"""MergeDatasets Algorithm."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from logging import Logger
from typing import Optional, Tuple

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.constraint.ugrid_activity import active_cells_from_points, active_points_from_cells, values_with_nans
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.tool.utilities.dataset_tool import get_min_max
from xms.tool.utilities.time_units_converter import TimeUnitsConverter


def merge_datasets(output_dataset_name: str, dataset_1: DatasetReader, dataset_2: DatasetReader, grid_1: Grid,
                   grid_2: Grid, logger: Logger) -> DatasetWriter:
    """Merges two transient datasets.

    Args:
        output_dataset_name: The name for the output dataset
        dataset_1: Dataset reader for the first input dataset
        dataset_2: Dataset reader for the second input dataset
        grid_1: Grid for the first input dataset
        grid_2: Grid for the second input dataset
        logger: Logger for user output
    """
    ugrid_1 = grid_1.ugrid
    ugrid_2 = grid_2.ugrid

    activity_type, null_value = get_output_activity_type(dataset_1, ugrid_1, dataset_2, ugrid_2)

    # Setup Output DatasetWriter
    output_dataset = DatasetWriter()
    output_dataset.name = output_dataset_name
    output_dataset.geom_uuid = grid_1.uuid
    output_dataset.num_components = dataset_1.num_components
    output_dataset.ref_time = dataset_1.ref_time
    output_dataset.time_units = dataset_1.time_units
    output_dataset.null_value = null_value

    # Setup output_dataset
    output_dataset.timestep_mins = []
    output_dataset.timestep_maxs = []

    if dataset_1.times[0] > dataset_2.times[0]:
        dataset_1, dataset_2 = dataset_2, dataset_1
    logger.info('Processing the first dataset...')
    _write_time_steps(dataset_1, ugrid_1, activity_type, output_dataset, logger)
    logger.info('Processing the second dataset...')
    _write_time_steps(dataset_2, ugrid_2, activity_type, output_dataset, logger)

    output_dataset.appending_finished()

    return output_dataset


def _write_time_steps(input_dataset: DatasetReader, ugrid: UGrid, activity_type: str,
                      output_dataset: DatasetWriter, logger: Logger):
    """Write time steps to output dataset copied from another dataset.

    Args:
        input_dataset (DatasetReader): The dataset to copy from.
        ugrid (UGrid): The dataset's UGrid.
        activity_type (str): The type of output activity.
        output_dataset (DatasetWriter): The dataset to write to.
        logger (Logger): Logger for user output
    """
    input_values = input_dataset.values
    input_activity = input_dataset.activity
    num_times = input_dataset.num_times

    for timestep_idx in range(num_times):
        logger.info(f'Processing time step {timestep_idx + 1} of {num_times}...')

        # get values with output activity applied
        values_out = input_values[timestep_idx]
        activity = None if input_activity is None else input_activity[timestep_idx]
        time = input_dataset.times[timestep_idx]
        if input_dataset.time_units != output_dataset.time_units:
            converter = TimeUnitsConverter(from_units=input_dataset.time_units, to_units=output_dataset.time_units)
            time = converter.convert_value(time)
        values_out = values_with_nans(ugrid, values_out, activity, input_dataset.null_value)

        minimum, maximum = get_min_max(values_out)
        output_dataset.timestep_mins.append(minimum)
        output_dataset.timestep_maxs.append(maximum)

        null_value = output_dataset.null_value if output_dataset.null_value is not None else -999.0
        values_out, activity_out = get_output_dataset_values(values_out, activity, ugrid, activity_type,
                                                             null_value)
        output_dataset.append_timestep(time, values_out, activity_out)


def get_output_activity_type(dataset_1, ugrid_1, dataset_2, ugrid_2):
    """Determine the output activity type and null value when combining datasets.

    Args:
        dataset_1 (DatasetReader): The first dataset.
        ugrid_1 (UGrid): The first UGrid.
        dataset_2 (DatasetReader): The second dataset.
        ugrid_2 (UGrid): The second UGrid.

    Returns:
        (str, Optional[float]): Activity type and null value for value activity.
    """
    # first prefer using cell array activity
    if dataset_1.activity is not None and len(dataset_1.activity[0]) == ugrid_1.cell_count:
        return 'cell activity', None
    if dataset_2.activity is not None and len(dataset_2.activity[0]) == ugrid_2.cell_count:
        return 'cell activity', None
    # next prefer using point array activity
    if dataset_1.activity is not None:
        return 'point activity', None
    if dataset_2.activity is not None:
        return 'point activity', None
    # next prefer using null value activity
    if dataset_1.null_value is not None:
        return 'null value activity', dataset_1.null_value
    if dataset_2.null_value is not None:
        return 'null value activity', dataset_2.null_value
    return 'no activity', None


def get_output_dataset_values(values: np.ndarray,
                              activity: Optional[np.ndarray],
                              grid: UGrid,
                              activity_type: str,
                              null_value: Optional[float]) -> Tuple[np.ndarray, Optional[np.ndarray]]:
    """Get dataset values and activity with given activity type.

    Args:
        values (np.ndarray): The dataset values.
        activity (Optional[np.ndarray]): The dataset activity.
        grid (UGrid): The UGrid.
        activity_type (str): The output activity type.
        null_value (Optional[float]): The output null value.

    Returns:
        (Optional[np.ndarray]): Tuple of values and activity.
    """
    values_out = np.where(np.isnan(values), null_value, values)
    activity_out = None
    if activity_type == 'cell activity':
        if activity is not None:
            if len(activity) == grid.cell_count:
                # use existing cell activity
                activity_out = activity
            elif len(activity) == grid.point_count:
                # use point activity converted to cell activity
                activity_out = active_cells_from_points(grid, activity)
            else:
                raise ValueError('Number of values in activity array should match the number of cells or points.')
        else:
            # use activity from nans
            if len(values) == grid.cell_count:
                activity_out = np.array(np.logical_not(np.ma.getmaskarray(np.ma.masked_invalid(values))), dtype=int)
            elif len(values) == grid.point_count:
                point_activity = np.array(np.logical_not(np.ma.getmaskarray(np.ma.masked_invalid(values))), dtype=int)
                activity_out = active_cells_from_points(grid, point_activity)
            else:
                raise ValueError('Number of dataset values should match the number of cells or points.')
    elif activity_type == 'point activity':
        if activity is not None:
            if len(activity) == grid.point_count:
                # use existing point activity
                activity_out = activity
            elif len(activity) == grid.cell_count:
                # use cell activity converted to point activity
                activity_out = active_points_from_cells(grid, activity)
            else:
                raise ValueError('Number of values in activity array should match the number of cells or points.')
        else:
            # use activity from nans
            if len(values) == grid.point_count:
                activity_out = np.array(np.logical_not(np.ma.masked_invalid(values).mask), dtype=int)
            elif len(values) == grid.cell_count:
                cell_activity = np.array(np.logical_not(np.ma.masked_invalid(values).mask), dtype=int)
                activity_out = active_points_from_cells(grid, cell_activity)
            else:
                raise ValueError('Number of dataset values should match the number of cells or points.')

    return values_out, activity_out
