"""SampleTimeStepsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import bisect
from typing import List, Tuple

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import values_with_nans
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid
from xms.tool_core import ALLOW_ONLY_TRANSIENT, IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.dataset_tool import get_min_max
import xms.tool.utilities.time_units_converter as tc


class SampleTimeStepsTool(Tool):
    """Tool to create a dataset with sampled time steps from another dataset."""
    ARG_INPUT_DATASET = 0
    ARG_SELECT_STARTING_TIME = 1
    ARG_SELECT_ENDING_TIME = 2
    ARG_TIME_STEP = 3
    ARG_TIME_STEP_UNITS = 4
    ARG_OUTPUT_DATASET = 5

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Sample Time Steps')
        self._dataset_reader = None
        self._time_converter = None
        self._current_input_dataset_str = None
        self._num_times_out = 1
        self._valid_dataset_time_units = ['Seconds', 'Minutes', 'Hours', 'Days', 'Years', 'None']

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        desc_ds = 'Input scalar dataset to sample from (transient datasets only)'
        desc_start = 'Select input scalar dataset starting time (beginning of sample interval)'
        desc_end = 'Select input scalar dataset ending time (end of sample interval)'
        time_units = tc.TIME_UNITS
        if 'weeks' in time_units:
            time_units.remove('weeks')
        arguments = [
            self.dataset_argument(name='input_dataset', description=desc_ds, filters=ALLOW_ONLY_TRANSIENT),
            self.timestep_argument(name='select_starting_time', description=desc_start),
            self.timestep_argument(name='select_ending_time', description=desc_end),
            self.float_argument(name='time_step', description='Time step for output scalar data', min_value=0.0),
            self.string_argument(name='time_step_units', description='Time step units for output scalar dataset',
                                 choices=time_units, value=tc.UNITS_SECONDS),
            self.dataset_argument(name='output_dataset', description='Name for output scalar dataset',
                                  io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        input_dataset_str = arguments[self.ARG_INPUT_DATASET].value
        arguments[self.ARG_SELECT_STARTING_TIME].enable_timestep(arguments[self.ARG_INPUT_DATASET])
        arguments[self.ARG_SELECT_ENDING_TIME].enable_timestep(arguments[self.ARG_INPUT_DATASET])
        # dataset set to be nothing so reset the other arguments
        if input_dataset_str is None or input_dataset_str == '':
            arguments[self.ARG_SELECT_STARTING_TIME].value = None
            arguments[self.ARG_SELECT_ENDING_TIME].value = None
            arguments[self.ARG_TIME_STEP].value = 0.0
            arguments[self.ARG_TIME_STEP_UNITS].value = tc.UNITS_SECONDS
        elif self._current_input_dataset_str is not None and self._current_input_dataset_str != input_dataset_str:
            # dataset has changed - make sure the other arguments are compatible
            input_dataset = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET].value)
            arguments[self.ARG_SELECT_STARTING_TIME].value = None
            arguments[self.ARG_SELECT_ENDING_TIME].value = None
            arguments[self.ARG_TIME_STEP].value = 0.0
            arguments[self.ARG_TIME_STEP_UNITS].value = input_dataset.time_units
        self._current_input_dataset_str = input_dataset_str

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[self.ARG_INPUT_DATASET], errors)
        start_ts = end_ts = 0
        try:
            start_ts = int(arguments[self.ARG_SELECT_STARTING_TIME].value)
        except TypeError:
            errors[arguments[self.ARG_SELECT_STARTING_TIME].name] = 'Invalid starting time.'
        try:
            end_ts = int(arguments[self.ARG_SELECT_ENDING_TIME].value)
        except TypeError:
            errors[arguments[self.ARG_SELECT_ENDING_TIME].name] = 'Invalid ending time.'
        if start_ts > end_ts:
            errors[arguments[self.ARG_SELECT_ENDING_TIME].name] = 'Ending time must be greater than starting time.'
        time_step = arguments[self.ARG_TIME_STEP].value
        if start_ts != end_ts and time_step <= 0.0:
            errors[arguments[self.ARG_TIME_STEP].name] = 'Time step must be greater than 0.0.'
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_dataset = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET].value)
        grid = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET].value)
        start_ts = int(arguments[self.ARG_SELECT_STARTING_TIME].value)
        end_ts = int(arguments[self.ARG_SELECT_ENDING_TIME].value)
        time_step = arguments[self.ARG_TIME_STEP].value
        time_step_units = arguments[self.ARG_TIME_STEP_UNITS].value
        for tsu in self._valid_dataset_time_units:
            if time_step_units.lower() == tsu.lower():
                out_ds_time_step_units = tsu
                break

        output_dataset_name = arguments[self.ARG_OUTPUT_DATASET].value

        self.logger.info(f'Input dataset time units: {input_dataset.time_units}')
        self.logger.info(f'Output dataset time units: {out_ds_time_step_units}')

        covert_time_in_to_out = tc.TimeUnitsConverter(input_dataset.time_units, time_step_units)
        self._time_converter = tc.TimeUnitsConverter(time_step_units, input_dataset.time_units)

        times = input_dataset.times[:]
        inds_start = times[start_ts - 1]
        self.logger.info(f'Input dataset starting time: {inds_start} {input_dataset.time_units}')
        outds_start = covert_time_in_to_out.convert_value(inds_start)
        self.logger.info(f'Output dataset starting time: {outds_start} {out_ds_time_step_units}')

        inds_end = times[end_ts - 1]
        self.logger.info(f'Input dataset ending time: {inds_end} {input_dataset.time_units}')
        outds_end = covert_time_in_to_out.convert_value(inds_end)
        self.logger.info(f'Output dataset ending time: {outds_end} {out_ds_time_step_units}')

        self.logger.info(f'Output dataset time step: {time_step} {out_ds_time_step_units}')
        if time_step > 0.0:
            self._num_times_out = int((outds_end - outds_start) // time_step) + 1
        self.logger.info(f'Number of output time steps: {self._num_times_out}')

        output_dataset = self.get_output_dataset_writer(name=output_dataset_name, geom_uuid=grid.uuid)
        output_dataset.time_units = out_ds_time_step_units

        self._sample_time_steps(grid.ugrid, input_dataset, output_dataset, outds_start, time_step)

    def _sample_time_steps(self, u_grid: UGrid, input_dataset: DatasetReader, output_dataset: DatasetWriter,
                           starting_time: float, time_step: float) -> None:
        """Sample time steps from a dataset.

        Args:
            u_grid (UGrid): The UGrid.
            input_dataset (DatasetReader): The dataset to calculate the time derivative of.
            output_dataset (DatasetWriter): The output dataset.
            starting_time (float): The starting time offset.
            time_step (float): The time step increment.
        """
        null_result: float = -999.0

        # initialize the dataset builder
        use_null_value: bool = input_dataset.null_value is not None
        if use_null_value:
            null_result = input_dataset.null_value
            output_dataset.null_value = null_result

        if input_dataset.ref_time is not None:
            output_dataset.ref_time = input_dataset.ref_time

        output_dataset.location = input_dataset.location

        input_values = input_dataset.values
        input_activity = input_dataset.activity
        out_ds_mins = []
        out_ds_maxs = []

        times = input_dataset.times[:]
        for new_time_index in range(self._num_times_out):
            self.logger.info(f'Processing time step {new_time_index + 1} of {self._num_times_out}...')

            time_offset = starting_time + new_time_index * time_step
            self.logger.info(f'Output dataset time step: {time_offset} {output_dataset.time_units}')
            inds_time_offset = self._time_converter.convert_value(time_offset)
            self.logger.info(f'Input dataset time step: {inds_time_offset} {input_dataset.time_units}')
            bounding_times = get_bounding_times(times, inds_time_offset)
            self.logger.info(f'Input dataset time indices/factors: {bounding_times}')

            values_out = None
            activity_out = None
            for time_index, weight in bounding_times:
                values = input_values[time_index]
                activity = None if input_activity is None else input_activity[time_index]
                values = values_with_nans(u_grid, values, activity, input_dataset.null_value)
                if values_out is None:
                    values_out = values * weight
                    activity_out = activity
                else:
                    values_out += values * weight
                    if activity_out is not None:
                        activity_out = activity * activity_out

            minimum, maximum = get_min_max(values_out)
            out_ds_mins.append(minimum)
            out_ds_maxs.append(maximum)
            # output the data
            if input_activity is not None:
                values_out[np.isnan(values_out)] = null_result
            output_dataset.append_timestep(time_offset, values_out, activity_out)

        # finish up the building
        output_dataset.timestep_mins = out_ds_mins
        output_dataset.timestep_maxs = out_ds_maxs
        output_dataset.appending_finished()
        self.set_output_dataset(output_dataset)


def get_bounding_times(times: np.ndarray, time: float) -> List[Tuple[int, float]]:
    """Get the indices and weights for the bounding times.

    Args:
        times (np.array): Array of time offsets.
        time (float): The value to find the bounding time steps of.

    Returns:
        (List[Tuple[int, float]]): A list of tuples with time index and interpolation weight.
    """
    if len(times) == 0:
        raise ValueError('Dataset must have at least one timestep.')
    if len(times) == 1:
        return [(0, 1.0)]
    index = bisect.bisect_left(times, time)
    if index >= len(times):
        return [(len(times) - 1, 1.0)]
    if index == 0:
        return [(0, 1.0)]
    if times[index] == time:
        return [(index, 1.0)]
    t0 = times[index - 1]
    t1 = times[index]
    weight0 = (t1 - time) / (t1 - t0)
    weight1 = 1.0 - weight0
    return [(index - 1, weight0), (index, weight1)]


# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = SampleTimeStepsTool()
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, 'Time Derivative', tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
