"""TimeDerivativeTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import values_with_nans
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.dataset_tool import get_min_max
import xms.tool.utilities.time_units_converter as tc


class TimeDerivativeTool(Tool):
    """Tool to convert an angle convention dataset to another angle convention."""
    ARG_INPUT_DATASET = 0
    ARG_OPTION = 1
    ARG_TIME_UNITS = 2
    ARG_OUTPUT_DATASET = 3

    def __init__(self, **kwargs):
        """Initializes the class."""
        super().__init__(name='Time Derivative', **kwargs)
        self._dataset_reader = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input scalar dataset'),
            self.string_argument(name='option', description='Calculation option',
                                 choices=['Change', 'Derivative'], value='Change'),
            self.string_argument(name='derivative_units', description='Derivative time units',
                                 choices=tc.TIME_UNITS, value=tc.UNITS_SECONDS),
            self.dataset_argument(name='output_dataset', description='Output dataset', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[self.ARG_INPUT_DATASET], errors)
        if self._dataset_reader.num_times < 2:
            name = arguments[self.ARG_INPUT_DATASET].name
            errors[name] = 'The input dataset must have at least 2 time steps.'
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_dataset = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET].value)
        grid = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET].value)
        calculate_derivative = arguments[self.ARG_OPTION].value == 'Derivative'
        time_units = arguments[self.ARG_TIME_UNITS].value
        output_dataset_name = arguments[self.ARG_OUTPUT_DATASET].value

        output_dataset = self.get_output_dataset_writer(
            name=output_dataset_name,
            geom_uuid=grid.uuid,
        )

        self._build_time_derivative(grid.ugrid, input_dataset, output_dataset, calculate_derivative, time_units)

    def _build_time_derivative(self, u_grid: UGrid, input_dataset: DatasetReader, output_dataset: DatasetWriter,
                               calculate_derivative: bool, time_units: str) -> None:
        """
        Build the time derivative for a dataset.

        Args:
            u_grid (UGrid): The UGrid.
            input_dataset (DatasetReader): The dataset to calculate the time derivative of.
            output_dataset (DatasetWriter): The output dataset.
            calculate_derivative (bool): Should the derivative be calculated (or a difference)?
            time_units (str): The time units for the derivative.
        """
        null_result: float = -999.0

        # initialize the dataset builder
        use_null_value: bool = input_dataset.null_value is not None
        if use_null_value:
            null_result = input_dataset.null_value
            output_dataset.null_value = null_result

        if input_dataset.ref_time is not None:
            output_dataset.ref_time = input_dataset.ref_time

        output_dataset.location = input_dataset.location

        num_times: int = input_dataset.num_times

        input_values = input_dataset.values
        input_activity = input_dataset.activity
        if input_activity is not None:
            mins = []
            maxs = []
        self.logger.info(f'Processing time step 1 of {num_times}...')
        values_t1 = input_values[0]
        activity_t1 = None if input_activity is None else input_activity[0]
        values_t1 = values_with_nans(u_grid, values_t1, activity_t1, input_dataset.null_value)

        # convert the time increment
        if calculate_derivative:
            converter = tc.TimeUnitsConverter()
            converter.from_units = input_dataset.time_units.title()
            converter.to_units = time_units

        # loop through each time and do comparison
        for ts_idx in range(1, num_times):
            self.logger.info(f'Processing time step {ts_idx + 1} of {num_times}...')

            # get the data for the functions
            values_t2 = input_values[ts_idx]
            activity_t2 = None if input_activity is None else input_activity[ts_idx]
            values_t2 = values_with_nans(u_grid, values_t2, activity_t2, input_dataset.null_value)

            factor = 1.0
            if calculate_derivative:
                # compute the span
                time_span = input_dataset.times[ts_idx]
                time_span = time_span - input_dataset.times[ts_idx - 1]

                factor = 1.0 / converter.convert_value(time_span)

            values_out = (values_t2 - values_t1) * factor

            # compute the offset
            time_offset = (input_dataset.times[ts_idx - 1] + input_dataset.times[ts_idx]) / 2.0

            # output the data
            activity_out = None
            if input_activity is not None:
                # inactive if either time step is inactive
                activity_out = input_activity[ts_idx - 1] * input_activity[ts_idx]
                min, max = get_min_max(values_out)
                mins.append(min)
                maxs.append(max)
                values_out[np.isnan(values_out)] = null_result
            output_dataset.append_timestep(time_offset, values_out, activity_out)

            values_t1 = values_t2
            activity_t1 = activity_t1

        # finish up the building
        if input_activity is not None:
            output_dataset.timestep_mins = mins
            output_dataset.timestep_maxs = maxs
        output_dataset.appending_finished()
        self.set_output_dataset(output_dataset)

# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = TimeDerivativeTool()
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, 'Time Derivative', tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
