"""SmoothRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.rasters import RasterOutput
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_RASTER = 0
ARG_FILTER_OPTION = 1
ARG_NUMBER_ITERATIONS = 2
ARG_MAXIMUM_CHANGE = 3
ARG_FILTER_RATIO = 4
ARG_MAXIMUM_CHANGE_OPTION = 5
ARG_OUTPUT_RASTER = 6


class SmoothRasterTool(Tool):
    """Tool to smooth a raster, writing out a new raster."""
    MAXIMUM_CHANGE_FROM_ORIGINAL = 'Maximum change defined from original raster'
    MAXIMUM_CHANGE_PER_ITERATION = 'Maximum change defined for each iteration'

    FILTER_SIZE_3 = '3 x 3'
    FILTER_SIZE_5 = '5 x 5'

    PROGRESS_OUTPUT_FREQUENCY = 50

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Smooth Raster')
        self._raster_data = None
        self._input_raster = None
        self._no_data = 0.0
        self._filter_option = self.FILTER_SIZE_3
        self._filter_size = 3
        self._padding_width = 1
        self._number_iterations = 1
        self._maximum_change = 0.5
        self._filter_ratio = 0.25
        self._maximum_change_option = self.MAXIMUM_CHANGE_FROM_ORIGINAL
        self._out_path = ''
        self._base_smoothing_filter = None
        self._smoothed_data = None
        self._padded_array = None
        self._height = 0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.string_argument(name='filter_option', description='Filter size', value=self.FILTER_SIZE_3,
                                 choices=[self.FILTER_SIZE_3, self.FILTER_SIZE_5]),
            self.integer_argument(name='number_iterations', description='Number of iterations', min_value=1, value=1),
            self.float_argument(name='maximum_change', description='Maximum elevation change', min_value=0.0,
                                value=0.5),
            self.float_argument(name='filter_ratio', description='Filter ratio (0.0-1.0)', min_value=0.0, max_value=1.0,
                                value=0.25),
            self.string_argument(name='maximum_change_option',
                                 description='Define maximum change from original raster or per iteration',
                                 value=self.MAXIMUM_CHANGE_FROM_ORIGINAL,
                                 choices=[self.MAXIMUM_CHANGE_FROM_ORIGINAL, self.MAXIMUM_CHANGE_PER_ITERATION]),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the raster files
        multiple_band_error = 'This tool should only be used on rasters with a single dataset and is not designed to ' \
                              'be used with RGB images or rasters with more than one band (dataset).'
        input_raster = self.get_input_raster(arguments[ARG_INPUT_RASTER].value)
        if input_raster.gdal_raster.RasterCount > 2:
            errors[arguments[ARG_INPUT_RASTER].name] = multiple_band_error
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._input_raster = self.get_input_raster(arguments[ARG_INPUT_RASTER].value)
        self._filter_option = arguments[ARG_FILTER_OPTION].value
        self._number_iterations = arguments[ARG_NUMBER_ITERATIONS].value
        self._maximum_change = arguments[ARG_MAXIMUM_CHANGE].value
        self._filter_ratio = arguments[ARG_FILTER_RATIO].value
        self._maximum_change_option = arguments[ARG_MAXIMUM_CHANGE_OPTION].value
        self._out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)
        self._create_smoothed_raster()
        self._out_path = ru.reproject_raster(self._out_path, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                             self.default_wkt, self.vertical_datum, self.vertical_units)
        self.set_output_raster_file(self._out_path, arguments[ARG_OUTPUT_RASTER].text_value)

    def _log_progress(self, row):
        """Log an info message periodically.

        Args:
            row (int): The one-based row currently being processed
        """
        # Log frequently enough to keep things interesting, but not so much that the dialog lags.
        if row % self.PROGRESS_OUTPUT_FREQUENCY == 0 or row == 1:
            if row == 1:  # Log the first iteration
                end_row = min(self._height, self.PROGRESS_OUTPUT_FREQUENCY)
            else:  # Log every PROGRESS_OUTPUT_FREQUENCY iteration
                end_row = min(self._height, row + self.PROGRESS_OUTPUT_FREQUENCY)
            self.logger.info(
                f'Processing points in rows {row}-{end_row} of {self._height}...'
            )

    def _create_smoothed_raster(self):
        if self._filter_option == self.FILTER_SIZE_3:
            self._filter_size = 3
        else:
            self._filter_size = 5
        self._raster_data = self._input_raster.get_raster_values()
        self._no_data = self._input_raster.nodata_value
        if self._no_data is None:
            self._no_data = -999999.0
        self._height = self._raster_data.shape[0]
        # expand self._raster_data to fit edge of kernel
        self._padding_width = int(self._filter_size / 2)
        self._padded_array = np.pad(self._raster_data, self._padding_width, 'constant',
                                    constant_values=self._no_data).astype(float)
        self._init_smoothing()
        for iteration in range(self._number_iterations):
            self.logger.info(f'Processing iteration {iteration + 1} of {self._number_iterations}...')
            for index, value in np.ndenumerate(self._raster_data):
                if index[1] == 0:
                    self._log_progress(index[0] + 1)
                if value != self._no_data:
                    smooth_filter = self._get_smoothing_filter(index[0], index[1])
                    self._filter_cell(index[0], index[1], smooth_filter)
            self._smoothed_data =\
                self._padded_array[self._padding_width:-self._padding_width, self._padding_width:-self._padding_width]
        raster_output = RasterOutput(xorigin=self._input_raster.xorigin, yorigin=self._input_raster.yorigin,
                                     width=self._input_raster.resolution[0], height=self._input_raster.resolution[1],
                                     pixel_width=self._input_raster.pixel_width,
                                     pixel_height=self._input_raster.pixel_height,
                                     nodata_value=self._no_data, wkt=self._input_raster.wkt)
        raster_output.write_raster(self._out_path, self._smoothed_data)

    def _init_smoothing(self):
        d1 = self._input_raster.pixel_width
        d2 = self._input_raster.pixel_height
        d3 = math.sqrt(d1 * d1 + d2 * d2)
        d1 *= d1
        d2 *= d2
        d3 *= d3
        d1 = 1.0 / d1
        d2 = 1.0 / d2
        d3 = 1.0 / d3
        if self._filter_option == self.FILTER_SIZE_3:
            sum1 = 2 * (d1 + d2) + 4 * d3
            mult = (1.0 - self._filter_ratio) / sum1
            self._base_smoothing_filter = np.array([[d3 * mult, d2 * mult, d3 * mult],
                                                    [d1 * mult, self._filter_ratio, d1 * mult],
                                                    [d3 * mult, d2 * mult, d3 * mult]])
        else:
            d4 = 2 * self._input_raster.pixel_width
            d5 = 2 * self._input_raster.pixel_height
            d6 = math.sqrt(d4 * d4 + d2 * d2)
            d7 = math.sqrt(d1 * d1 + d5 * d5)
            d8 = 2 * d3
            d4 *= d4
            d5 *= d5
            d6 *= d6
            d7 *= d7
            d8 *= d8
            d4 = 1.0 / d4
            d5 = 1.0 / d5
            d6 = 1.0 / d6
            d7 = 1.0 / d7
            d8 = 1.0 / d8
            sum1 = 2 * (d1 + d2 + d4 + d5) + 4 * (d3 + d6 + d7 + d8)
            mult = (1.0 - self._filter_ratio) / sum1
            self._base_smoothing_filter = np.array(
                [[d8 * mult, d7 * mult, d5 * mult, d7 * mult, d8 * mult],
                 [d6 * mult, d3 * mult, d2 * mult, d3 * mult, d6 * mult],
                 [d4 * mult, d1 * mult, self._filter_ratio, d1 * mult, d4 * mult],
                 [d6 * mult, d3 * mult, d2 * mult, d3 * mult, d6 * mult],
                 [d8 * mult, d7 * mult, d5 * mult, d7 * mult, d8 * mult]])
        self._smoothed_data = np.array(self._raster_data)

    def _get_smoothing_filter(self, y, x):
        smooth_filter = np.array(self._base_smoothing_filter)
        found = False
        for i1 in range(self._filter_size):
            for i2 in range(self._filter_size):
                y_loc = y + i1
                x_loc = x + i2
                if self._padded_array[y_loc][x_loc] == self._no_data:
                    smooth_filter[i1][i2] = 0.0
                    found = True
        if found:
            filter_sum = np.sum(smooth_filter)
            smooth_filter = np.array(smooth_filter) / filter_sum
        return smooth_filter

    def _filter_cell(self, y, x, smooth_filter):
        new_z = 0.0
        for i1 in range(self._filter_size):
            for i2 in range(self._filter_size):
                if smooth_filter[i1][i2] > 0.0:
                    y_loc = y - self._padding_width + i1
                    x_loc = x - self._padding_width + i2
                    new_z += self._smoothed_data[y_loc][x_loc] * smooth_filter[i1][i2]
        if self._maximum_change_option == self.MAXIMUM_CHANGE_FROM_ORIGINAL:
            if new_z - self._raster_data[y][x] > self._maximum_change:
                new_z = self._raster_data[y][x] + self._maximum_change
            elif self._raster_data[y][x] - new_z > self._maximum_change:
                new_z = self._raster_data[y][x] - self._maximum_change
        else:
            if new_z - self._smoothed_data[y][x] > self._maximum_change:
                new_z = self._smoothed_data[y][x] + self._maximum_change
            elif self._smoothed_data[y][x] - new_z > self._maximum_change:
                new_z = self._smoothed_data[y][x] - self._maximum_change
        self._padded_array[y + self._padding_width][x + self._padding_width] = new_z
