"""Helper class for hot start data."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from collections import defaultdict
import logging
import re
from typing import TYPE_CHECKING

# 2. Third party modules
from adhparam.hot_start_data_set import HotStartDataSet
from adhparam.model_control import ModelControl as ParamModelControl
import pandas

# 3. Aquaveo modules
from xms.datasets.dataset_reader import DatasetReader
from xms.tool_core.table_definition import ChoicesColumnType, StringColumnType, TableDefinition

# 4. Local modules

if TYPE_CHECKING:  # noqa: I300
    from xms.adh.data.xms_query_data import XmsQueryData  # noqa: I300

DATASET_HEADER = "Dataset"
DATASET_CHOICES = "Dataset Choices"
VARIABLE_HEADER = "Variable Name"
TIMESTEP_HEADER = "Timestep"
TIMESTEP_CHOICES = "Timestep Choices"
NOT_ENABLED = "-- Not enabled --"


class HotStartHelper:
    """Helper for displaying hot start data."""
    def __init__(self, xms_data: 'XmsQueryData', param_data: ParamModelControl):
        """Construct HotStartHelper.

        Args:
            xms_data: XmsData object for communicating with XMS.
            param_data: The param ModelControl object.
        """
        self._xms_data = xms_data
        self._param_data = param_data
        self.previous_hot_starts = None
        self._dataset_uuid_from_path = {}
        self._dataset_path_from_uuid = {}
        self._scalar_choices = self._build_dataset_list(xms_data.scalar_datasets)
        self._vector_choices = self._build_dataset_list(xms_data.vector_datasets)

    def create_hot_start_table_definition(self) -> TableDefinition:
        """Creates the hot start table definition.

        Returns:
            The hot start table definition.
        """
        columns = [
            StringColumnType(header=DATASET_HEADER, tool_tip="Dataset selection", default=NOT_ENABLED, choices=1),
            ChoicesColumnType(header=DATASET_CHOICES, default=[[""]]),
            StringColumnType(header=VARIABLE_HEADER, tool_tip="Hot start name", default=VARIABLE_HEADER, enabled=False),
            StringColumnType(header=TIMESTEP_HEADER, tool_tip="Time step selection", choices=4, default=""),
            ChoicesColumnType(header=TIMESTEP_CHOICES, default=[[""]])
        ]
        hot_start_row_count = len(self._potential_hot_starts())
        return TableDefinition(columns, fixed_row_count=hot_start_row_count)

    def build_data_frame(self, hot_starts: dict[str, (str, int)]) -> pandas.DataFrame:
        """
        Builds a pandas DataFrame based on the given table definition and hot_starts.

        Args:
            hot_starts: A dictionary containing hot starts for each item.
                        The keys are names, and the values are tuples of the form (uuid, time_step_idx).

        Returns:
            The built DataFrame.
        """
        datasets = []
        dataset_choices = []
        variables = []
        time_steps = []
        time_step_choices = []
        potential_hot_starts = self._potential_hot_starts()
        for variable_name, ui_name in potential_hot_starts:
            datasets_for_type = self._vector_choices if variable_name == 'iov' else self._scalar_choices
            value_added = False
            if variable_name in hot_starts:
                uuid, time_step_idx = hot_starts[variable_name]
                if uuid is not None:
                    dataset_path = self._dataset_path_from_uuid[uuid]
                    if dataset_path in datasets_for_type:
                        dataset_time_steps = self._get_timesteps(dataset_path)
                        if time_step_idx < len(dataset_time_steps):
                            time_step = dataset_time_steps[time_step_idx]
                        else:
                            time_step = dataset_time_steps[0]
                        datasets.append(dataset_path)
                        dataset_choices.append(datasets_for_type)
                        variables.append(ui_name)
                        time_steps.append(time_step)
                        time_step_choices.append(dataset_time_steps)
                        value_added = True
            if not value_added:
                datasets.append(NOT_ENABLED)
                dataset_choices.append(datasets_for_type)
                variables.append(ui_name)
                time_steps.append('')
                time_step_choices.append([])
        data_frame = pandas.DataFrame(
            {
                DATASET_HEADER: pandas.Series(datasets, dtype='str'),
                DATASET_CHOICES: pandas.Series(dataset_choices, dtype='object'),
                VARIABLE_HEADER: pandas.Series(variables, dtype='str'),
                TIMESTEP_HEADER: pandas.Series(time_steps, dtype='str'),
                TIMESTEP_CHOICES: pandas.Series(time_step_choices, dtype='object')
            }
        )
        return data_frame

    def update_time_steps(self, new_hot_starts):
        """
        Updates time steps column as needed.

        Args:
            new_hot_starts: changed hot starts.
        """
        if self.previous_hot_starts is not None:
            for index, row in new_hot_starts.iterrows():
                if row[DATASET_HEADER] != self.previous_hot_starts[DATASET_HEADER][index]:
                    self._update_row_timestep(index, row[DATASET_HEADER], new_hot_starts)
        self.previous_hot_starts = new_hot_starts.copy()

    def get_hot_starts(self, data_frame) -> dict[str, (str, int)]:
        """
        Get hot start data from a data frame.

        Args:
            data_frame: The pandas DataFrame containing information about the hot starts.

        Returns: A dictionary containing the hot starts information.
            The keys are the variable names, and the values are tuples of (dataset_uuid, time_step_index).
        """
        hot_starts = {}
        variable_names = self._get_variable_names()
        for _index, row in data_frame.iterrows():
            dataset = row[DATASET_HEADER]
            if dataset != NOT_ENABLED:
                ui_name = row[VARIABLE_HEADER]
                variable_name = variable_names[ui_name]
                time_steps = row[TIMESTEP_CHOICES]
                time_step = time_steps.index(row[TIMESTEP_HEADER])
                dataset_uuid = self._dataset_uuid_from_path[dataset]
                hot_starts[variable_name] = (dataset_uuid, time_step)
        return hot_starts

    def _potential_hot_starts(self) -> (str, str):
        """Get a list of potential hot starts.

        Returns: A list of tuples (variable name, UI name).
        """
        hot_starts = [
            ('ioh', 'Initial depth (ioh)'), ('iov', 'Initial velocity (iov)')
        ]
        sediment_constituents_io = self._xms_data.adh_data.sediment_constituents_io
        if sediment_constituents_io is not None:
            hot_starts.append(('ibd', 'Sediment bed displacement (ibd)'))
            clay_df = sediment_constituents_io.param_control.clay
            for _index, constituent_id in enumerate(clay_df['ID']):
                hot_starts.append((f'icon-clay-{_index + 1}-1-{constituent_id}', f'Clay {_index + 1} (icon) 1'))
                hot_starts.append((f'icon-clay-{_index + 1}-2-{constituent_id}', f'Clay {_index + 1} (icon) 2'))
                hot_starts.append((f'icon-clay-{_index + 1}-3-{constituent_id}', f'Clay {_index + 1} (icon) 3'))
            sand_df = sediment_constituents_io.param_control.sand
            for _index, constituent_id in enumerate(sand_df['ID']):
                hot_starts.append((f'icon-sand-{_index + 1}-1-{constituent_id}', f'Sand {_index + 1} (icon) 1'))
                hot_starts.append((f'icon-sand-{_index + 1}-2-{constituent_id}', f'Sand {_index + 1} (icon) 2'))
                hot_starts.append((f'icon-sand-{_index + 1}-3-{constituent_id}', f'Sand {_index + 1} (icon) 3'))
        transport_io = self._xms_data.adh_data.transport_constituents_io
        if transport_io is not None:
            if transport_io.param_control.salinity:
                hot_starts.append(('icon-salinity', 'Salinity (icon)'))
            if transport_io.param_control.temperature:
                hot_starts.append(('icon-temperature', 'Temperature (icon)'))
            if transport_io.param_control.vorticity:
                hot_starts.append(('icon-vorticity', 'Vorticity (icon)'))
            for _index in range(transport_io.param_control.general_constituents.shape[0]):
                hot_starts.append((f'icon-concentration-{_index + 1}', f'Concentration {_index + 1} (icon)'))
        sediment_materials_io = self._xms_data.adh_data.sediment_materials_io
        if sediment_materials_io is not None:
            for _index in range(sediment_materials_io.materials[0].bed_layers.shape[0]):
                hot_starts.append((f'iblt {_index + 1}', f'Bed layer {_index + 1} thickness (iblt)'))
        return hot_starts

    def _get_variable_names(self) -> dict[str, str]:
        """Retrieves a dictionary of UI names to variable names.

        Returns: A dictionary mapping hot start UI names to variable names.
        """
        hot_starts = self._potential_hot_starts()
        variable_names = {hot_start_name: hot_start_id for hot_start_id, hot_start_name in hot_starts}
        return variable_names

    def _update_row_timestep(self, index, dataset_path, new_data_frame):
        """
        Retrieves timestep options for a given dataset and populates the timestep cell on table for that dataset.

        Args:
            index: The row index.
            dataset_path: The dataset path
            new_data_frame: The data frame.
        """
        if new_data_frame[DATASET_HEADER][index] == NOT_ENABLED:
            new_data_frame[TIMESTEP_HEADER][index] = ""
            new_data_frame[TIMESTEP_CHOICES][index] = []
            return
        time_steps = self._get_timesteps(dataset_path)
        new_data_frame[TIMESTEP_HEADER][index] = time_steps[0]
        new_data_frame[TIMESTEP_CHOICES][index] = time_steps

    def _get_timesteps(self, dataset_path: str) -> list[str]:
        """Retrieve a dataset's timestep array.

        Args:
            dataset_path(str): The dataset path.

        Returns:
            A list of tuples made up of indices and timestep values from the dataset.
        """
        dataset_uuid = self._dataset_uuid_from_path[dataset_path]
        dataset_reader = self._xms_data.get_dataset_from_uuid(dataset_uuid)
        times = dataset_reader.times[:]
        time_steps = [f'{index + 1}: {time}' for index, time in enumerate(times)]
        return time_steps

    def _get_dataset(self, dataset_path) -> DatasetReader:
        """Get a dataset by name.

        Args:
            dataset_path(str): The dataset path.

        Returns:
            The dataset reader.
        """
        dataset = None
        dset_uuid = self._dataset_uuid_from_path.get(dataset_path)
        if dset_uuid is not None:
            dataset = self._xms_data.get_dataset_from_uuid(dset_uuid)
        return dataset

    def _build_dataset_list(self, datasets) -> list[str]:
        """
        Build a list of datasets to choose from.

        Args:
            datasets (list): A list of tree items representing datasets.

        Returns:
            A list of choices generated from the datasets, including 'NOT_ENABLED' as the first choice.
        """
        choices = [NOT_ENABLED]
        for tree_item in datasets:
            path = _get_tree_path(tree_item)
            uuid = tree_item.uuid
            choices.append(path)
            self._dataset_uuid_from_path[path] = uuid
            self._dataset_path_from_uuid[uuid] = path
        return choices


def _get_tree_path(tree_item):
    """Build a path to a UGrid or dataset.

    Args:
        tree_item (xms.guipy.tree.tree_node.TreeNode): Tree node of the dataset.

    Returns:
        (str): Project explorer path to the dataset.
    """
    path = [tree_item.name]
    tree_item = tree_item.parent
    while tree_item is not None:
        path.append(tree_item.name)
        tree_item = tree_item.parent
    path.reverse()
    path = path[1:]  # Trim off the project node
    return '/'.join(path)


def merge_hot_start_datasets(hot_start_list: list[HotStartDataSet], logger: logging.Logger) -> list[HotStartDataSet]:
    """
    Merges sand and clay hot start datasets into single 3-component datasets.

    Args:
        hot_start_list: The list of HotStartDataSet objects.
        logger: Logger for warnings/errors.

    Returns:
        List of HotStartDataSet objects with sand/clay merged.
    """
    merged_hot_starts = []
    groups: dict[tuple[str, int, int], dict[int, HotStartDataSet]] = defaultdict(dict)

    # Match names like icon-sand-1-2-3 (kind, group, component, constituent_id)
    pattern = re.compile(r"icon-(sand|clay)-(\d+)-(\d+)-(\d+)")

    # Bucket datasets
    for ht in hot_start_list:
        m = pattern.fullmatch(ht.name)
        if m:
            kind = m.group(1)  # "sand" or "clay"
            group = int(m.group(2))  # group number
            comp = int(m.group(3))  # 1, 2, or 3
            constituent_id = int(m.group(4))  # the constituent id
            groups[(kind, group, constituent_id)][comp] = ht
        else:
            merged_hot_starts.append(ht)

    # Merge groups
    for (kind, group, constituent_id), comp_map in groups.items():
        if all(c in comp_map for c in (1, 2, 3)):
            group_list = [comp_map[1], comp_map[2], comp_map[3]]
            merged_values = pandas.concat([g.values.iloc[:, 0].rename(i) for i, g in enumerate(group_list)], axis=1)
            merged = HotStartDataSet(
                name=f"icon-{constituent_id}",
                values=merged_values,
                number_of_cells=group_list[0].number_of_cells
            )
            merged_hot_starts.append(merged)
        else:
            logger.warning(
                f"Skipping {kind} group {group} (constituent {constituent_id}): "
                f"missing one or more of datasets "
            )

    return merged_hot_starts
