"""CellDatasetFromPointDataset class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import active_cells_from_points, CellToPointActivityCalculator
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, Tool

# 4. Local modules
from xms.tool.utilities import xms_utils

ARG_INPUT_DSET = 0
ARG_OUTPUT_NAME = 1


class CellDatasetFromPointDatasetTool(Tool):
    """Tool to convert a point based dataset to a cell based dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Cell Dataset from Point Dataset')
        self._dataset_reader = None
        self._dataset_writer = None
        self._ugrid = None
        self._args = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='input_dataset', description='Input dataset',
                                  filters=ALLOW_ONLY_POINT_MAPPED),
            self.string_argument(name='output_dataset_name', description='Dataset name', value=''),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate input datasets
        self._dataset_reader = self._validate_input_dataset(arguments[ARG_INPUT_DSET], errors)
        self._ugrid = self.get_input_dataset_grid(arguments[ARG_INPUT_DSET].text_value)
        ug_txt = arguments[ARG_INPUT_DSET].text_value
        if xms_utils.tool_is_running_from_xms(self):  # TODO we would like a better way to do this kind of check
            if not ug_txt.startswith('UGrid Data'):
                errors[arguments[ARG_INPUT_DSET].name] = 'Input dataset must be associated with a UGrid'

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        self._setup_dataset_writer()
        self._convert_dataset()
        # Send the dataset back to XMS
        self.set_output_dataset(self._dataset_writer)

    def _convert_dataset(self):
        """Convert the cell data to point data."""
        all_cell_points = {}
        ug = self._ugrid.ugrid

        data = self._dataset_reader.values[0]
        activity = self._dataset_reader.activity[0] if self._dataset_reader.activity is not None else None
        if activity is not None and len(data) != len(activity):
            self._dataset_reader.activity_calculator = CellToPointActivityCalculator(ug)
        for i in range(ug.cell_count):
            all_cell_points[i] = ug.get_cell_points(i)
        nv = self._dataset_reader.null_value
        ncomp = self._dataset_reader.num_components
        time_count = len(self._dataset_reader.times)
        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            data, activity = self._dataset_reader.timestep_with_activity(tsidx, nan_null_values=False)
            cell_activity = pt_activity = None
            if nv is not None:  # dataset has a null value
                if ncomp == 1:
                    pt_activity = [1 if val != nv else 0 for val in data]
                else:
                    pt_activity = [1 if val[0] != nv else 0 for val in data]
            elif activity is not None and len(activity) == len(data):  # point data, point activity
                pt_activity = activity
            elif activity is not None:  # point data, cell activity - activity calculator should take care of it
                cell_activity = activity
            if pt_activity is not None:
                cell_activity = active_cells_from_points(ug, pt_activity)

            if ncomp == 1:
                cell_data = [0] * ug.cell_count
            else:
                cell_data = [[0] * ncomp] * ug.cell_count
            # calculate the point data
            for cell_idx, cell_points in all_cell_points.items():
                if cell_activity is not None and cell_activity[cell_idx] == 0:
                    continue
                cell_values = [data[p] for p in cell_points]
                if len(cell_values) > 0:
                    if ncomp == 1:
                        cell_data[cell_idx] = sum(cell_values) / len(cell_values)
                    else:
                        cell_data[cell_idx] = [sum(x) / len(cell_values) for x in zip(*cell_values)]

            self._dataset_writer.append_timestep(self._dataset_reader.times[tsidx], cell_data, cell_activity)
        self._dataset_writer.appending_finished()

    def _setup_dataset_writer(self):
        """Set up dataset writer for tool."""
        # Create a place for the output dataset file
        loc = 'cells'
        dataset_name = self._args[ARG_OUTPUT_NAME].text_value

        dsr = self._dataset_reader
        use_activity_as_null = False
        if dsr.activity is not None or dsr.null_value is not None:
            use_activity_as_null = True
        self._dataset_writer = self.get_output_dataset_writer(
            name=dataset_name,
            dset_uuid=str(uuid.uuid4()),
            geom_uuid=dsr.geom_uuid,
            num_components=dsr.num_components,
            ref_time=dsr.ref_time,
            time_units=dsr.time_units,
            use_activity_as_null=use_activity_as_null,
            location=loc,
        )
        self.logger.info(f'geom_uuid: {dsr.geom_uuid}')
        self.logger.info(f'num_components: {dsr.num_components}')
        self.logger.info(f'h5: {dsr.h5_filename}')
