"""Dataset Calculator Algorithm."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import numexpr as ne
import numpy as np


# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import active_points_from_cells, CellToPointActivityCalculator
from xms.datasets.dataset_reader import DatasetReader

# 4. Local modules


def get_activity_type(ugrid, activity_array):
    """Gets activity type: cells, points, invalid, or None.

    Args:
        ugrid(UGrid): UGrid origin of dataset(s).
        activity_array(list): list of activity arrays from dataset.

    Returns:
        activity_type(str): type of activity array depending on where activity is recorded: on cells or points
    """
    if activity_array is None:
        return None
    activity_array = np.array(activity_array)
    if activity_array.ndim == 1:
        length = len(activity_array)
    else:
        length = activity_array.shape[-1]

    if ugrid.cell_count == length:
        return "cells"
    elif ugrid.point_count == length:
        return "points"
    else:
        return "invalid"


class CalculationResult:
    """Simple class for storing the results of a calculation."""

    def __init__(self):
        """Initializer."""
        self.result_type = None  # Either "error" or "success"
        self.value_count = None  # Either "single" or "multiple"
        self.error_info = None  # Used for any extra information to display to the user.


class DatasetCalculator:
    """Class for calculating expressions with datasets."""

    def __init__(self, datasets, variables, timesteps, ugrid, expression, output_dataset, logger):
        """
        Initializer dataset calculator.

        Args:
            datasets: list of provided dataset reader objects
            variables: list of provided variable names
            timesteps: list of chosen timestep indices, or "All" to signify all timesteps are chosen
            ugrid: UGrid
            expression: provided mathematical expression to be computed.
            output_dataset: name of dataset file to output results to.
            logger: tool logger for displaying errors/information.
        """
        self._datasets = datasets
        self._variables = variables
        self._timesteps = timesteps
        self._ugrid = ugrid
        self._expression = expression
        self._logger = logger
        self._output_dataset = output_dataset

    def _evaluate_expression(self, expression, variables_dict):
        """Evaluate expressions.

        Args:
            expression(str): the provided mathematical expression.
            variables_dict(dict): dictionary with provided names to use
                in expression for each dataset and the datasets' values

        Returns:
            result(np.ndarray): array of resulting value arrays from executing the expression on the datasets' values.
        """
        try:
            result = ne.evaluate(expression, local_dict=variables_dict)
            if result.size == 1:
                return result.item()

            return result
        except Exception as e:
            self._logger.error(str(e))
            return f"Error: {str(e)}"

    def _build_activity_array(self, values, null_value):
        """Forms an activity array from a dataset's values and provided null value.

        Args:
            values(np.ndarray): array of value arrays from dataset.
            null_value(float): value to consider inactive in activity if found at location(s) in the values array.

        Returns:
            activity(np.ndarray): array of activity arrays where if a value is equal to the null_value,
                activity is 0(inactive), otherwise 1(active).
        """
        activity = np.array(np.where(values == null_value, 0, 1))
        return activity

    def _get_single_activity(self, dataset: DatasetReader, timestep_index: int | None = None):
        """Gets activity from a single dataset.

        Args:
            dataset: dataset reader object
            timestep_index: index of chosen timestep, or if 'all time steps' is selected, None.

        Returns:
            activity(np.ndarray): array of activity array(s) to use when evaluating the final resulting
                dataset's activity
        """
        activity = None

        if dataset.activity is None:
            if dataset.null_value is None:
                return None
            else:
                values = np.array(dataset.values)
                activity = self._build_activity_array(values, dataset.null_value)
            if timestep_index is not None:
                activity = np.array([activity[0]])

        elif timestep_index is None:
            activity = np.array(dataset.activity)

        else:
            activity = np.array([dataset.activity[timestep_index]])

        return activity

    def _get_overall_activity(self, datasets, timestep_indices):
        """Gets the overall activity for the resulting dataset using the activity array(s) of the input dataset(s).

        Args:
            datasets: list of dataset objects.
            timestep_indices: list of indices of chosen timesteps for dataset timestep arrays.

        Returns:
            activity(np.ndarray): resulting activity array evaluated from all chosen datasets
        """
        activity = None
        activity_list = []

        # Retrieve a list of the formed activity arrays for each dataset
        for index in range(0, len(datasets)):
            single_activity = self._get_single_activity(datasets[index], timestep_indices[index])
            if single_activity is not None:
                activity_list.append(single_activity)

        maximum_timesteps = 1
        for index in range(len(datasets)):
            # Grab the first dataset with all timesteps selected
            if timestep_indices[index] is None:
                maximum_timesteps = len(datasets[index].times)
                break

        if len(activity_list) > 0:
            # Grab the highest amount of timesteps that exists for any of the user's chosen datasets
            if None in timestep_indices:
                for index in range(len(activity_list)):
                    if len(activity_list[index]) != maximum_timesteps:
                        activity_list[index] = np.broadcast_to(activity_list[index],
                                                               (maximum_timesteps, activity_list[index].shape[1]))

            if len(activity_list) == 1:
                return activity_list[0]

        else:
            return None

        # If there's more than one type of activity, one or more must be cell
        # activity arrays that need to be converted to point activity arrays
        if any((get_activity_type(self._ugrid, activity) == "points") for activity in activity_list):
            for index in range(len(activity_list)):
                if get_activity_type(self._ugrid, activity_list[index]) == "cells":
                    activity_list[index] = np.array([active_points_from_cells(self._ugrid, activity_sublist)
                                                     for activity_sublist in activity_list[index]])

        activity_list = [np.array(activity) for activity in activity_list]

        # Form the final resulting activity array
        current_activity = activity_list[0]
        for index in range(1, len(activity_list)):
            current_activity = np.where((current_activity == 0) | (activity_list[index] == 0), 0, 1)

        activity = current_activity

        return activity

    def _build_output_dataset(self, activity, results, times):
        """
        Build resulting dataset with activity, results of calculation, and any timesteps provided (if any).

        Args:
            activity: list of resulting activity arrays
            results: list of value arrays resulting from calculation
            times: list of timesteps
        """
        # Make sure aspects like mins and maxs are calculated with final activity.
        if get_activity_type(self._ugrid, activity) == "cells" and self._output_dataset.location == "points":
            self._output_dataset.activity_calculator = CellToPointActivityCalculator(ugrid=self._ugrid)

        for index in range(len(times)):
            time = times[index]
            if activity is None:
                self._output_dataset.append_timestep(time, results[index])
            else:
                self._output_dataset.append_timestep(time, results[index], activity[index])
        self._output_dataset.appending_finished()

    def calculate(self):
        """Runs calculation."""
        calculation_result = CalculationResult()
        timestep_indices = []
        values_list = []
        times = None
        array_shape = None

        # Get timestep indices and values for each dataset
        for index in range(len(self._timesteps)):
            dataset = self._datasets[index]
            if self._timesteps[index] == "All":
                if times is None:
                    times = np.array(dataset.times)

                timestep_indices.append(None)
                values_list.append(np.array(dataset.values))
                if array_shape is None:
                    array_shape = dataset.values.shape
            else:
                timestep_indices.append(self._timesteps[index])
                values_list.append(np.array(dataset.values[self._timesteps[index]]))

        if times is None:
            times = [0.0]

        if array_shape is None:
            if self._output_dataset.location == "cells":
                array_shape = (1, self._ugrid.cell_count)
            elif self._output_dataset.location == "points":
                array_shape = (1, self._ugrid.point_count)

        activity = self._get_overall_activity(self._datasets, timestep_indices)

        # Make sure all value lists have the same size.
        for index in range(len(values_list)):
            if values_list[index].shape[0] != array_shape[0]:
                values_list[index] = np.broadcast_to(values_list[index], array_shape)

        variables_dict = {}
        for index in range(len(self._datasets)):
            variables_dict[self._variables[index]] = values_list[index]

        results = self._evaluate_expression(self._expression, variables_dict)

        # Check if the results are a single value, meaning a constant expression was used
        if not isinstance(results, np.ndarray):
            calculation_result.value_count = "single"
            # Check for an invalid value
            if np.isnan(results) or np.isinf(results):
                calculation_result.result_type = "error"
                return calculation_result

            # Fill an array with the correct number of values, all values being the
            # same result of the constant expression
            results = np.full(array_shape, results, dtype=float)
        else:
            calculation_result.value_count = "multiple"

        # Check for any invalid values in results
        if np.any(np.isnan(results) | np.isinf(results)):
            # Create a mask that selects only invalid values in results
            value_mask = np.isnan(results) | np.isinf(results)
            if activity is not None:
                # If the activity array size doesn't match the values array's, meaning point values with cell activity
                if results[0].shape != activity[0].shape:
                    # Convert the cell activity to point activity
                    activity = np.array([active_points_from_cells(self._ugrid, activity_sublist)
                                         for activity_sublist in activity])
                # Create a mask that selects only active points or cells
                activity_mask = (activity == 1)
                # Create a mask that selects only points or cells that have invalid values and are also active
                errors_mask = value_mask & activity_mask
            else:
                # If there is no activity, any invalid values will result in errors.
                errors_mask = value_mask

            # If there are any values in the error mask to be displayed to the user
            if np.any(errors_mask):
                calculation_result.result_type = "error"
                # Grab the timestep and value indices to use in the error message
                timestep_indices, value_indices = np.where(errors_mask)

                # Dictionary used to store each timestep and a list of it's corresponding point or cell id(s)
                timestep_dict = {}
                for timestep_index, value_index in zip(timestep_indices, value_indices):
                    timestep = times[timestep_index]
                    if timestep in timestep_dict:
                        timestep_dict[timestep].append(value_index + 1)
                    else:
                        timestep_dict[timestep] = [value_index + 1]

                calculation_result.error_info = timestep_dict
                return calculation_result

            else:
                if activity is not None:
                    # Create a mask that selects only points or cells that are inactive
                    inactive_mask = (activity == 0)
                    # Create a mask that selects points or cells that have invalid values and are inactive
                    replace_mask = value_mask & inactive_mask
                    # Replace all values in the replace mask with the value '0.0'
                    results[replace_mask] = 0.0

        self._build_output_dataset(activity, results, times)
        calculation_result.result_type = "success"
        return calculation_result
