"""BcCoverageData class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from pathlib import Path

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.gmi.data.coverage_data import CoverageData

# 4. Local modules

# Constants
GSSHA_ARC_LINKS = 'gssha_arc_links'  # Name of arc link numbers xarray dataset

# Type aliases
ArcLinks = dict[int, int]  # arc.id -> link number


class BcCoverageData(CoverageData):
    """GSSHA BC coverage data."""

    GSSHA_DATA_VERSION = 'GSSHA_DATA_VERSION'  # To track changes to the generic model
    LATEST_VERSION = 1  # The latest version

    def __init__(self, data_file: str | Path):
        """Initialize the data class.

        Args:
            data_file: The netcdf file (with path) associated with this instance data. Probably the owning
                       component's main file.
        """
        super().__init__(data_file)
        self._migrate()
        # Always add the GSSHA_DATA_VERSION string, set to the latest version
        self.info.attrs[self.GSSHA_DATA_VERSION] = self.LATEST_VERSION
        # self.commit()

    def _migrate(self):
        """Method to migrate data from different versions of the file."""
        pass

    def get_arc_links(self) -> ArcLinks:
        """Returns the mapping of arc.id -> link number.

        Returns:
            See description.
        """
        arc_links_dataset = self._get_dataset(GSSHA_ARC_LINKS)
        arc_ids = arc_links_dataset['arc_id'].values
        link_numbers = arc_links_dataset['link_number'].values
        arc_links = {a: l for (a, l) in zip(arc_ids, link_numbers)}
        return arc_links

    def set_arc_links(self, arc_links: ArcLinks) -> None:
        """Sets the mapping of arc.id -> link number.

        Args:
            arc_links: Dict of arc.id -> link number
        """
        self._set_dataset(make_links_dataset(arc_links), GSSHA_ARC_LINKS)

    @property
    def _dataset_names(self) -> set[str]:
        """
        The names of datasets used by this data class.

        Derived classes can override it to add/remove names. If they add names, they should also override
        `self._create_dataset()`.
        """
        return super()._dataset_names | {GSSHA_ARC_LINKS}

    def _create_dataset(self, name: str) -> xr.Dataset:
        """
        Create an empty dataset, given its name.

        Derived classes should override this to handle any names they add to `self._dataset_names`. Any names they
        don't need to specifically handle should be handled by `return super()._create_dataset(name)`.

        Args:
            name: The name of the dataset to create.

        Returns:
            A new dataset with the appropriate structure for this name.
        """
        if name == GSSHA_ARC_LINKS:
            return make_links_dataset()

        return super()._create_dataset(name)


def make_links_dataset(arc_links: 'ArcLinks | None' = None) -> xr.Dataset:
    """Makes and returns an xarray dataset for storing arc link numbers.

    Args:
        arc_links: Dict of arc id -> link number.

    Returns:
        See description.
    """
    arc_links = arc_links or {}
    arc_ids = list(arc_links.keys())
    link_numbers = list(arc_links.values())
    data_vars = {'link_number': ('arc_id', np.array(link_numbers, dtype=int))}
    coords = {'arc_id': np.array(arc_ids, dtype=int)}  # coords is what we look things up by
    return xr.Dataset(data_vars=data_vars, coords=coords)
