"""CoverageMapper class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import copy
import logging
from typing import Type

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import TreeNode
from xms.coverage.grid import ugrid_mapper
from xms.coverage.grid.ugrid_mapper import UgridMapper
from xms.data_objects.parameters import Coverage, FilterLocation
from xms.gmi.data.generic_model import GenericModel
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.hgs.components.bc_coverage_component import BcCoverageComponent
from xms.hgs.components.hgs_coverage_component_base import HgsCoverageComponentBase
from xms.hgs.components.hydrograph_coverage_component import HydrographCoverageComponent
from xms.hgs.components.obs_coverage_component import ObsCoverageComponent
from xms.hgs.data import bc_generic_model, hydrograph_generic_model, obs_generic_model
from xms.hgs.file_io import file_io_util
from xms.hgs.mapping.map_att import MapAtt

# Type aliases
MapAttDict = dict[str, list[Type[MapAtt]]]  # component class name -> list of MapAtt


def _init_generic_model(coverage_type: str) -> GenericModel:
    """Initializes the generic model.

    Args:
        coverage_type: The coverage type.

    Returns:
        (GenericModel): The generic model.
    """
    if coverage_type == 'Boundary Conditions':
        return bc_generic_model.create()
    elif coverage_type == 'Observations':
        return obs_generic_model.create()
    elif coverage_type == 'Hydrographs':
        return hydrograph_generic_model.create()
    else:
        raise ValueError(f'Unrecognized coverage type: "{coverage_type}".')


class CoverageMapper:
    """Maps coverages to a UGrid."""
    def __init__(self, query: 'Query | None' = None, sim_node: 'TreeNode | None' = None, co_grid_3d=None) -> None:
        """Initializes the class.

        Args:
            query: Object for communicating with GMS (only None when testing).
            sim_node: Simulation tree node (only None when testing).
            co_grid_3d: The 3D grid (only None when testing).
        """
        self._query = query
        self._sim_node = sim_node
        self._co_grid_3d = co_grid_3d

        self._ugrid_mapper: UgridMapper | None = None
        self._map_atts: MapAttDict = {}
        self._log = logging.getLogger('xms.hgs')

    def map(self) -> MapAttDict:
        """Maps coverages to a UGrid.

        Returns:
            (dict[str, list[MapAtt]]): The att items, sorted by coverage uuid.
        """
        coverage_ptr_nodes = file_io_util.get_coverage_pointers(self._sim_node, '')
        if not coverage_ptr_nodes:
            return {}

        if self._co_grid_3d:  # Can be None when testing
            # Add point_sheets attribute to co_grid
            self._co_grid_3d.point_sheets = ugrid_mapper.compute_point_sheets(self._co_grid_3d)
            self._ugrid_mapper = UgridMapper(self._co_grid_3d, cell_materials=None)

        for coverage_ptr_node in coverage_ptr_nodes:
            # Get coverage and component
            coverage = self._query.item_with_uuid(coverage_ptr_node.uuid)
            component = self._get_component(coverage_ptr_node)
            if not component:
                raise (
                    RuntimeError(f'Could not get component for "{coverage.name}".')
                )  # pragma no cover - can't happen

            self._log.info(f'Mapping "{coverage.name}" coverage...')
            self._query.load_component_ids(component, points=True, arcs=True, polygons=True)

            generic_model = _init_generic_model(coverage_ptr_node.coverage_type)
            self._intersect_features(coverage, component, generic_model, TargetType.point)
            self._intersect_features(coverage, component, generic_model, TargetType.arc)
            self._intersect_features(coverage, component, generic_model, TargetType.polygon)

        return self._map_atts

    def _get_component(self, coverage_ptr_node: TreeNode) -> Type[HgsCoverageComponentBase]:
        """Returns the coverage component given the tree node.

        Args:
            coverage_ptr_node (TreeNode): Tree node pointing to coverage.

        Returns:
            (Type[HgsCoverageComponentBase]): See description
        """
        coverage_class, class_name = self._component_from_coverage_type(coverage_ptr_node.coverage_type)
        return file_io_util.get_coverage_component(coverage_ptr_node, coverage_class, class_name, self._query)

    @staticmethod
    def _component_from_coverage_type(coverage_type: str) -> tuple[Type, str]:
        """Returns the coverage component class and class name given the coverage type.

        Args:
            coverage_type: The coverage type.

        Returns:
            (tuple[Type, str]): See description.
        """
        if coverage_type == 'Boundary Conditions':
            return BcCoverageComponent, 'BcCoverageComponent'
        elif coverage_type == 'Observations':
            return ObsCoverageComponent, 'ObsCoverageComponent'
        elif coverage_type == 'Hydrographs':
            return HydrographCoverageComponent, 'HydrographCoverageComponent'
        else:
            raise ValueError(f'Unrecognized coverage type: "{coverage_type}".')

    @staticmethod
    def _features_from_type(coverage: Coverage, feature_type: TargetType):
        """Writes the hydrographs on the specified feature type."""
        if feature_type == TargetType.point:
            features = coverage.get_points(FilterLocation.PT_LOC_DISJOINT)
        elif feature_type == TargetType.arc:
            features = coverage.arcs
        elif feature_type == TargetType.polygon:
            features = coverage.polygons
        else:
            raise ValueError(f'Unsupported feature type: "{str(feature_type)}".')
        return features

    def _intersect_features(
        self, coverage: Coverage, component: Type[HgsCoverageComponentBase], generic_model: GenericModel,
        feature_type: TargetType
    ) -> None:
        """Intersects all the features of type.

        Args:
            coverage: The coverage.
            component: The coverage component.
            generic_model: The generic model.
            feature_type: Type of feature object.
        """
        features = self._features_from_type(coverage, feature_type)
        gmi_section = generic_model.section_from_target_type(feature_type)  # point_parameters etc

        for feature in features:
            comp_id = component.get_comp_id(feature_type, feature.id)
            if comp_id is None or comp_id < 0:
                continue

            # Get the data for this feature
            att_type, values = component.data.feature_type_values(feature_type, comp_id)
            if not values:
                continue  # pragma no cover - should never happen (? I think) and can't test

            # Create MapAtt items, intersect them, and store them in a dict by component class name
            gmi_section.restore_values(values)
            if not gmi_section.has_group(att_type):
                continue  # pragma no cover - should never happen and can't test
            group = copy.deepcopy(gmi_section.group(att_type))  # Have to do this cause generic model stuff is screwy

            map_att = component.make_map_att(feature, att_type, group)
            map_att.intersect(self._ugrid_mapper)
            if component.class_name not in self._map_atts:
                self._map_atts[component.class_name] = []
            self._map_atts[component.class_name].append(map_att)
