"""PointDatasetFromCellDatasetTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.tool_core import ALLOW_ONLY_CELL_MAPPED, Tool

# 4. Local modules
from xms.tool.utilities import xms_utils

ARG_INPUT_DSET = 0
ARG_OUTPUT_NAME = 1


class PointDatasetFromCellDatasetTool(Tool):
    """Tool to convert a cell based dataset to a point based dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Point Dataset from Cell Dataset')
        self._dataset_reader = None
        self._dataset_writer = None
        self._ugrid = None
        self._activity_calculator = None
        self._args = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input dataset',
                                  filters=ALLOW_ONLY_CELL_MAPPED),
            self.string_argument(name='output_dataset_name', description='dataset name', value=''),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[ARG_INPUT_DSET], errors)
        self._ugrid = self.get_input_dataset_grid(arguments[ARG_INPUT_DSET].text_value)
        self._activity_calculator = CellToPointActivityCalculator(self._ugrid)
        ug_txt = arguments[ARG_INPUT_DSET].text_value
        if xms_utils.tool_is_running_from_xms(self):  # TODO we would like a better way to do this kind of check
            if not ug_txt.startswith('UGrid Data'):
                errors[arguments[ARG_INPUT_DSET].name] = 'Input dataset must be associated with a UGrid'

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        self._setup_dataset_writer()
        self._convert_dataset()
        # Send the dataset back to XMS
        self.set_output_dataset(self._dataset_writer)

    def _convert_dataset(self):
        """Convert the cell data to point data."""
        adj_cells = {}
        ug = self._ugrid.ugrid
        for i in range(ug.point_count):
            adj_cells[i] = ug.get_point_adjacent_cells(i)
        nv = self._dataset_reader.null_value
        ncomp = self._dataset_reader.num_components
        time_count = len(self._dataset_reader.times)
        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            data, activity = self._dataset_reader.timestep_with_activity(tsidx, nan_null_values=False)
            if activity is None and nv is not None:  # convert null_value to activity array
                if ncomp == 1:
                    activity = [1 if val != nv else 0 for val in data]
                else:
                    activity = [1 if val[0] != nv else 0 for val in data]

            if ncomp == 1:
                pt_data = [0] * ug.point_count
            else:
                pt_data = [[0] * ncomp] * ug.point_count

            # calculate the point data
            for pt_idx, pt_adj_cells in adj_cells.items():
                if activity is not None:
                    cell_values = [data[c] for c in pt_adj_cells if activity[c] != 0]
                else:
                    cell_values = [data[c] for c in pt_adj_cells]

                if len(cell_values) < 1:
                    continue
                if ncomp == 1:
                    pt_data[pt_idx] = sum(cell_values) / len(cell_values)
                else:
                    pt_data[pt_idx] = [sum(x) / len(cell_values) for x in zip(*cell_values)]

            self._dataset_writer.append_timestep(self._dataset_reader.times[tsidx], pt_data, activity)
        self._dataset_writer.appending_finished()

    def _setup_dataset_writer(self):
        """Set up dataset writer for tool."""
        # Create a place for the output dataset file
        loc = 'points'
        dataset_name = self._args[ARG_OUTPUT_NAME].text_value

        dsr = self._dataset_reader
        use_activity_as_null = False
        if dsr.activity is not None or dsr.null_value is not None:
            use_activity_as_null = True
        self._dataset_writer = self.get_output_dataset_writer(
            name=dataset_name,
            dset_uuid=str(uuid.uuid4()),
            geom_uuid=dsr.geom_uuid,
            num_components=dsr.num_components,
            ref_time=dsr.ref_time,
            time_units=dsr.time_units,
            use_activity_as_null=use_activity_as_null,
            location=loc,
        )
        self._dataset_writer.activity_calculator = CellToPointActivityCalculator(self._ugrid.ugrid)
