"""MergeDatasetsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.datasets.merge_datasets import merge_datasets
from xms.tool.utilities.time_units_converter import TimeUnitsConverter


class MergeDatasetsTool(Tool):
    """Tool to merge two transient datasets."""
    ARG_INPUT_DATASET_1 = 0
    ARG_INPUT_DATASET_2 = 1
    ARG_OUTPUT_DATASET = 2

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Merge Datasets')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='dataset_1', description='Dataset one'),
            self.dataset_argument(name='dataset_2', description='Dataset two'),
            self.dataset_argument(name='output_dataset', description='Output dataset', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        dataset_1 = self._validate_input_dataset(arguments[self.ARG_INPUT_DATASET_1], errors)
        dataset_2 = self._validate_input_dataset(arguments[self.ARG_INPUT_DATASET_2], errors)
        if dataset_1 is None or dataset_2 is None:
            return errors

        name_1 = arguments[self.ARG_INPUT_DATASET_1].name
        name_2 = arguments[self.ARG_INPUT_DATASET_2].name

        errors_1 = []
        errors_2 = []
        if dataset_1.geom_uuid != dataset_2.geom_uuid:
            errors_1.append('The datasets must be from the same UGrid.')
            errors_2.append('The datasets must be from the same UGrid.')
        if dataset_1.num_components != dataset_2.num_components:
            errors_1.append('Input datasets must have the same number of components.')
            errors_2.append('Input datasets must have the same number of components.')
        if dataset_1.ref_time != dataset_2.ref_time:
            errors_1.append('Input datasets must have matching reference times.')
            errors_2.append('Input datasets must have matching reference times.')
        if dataset_1.num_values != dataset_2.num_values:
            errors_1.append('Input datasets must have the same number of values.')
            errors_2.append('Input datasets must have the same number of values.')
        # Check if time units are defined. If not, the dataset is most likely steady state and the user is silly.
        if dataset_1.time_units == 'None':
            errors_1.append('Input datasets must have defined time units.')
        if dataset_2.time_units == 'None':
            errors_2.append('Input datasets must have defined time units.')
        times_1 = dataset_1.times
        times_2 = dataset_2.times
        converter = None
        dataset_1_first_time = times_1[0]
        dataset_1_last_time = times_1[-1]
        dataset_2_first_time = times_2[0]
        dataset_2_last_time = times_2[-1]
        if dataset_1.time_units != dataset_2.time_units:
            # If one of the datasets has no time units defined, we will assume it is in the units of the second.
            if dataset_1.time_units != 'None' and dataset_2.time_units != 'None':
                converter = TimeUnitsConverter(from_units=dataset_2.time_units, to_units=dataset_1.time_units)
                dataset_2_first_time = converter.convert_value(dataset_2_first_time)
                dataset_2_last_time = converter.convert_value(dataset_2_last_time)
        dataset_1_first = dataset_1_last_time < dataset_2_first_time
        dataset_2_first = dataset_2_last_time < dataset_1_first_time
        if not dataset_1_first and not dataset_2_first:
            errors_1.append('Time steps of the datasets must not overlap.')
            errors_2.append('Time steps of the datasets must not overlap.')

        space = ' '
        if len(errors_1) > 0:
            errors[name_1] = space.join(errors_1)
        if len(errors_2) > 0:
            errors[name_2] = space.join(errors_2)

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        dataset_1 = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET_1].value)
        dataset_2 = self.get_input_dataset(arguments[self.ARG_INPUT_DATASET_2].value)
        grid_1 = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET_1].value)
        grid_2 = self.get_input_dataset_grid(arguments[self.ARG_INPUT_DATASET_2].value)
        output_dataset_name = arguments[self.ARG_OUTPUT_DATASET].value

        # Run merge_datasets algorithm
        output_dataset = merge_datasets(output_dataset_name, dataset_1, dataset_2, grid_1, grid_2, self.logger)

        self.set_output_dataset(output_dataset)


# def main():
#     """Main function, for testing."""
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#
#     qapp = ensure_qapplication_exists()
#     tool = MergeDatasetsTool()
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#
#
# if __name__ == "__main__":
#     main()
