"""Stores the data for the simulation model control."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import importlib.metadata
import os

# 2. Third party modules
from adhparam.model_control import ModelControl as ParamModelControl
from adhparam.time_series import TimeSeries
import h5py
import numpy
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.adh.data import param_h5_io
from xms.adh.data.json_export import filtered_info_attrs, include_dataframe_if_not_empty, include_if_xarray_not_empty, \
    make_json_serializable, parameterized_to_dict
from xms.adh.data.version import needs_update


class ModelControl(XarrayBase):
    """A class that handles saving and loading simulation model control data."""
    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self.main_file = main_file
        attrs = self.info.attrs
        defaults = self.default_data()
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_SIM_DATA'
        self.param_control = ParamModelControl()
        self.series_ids = xr.Dataset()
        self.hot_starts = pd.DataFrame(
            {
                'name': pd.Series([], dtype='str'),
                'uuid': pd.Series([], dtype='str'),
                'time_step_index': pd.Series([], dtype='int')
            }
        )
        self.time_series = {}
        self.domain_uuid = ''
        self.vessel_uuids = xr.Dataset()
        if os.path.exists(main_file):
            file_type = 'ADH_SIM_DATA'
            param_h5_io.read_from_h5_file(main_file, self.param_control, file_type)
            grp_name = 'uuids'
            uuid_info = xr.load_dataset(self.main_file, group=grp_name)
            self.domain_uuid = uuid_info.attrs['domain_uuid']
            self.read_time_series_from_h5(main_file)
            self.read_hot_starts_from_h5(main_file)
            self.read_vessel_uuids_from_h5(main_file)

    def get_hot_starts(self) -> dict[str, tuple[str, int]]:
        """Get the hot starts.

        Returns:
            A dictionary of hot start names each containing a tuple of UUID and time step.
        """
        hot_starts = {}
        for _index, row in self.hot_starts.iterrows():
            hot_starts[row['name']] = (row['uuid'], row['time_step_index'])
        return hot_starts

    def set_hot_starts(self, hot_starts: dict[str, (str, int)]) -> None:
        """
        Sets the hot start for a specific name to a dataset.

        Args:
            hot_starts: A dictionary of hot start names each containing a tuple of UUID and time step.
        """
        names = []
        uuids = []
        time_steps = []
        for name, items in hot_starts.items():
            names.append(name)
            uuids.append(items[0])
            time_steps.append(int(items[1]))
        self.hot_starts = pd.DataFrame({'name': names, 'uuid': uuids, 'time_step_index': time_steps})

    @staticmethod
    def default_data():
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        version = importlib.metadata.version('xmsadh')
        return {
            'FILE_TYPE': 'ADH_SIM_DATA',
            'VERSION': version,
            'os_time_series': 0,
        }

    def update_constants_to_feet(self):
        """Updates the constants to feet."""
        model_constants = self.param_control.model_constants
        model_constants.mannings_unit_constant = 1.486
        model_constants.gravity = 32.17
        model_constants.density = 1.940
        model_constants.kinematic_viscosity = 1.08e-5

    def read_time_series_from_h5(self, filename):
        """Reads the time series from the file.

        Args:
            filename (str): The file to read time series data from.
        """
        series_ids = '/time_series/ids'
        self.series_ids = xr.load_dataset(filename, group=series_ids)
        for t_id in self.series_ids['ids']:
            if t_id and t_id > 0 and not numpy.isnan(t_id):
                int_id = int(t_id)
                if int_id not in self.time_series:
                    self.time_series[int_id] = TimeSeries()
                param_h5_io.read_params_recursive(
                    filename, group_name=f'/time_series/{int_id}/', param_class=self.time_series[int_id]
                )

    def read_hot_starts_from_h5(self, filename):
        """Reads the hot starts from the file.

        Args:
            filename (str): The file to read hot starts data from.
        """
        if needs_update(self.info.attrs['VERSION'], '1.4.0.dev3'):
            self.read_hot_starts_before_1_4_0()
            return
        self.hot_starts = xr.load_dataset(filename, group='/hot_starts').to_dataframe()

    def read_hot_starts_before_1_4_0(self):
        """Read hot starts before version 1.4.0."""
        names = []
        uuids = []
        # put old hot start values into data frame
        if self.info.attrs['IOH_DATASET']:
            names.append('ioh')
            uuids.append(self.info.attrs['IOH_DATASET'])
        if self.info.attrs['USE_IOV_DATASET']:
            names.append('iov')
            uuids.append(self.info.attrs['IOV_DATASET'])
        self.hot_starts = pd.DataFrame(
            {
                'name': pd.Series(names, dtype='str'),
                'uuid': pd.Series(uuids, dtype='str'),
                'time_step_index': pd.Series([0] * len(names), dtype='int')
            }
        )
        # remove old values
        del self.info.attrs['IOH_DATASET']
        del self.info.attrs['USE_IOV_DATASET']
        del self.info.attrs['IOV_DATASET']

    def read_vessel_uuids_from_h5(self, filename):
        """Reads the vessel uuids from the file.

        Args:
            filename (str): The file to read vessel uuid data from.
        """
        if needs_update(self.info.attrs['VERSION'], '1.5.5.dev2'):
            self.vessel_uuids = xr.Dataset()
        else:
            self.vessel_uuids = xr.load_dataset(filename, group='/vessel_uuids')

    def commit(self):
        """Stores simulation data in the main file."""
        self.info.attrs['VERSION'] = importlib.metadata.version('xmsadh')
        super().commit()
        param_h5_io.write_params_recursive(self.main_file, '/', self.param_control)
        uuid_info = xr.Dataset()
        grp_name = 'uuids'
        uuid_info.attrs['domain_uuid'] = self.domain_uuid
        uuid_info.to_netcdf(self.main_file, group=grp_name, mode='a')

        id_list = list(self.time_series.keys())
        if not id_list:
            id_list = [[None]]
        else:
            id_list = [[series_id] for series_id in id_list]
        self.series_ids = pd.DataFrame(id_list, columns=['ids']).to_xarray()
        with h5py.File(self.main_file, 'a') as f:
            try:
                del f['/time_series/ids']
            except Exception:
                pass
        self.series_ids.to_netcdf(self.main_file, group='/time_series/ids', mode='a')
        for t_id in self.time_series.keys():
            if t_id and self.time_series[t_id]:
                param_h5_io.write_params_recursive(
                    self.main_file, group_name=f'/time_series/{t_id}/', param_class=self.time_series[t_id]
                )
        self._drop_h5_groups(['/hot_starts'])
        hot_starts = self.hot_starts.to_xarray()
        hot_starts.to_netcdf(self.main_file, group='/hot_starts', mode='a')
        self._drop_h5_groups(['/vessel_uuids'])
        self.vessel_uuids.to_netcdf(self.main_file, group='/vessel_uuids', mode='a')

    def as_dict(self) -> dict:
        """
        Converts the object's attributes and related data into a JSON serializable dictionary format.

        Returns:
            dict: A dictionary containing the serialized and filtered attributes of
            the object.

        Raises:
            TypeError: If processing the attributes results in an invalid data type.
        """
        model_control = parameterized_to_dict(self.param_control)

        # Filter info_attrs using the helper function
        info_attrs = filtered_info_attrs(self.info.attrs)

        # time_series = {series_id: parameterized_to_dict(series) for series_id, series in self.time_series.items()}

        # Use the helper functions to simplify the logic for DataFrames and Xarrays
        output = {
            "info_attrs": info_attrs,
            "series_ids": include_if_xarray_not_empty(self.series_ids),
            "hot_starts": include_dataframe_if_not_empty(self.hot_starts.drop(columns=['uuid'])),
            "time_series": self.time_series,
            "model_control": model_control,
            "vessel_count": len(self.vessel_uuids),
        }

        output = make_json_serializable(output)
        return output
