"""SrhCoverageData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import orjson
import pandas as pd
import pkg_resources

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.srh.data.par import par_util


class SrhCoverageData(XarrayBase):
    """Class for storing the SRH Material properties (Manning's N and sediment)."""
    def __init__(self, filename, file_type_str, data_name_str):
        """Constructor.

        Args:
            filename (:obj:`str`): file name
            file_type_str (:obj:`str`): file type identifier
            data_name_str (:obj:`str`): name of the derived class

        """
        super().__init__(filename)
        if 'cov_uuid' not in self.info.attrs:
            self.info.attrs['cov_uuid'] = ''  # gets set later
        else:
            self.cov_uuid = self.info.attrs['cov_uuid']
        if 'display_uuid' not in self.info.attrs:
            self.info.attrs['display_uuid'] = ''
        else:
            self.display_uuid = self.info.attrs['display_uuid']
        self.info.attrs['FILE_TYPE'] = file_type_str
        self.data_name_str = data_name_str
        self._data = None
        self.load_all()

    def load_all(self):
        """Loads all datasets from the file."""
        _ = self.info
        _ = self.data
        self._on_load_all()
        self.close()

    def _on_load_all(self):
        """Method to override in derived class."""
        pass

    @property
    def data(self):
        """Get the dataset.

        Returns:
            (:obj:`xarray.Dataset`): The dataset

        """
        if self._data is None:
            self._data = self.get_dataset(self.data_name_str, False)
            if self._data is None:
                self._data = self._default_param_dataset()
        return self._data

    @data.setter
    def data(self, dset):
        """Setter for the _data dataset."""
        if dset:
            self._data = dset

    def _default_param_dataset(self):
        """Creates a default dataset for holding instances of param data classes.

        Returns:
            (:obj:`xarray.Dataset`): The dataset

        """
        default_data = {'id': [0], 'json': ''}
        return pd.DataFrame(default_data).to_xarray()

    def _data_record_from_id(self, data_id):
        """Gets a record from an id. If the id is not in the bc_data then returns None.

        Args:
            data_id (:obj:`int`): component id

        Returns:
            (:obj:`dict`): The record in the dataframe
        """
        df = self.data.to_dataframe()
        record = df.loc[df['id'] == data_id]
        if len(record):
            return record.reset_index(drop=True).to_dict()
        return None

    def param_from_id(self, data_id, param_cls):
        """Gets a record from an id. If the id is not in the point_data then returns None.

        Args:
            data_id (:obj:`int`): component id
            param_cls (:obj:`Param.Parameterixed`): class to fill

        """
        record = self._data_record_from_id(data_id)
        param_cls = param_cls.__class__()
        if record is None:
            return param_cls
        json_txt = record['json'][0]
        if json_txt:
            par_dict = orjson.loads(json_txt.encode())
            par_util.orjson_dict_to_param_cls(par_dict, param_cls, None)
        return param_cls

    def append_param_data_with_id(self, param_data, data_id):
        """Sets a record with an id. If the id is < 1 then do nothing.

        Args:
            param_data (:obj:`Param.Parameterized`): instance of param class
            data_id (:obj:`int`): component id

        """
        if data_id < 1:
            return
        df = self.data.to_dataframe()
        row = max(df['id']) + 1
        df.loc[row, 'id'] = data_id
        pdict = par_util.param_cls_to_orjson_dict(param_data, None, False)
        json_txt = orjson.dumps(pdict).decode()
        df.loc[row, 'json'] = json_txt
        self._data = df.to_xarray()

    def commit(self):
        """Save in memory datasets to the NetCDF file."""
        self.info.attrs['VERSION'] = pkg_resources.get_distribution('xmssrh').version
        super().commit()

        self._drop_h5_groups([self.data_name_str])
        if self._data is not None:
            self._data.to_netcdf(self._filename, group=self.data_name_str, mode='a')

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._data is not None:
            self._data.close()
