"""Data class for a mapped BC component."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from array import array
from pathlib import Path
from typing import Optional, Sequence

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules

# 4. Local modules
from xms.schism.data.base_data import BaseData

_OPEN_ARCS_DS = 'open_arcs'
_OPEN_NODES_DS = 'open_nodes'
_CLOSED_ARCS_DS = 'closed_arcs'
_CLOSED_NODES_DS = 'closed_nodes'


def make_open_arc_dataset(
    comp_id: Optional[int] = None,
    start: Optional[int] = None,
    length: Optional[int] = None,
    value: Optional[str] = None
):
    """
    Make a dataset for an open arc.

    Args:
        comp_id: The arc's component ID.
        start: The index in the node dataset of the arc's first node.
        length: How many nodes comprise the arc.
        value: A string that can be restored into the arc parameters of a generic model to determine which
                properties are assigned to the arc.

    Returns:
        The dataset.
    """
    comp_ids = [comp_id] if comp_id is not None else []
    starts = [start] if start is not None else []
    lengths = [length] if length is not None else []
    values = [value] if value is not None else []
    coords = {'comp_id': np.array(comp_ids, dtype=np.int32)}
    arc_data = {
        'start_index': ('comp_id', np.array(starts, dtype=np.int32)),
        'count': ('comp_id', np.array(lengths, dtype=np.int32)),
        'value': ('comp_id', np.array(values, dtype=object)),
    }
    dataset = xr.Dataset(data_vars=arc_data, coords=coords)
    return dataset


def make_closed_arc_dataset(
    comp_id: Optional[int] = None,
    start: Optional[int] = None,
    length: Optional[int] = None,
    flag: Optional[int] = None
):
    """
    Make a dataset for a closed arc.

    Args:
        comp_id: The arc's component ID.
        start: The index in the node dataset of the arc's first node.
        length: How many nodes comprise the arc.
        flag: The arc's land boundary type flag. The flag appears in the hgrid.gr3 file, as the second part of the
            lines that tell you how many nodes are in a particular closed boundary.

    Returns:
        The dataset.
    """
    comp_ids = [comp_id] if comp_id is not None else []
    starts = [start] if start is not None else []
    lengths = [length] if length is not None else []
    flags = [flag] if flag is not None else []

    coords = {'comp_id': np.array(comp_ids, dtype=np.int32)}
    arc_data = {
        'start_index': ('comp_id', np.array(starts, dtype=np.int32)),
        'count': ('comp_id', np.array(lengths, dtype=np.int32)),
        'flag': ('comp_id', np.array(flags, dtype=np.int32)),
    }
    dataset = xr.Dataset(data_vars=arc_data, coords=coords)
    return dataset


def make_node_dataset(nodes: Optional[Sequence[int]] = None) -> xr.Dataset:
    """
    Make a dataset for a list of node IDs.

    Args:
        nodes: List of node IDs.

    Returns:
        The dataset.
    """
    nodes = nodes if nodes is not None else []
    dataset = {'id': np.array(nodes, dtype=np.int32)}
    return xr.Dataset(data_vars=dataset)


class MappedBcData(BaseData):
    """Manages a mapped BC coverage's data."""
    def __init__(self, main_file: Optional[str | Path] = None):
        """
        Initialize the object.

        Args:
            main_file: Path to the component's main file.
        """
        # The superclass calls self.commit(), which we override to require these, so initialize them first.
        self._open_nodes = array('l')
        self._open_arcs = []
        self._closed_nodes = array('l')
        self._closed_arcs = []
        self.coverage_uuid = ''  # SCHISM doesn't use this, but we need it to keep GMI's CoverageComponent happy.

        super().__init__(main_file)

        self.info.attrs['SCHISM_FILE_TYPE'] = 'MAPPED_BC'

        self.commit()

    @property
    def domain_hash(self) -> str:
        """The hash of the domain the mapped coverage was made for."""
        return self.info.attrs.get('domain_hash', '')

    @domain_hash.setter
    def domain_hash(self, value: str):
        """The hash of the domain the mapped coverage was made for."""
        self.info.attrs['domain_hash'] = value

    def add_open_arc(self, nodes: Sequence[int], values: str) -> int:
        """
        Add an open arc.

        Args:
            nodes: Indices of the nodes comprising the arc.
            values: A string that can be restored into the arc parameters of a generic model to determine which
                properties are assigned to the arc.

        Returns:
            The component ID assigned to the newly added arc.
        """
        start = len(self._open_nodes)
        length = len(nodes)
        component_id = self._get_component_id()
        self._open_nodes.extend(nodes)
        arc = make_open_arc_dataset(component_id, start, length, values)
        self._open_arcs.append(arc)
        return component_id

    @property
    def open_arcs(self) -> Sequence[tuple[Sequence[int], str]]:
        """
        Open arc data.

        The data is a sequence of arcs. Each arc is a tuple of (locations, values). The locations are a list of node
        indices in the geometry, and the value is a string that can be restored into the arc parameters of a generic
        model to determine which properties are assigned to the arc.

        Returns:
            Open arc data.
        """
        arc_ds = self._get_dataset(_OPEN_ARCS_DS)
        node_ds = self._get_dataset(_OPEN_NODES_DS)
        result = []

        for comp_id in arc_ds['comp_id']:
            arc = arc_ds.where(arc_ds['comp_id'] == comp_id, drop=True)
            if arc.sizes['comp_id'] > 1:
                raise AssertionError('Reused component ID in file')  # pragma: nocover

            start = int(arc['start_index'].values[0])
            count = int(arc['count'].values[0])
            values = arc['value'].values[0]

            nodes = node_ds.isel(id=slice(start, start + count))
            node_list = nodes['id'].values
            result.append((node_list, values))
        return result

    def add_closed_arc(self, nodes: Sequence[int], flag: int) -> int:
        """
        Add a closed arc.

        Args:
            nodes: Indices of the nodes comprising the arc.
            flag: What type of land boundary it is. The flag appears in the hgrid.gr3 file, as the second part of the
                lines that tell you how many nodes are in a particular closed boundary.

        Returns:
            The component ID assigned to the newly added arc.
        """
        start = len(self._closed_nodes)
        length = len(nodes)
        component_id = self._get_component_id()
        self._closed_nodes.extend(nodes)
        arc = make_closed_arc_dataset(component_id, start, length, flag)
        self._closed_arcs.append(arc)
        return component_id

    @property
    def closed_arcs(self) -> Sequence[tuple[Sequence[int], int]]:
        """
        Closed arc data.

        The data is a sequence of arcs. Each arc is a tuple of (locations, flag). The locations are a list of node
        indices in the geometry, and the flag is what type of land boundary it is. The flag appears in the hgrid.gr3
        file, as the second part of the lines that tell you how many nodes are in a particular closed boundary.

        Returns:
            Closed arc data.
        """
        arc_ds = self._get_dataset(_CLOSED_ARCS_DS)
        node_ds = self._get_dataset(_CLOSED_NODES_DS)

        for comp_id in arc_ds['comp_id']:
            arc = arc_ds.where(arc_ds['comp_id'] == comp_id, drop=True)
            if arc.sizes['comp_id'] > 1:
                raise AssertionError('Reused component ID in file')  # pragma: nocover

            start = int(arc['start_index'].values[0])
            count = int(arc['count'].values[0])
            flag = int(arc['flag'].values[0])

            nodes = node_ds.isel(id=slice(start, start + count))
            node_list = nodes['id'].values
            yield node_list, flag

    @property
    def open_nodes(self) -> Sequence[int]:
        """The IDs of all the open nodes."""
        return self._get_dataset(_OPEN_NODES_DS)['id'].values

    @property
    def _dataset_names(self) -> set[str]:
        """The names of datasets used by this data class."""
        return super()._dataset_names | {'open_nodes', 'closed_nodes', 'open_arcs', 'closed_arcs'}

    @property
    def _main_file_name(self) -> str:
        """What to name the component's main file."""
        return 'schism_mapped_bc.nc'

    def _create_dataset(self, name: str):
        """Create a dataset."""
        if name in ['open_nodes', 'closed_nodes']:
            return make_node_dataset()

        elif name == 'open_arcs':
            return make_open_arc_dataset()

        elif name == 'closed_arcs':
            return make_closed_arc_dataset()

        else:
            return super()._create_dataset(name)

    def commit(self):
        """Save current in-memory component data to main file."""
        if self._open_nodes:
            existing_open_nodes = self._get_dataset(_OPEN_NODES_DS)
            new_open_nodes = make_node_dataset(self._open_nodes)
            combined_open_nodes = xr.concat([existing_open_nodes, new_open_nodes], dim='id')
            self._set_dataset(combined_open_nodes, _OPEN_NODES_DS)

            existing_open_arcs = self._get_dataset(_OPEN_ARCS_DS)
            combined_open_arcs = xr.concat([existing_open_arcs, *self._open_arcs], dim='comp_id')
            self._set_dataset(combined_open_arcs, _OPEN_ARCS_DS)

        if self._closed_nodes:
            existing_closed_nodes = self._get_dataset(_CLOSED_NODES_DS)
            new_closed_nodes = make_node_dataset(self._closed_nodes)
            combined_closed_nodes = xr.concat([existing_closed_nodes, new_closed_nodes], dim='id')
            self._set_dataset(combined_closed_nodes, _CLOSED_NODES_DS)

            existing_closed_arcs = self._get_dataset(_CLOSED_ARCS_DS)
            combined_closed_arcs = xr.concat([existing_closed_arcs, *self._closed_arcs], dim='comp_id')
            self._set_dataset(combined_closed_arcs, _CLOSED_ARCS_DS)

        super().commit()
