"""MapAtt class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import itertools
from typing import Type

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage.grid.ugrid_mapper import UgridMapper
from xms.data_objects.parameters import Arc, Point, Polygon
from xms.gmi.data.generic_model import Group
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.hgs.data import bc_generic_model
from xms.hgs.data.domains import Domains
from xms.hgs.file_io.section import SectionBase


def _feature_type(feature) -> str:
    """Returns the type of the feature object as a string.

    Args:
        feature: A Point, Arc, or Polygon.

    Returns:
        (str): See description
    """
    if isinstance(feature, Point):
        return 'point'
    elif isinstance(feature, Arc):
        return 'arc'
    elif isinstance(feature, Polygon):
        return 'polygon'
    else:
        raise ValueError(f'Unexpected value "{feature}".')


class MapAtt:
    """A coverage attribute item (i.e. BC, observation, hydrograph) with attributes and intersected grid components."""
    def __init__(self, feature, att_type: str, group: Group) -> None:
        """Initializes the class.

        Args:
            feature: A point, arc, or polygon.
            att_type (str): Attribute type (e.g. 'Flux nodal')
            group (Group): The BC as a generic model group.
        """
        self.att_type: str = att_type
        self.group = group

        self.feature = feature
        self.intersection = {}  # Intersected grid features (points, faces etc. Type can vary)
        self._max_sheet = None
        self._min_sheet = None
        self._top_elev = None
        self._bottom_elev = None

    def value(self, parameter_name: str, default=None):
        """Helper function that returns the parameter value or the default if the parameter doesn't exist.

        Args:
            parameter_name (str): Name of the parameter.
            default: The default value returned if the parameter doesn't exist.

        Returns:
            The parameter value or the default.
        """
        try:
            return self.group.parameter(parameter_name).value
        except KeyError:
            return default

    def set_value(self, parameter_name: str, value):
        """Helper function that sets the parameter value, but only locally - the coverage isn't affected.

        Args:
            parameter_name (str): Name of the parameter.
            value: The value.
        """
        try:
            self.group.parameter(parameter_name).value = value
        except KeyError:
            pass

    def intersect(self, ugrid_mapper: UgridMapper) -> None:
        """Intersects the feature object with the grid.

        Args:
            ugrid_mapper (UgridMapper | None): The UgridMapper for the 3D grid.
        """
        grid_component_type = self.get_grid_component_type()
        self._find_vertical_range(ugrid_mapper)
        if isinstance(self.feature, Point):
            self._points_at_point(self.feature, ugrid_mapper)
        elif isinstance(self.feature, Arc):
            if grid_component_type in {'node', 'face', 'segment'}:
                self._points_on_arc(self.feature, ugrid_mapper)
                if grid_component_type == 'face':
                    self._faces_on_arc(self.feature, ugrid_mapper)
            else:
                self._raise_unsupported_combination_error(grid_component_type)  # pragma no cover - too hard to test
        else:  # isinstance(feature, Polygon):
            if grid_component_type == 'node':
                self._points_in_polygon(self.feature, ugrid_mapper)
            elif grid_component_type == 'face':
                self._top_faces_in_polygon(self.feature, ugrid_mapper)
            else:
                self._raise_unsupported_combination_error(grid_component_type)  # pragma no cover - too hard to test

        # Create the extra face and segment lists
        if grid_component_type == 'face':
            self._create_choose_faces_by_nodes_list(ugrid_mapper._ugrid)

    def _raise_unsupported_combination_error(self, grid_component_type):  # pragma no cover - too hard to test
        """Raise an exception."""
        feature_type = _feature_type(self.feature)
        raise RuntimeError(f'Unsupported combination "{feature_type}" and "{grid_component_type}".')

    def write(self, section: Type[SectionBase], *args, **kwargs):
        """Writes this MapAtt object to grok.

        Args:
            section (Type[SectionBase]): The section we are writing to.
            *args: Other arguments.
            **kwargs: Arbitrary keyword arguments.
        """
        pass  # pragma no cover - can't test this

    def _find_vertical_range(self, ugrid_mapper: UgridMapper):
        """Returns the max and min point sheets and the top and bottom elevations (two of these will be None).

        Args:
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.
        """
        self._max_sheet, self._min_sheet, self._top_elev, self._bottom_elev = None, None, None, None
        if self._use_top_sheet():
            self._min_sheet = self._max_sheet = max(ugrid_mapper.co_grid.point_sheets)
        else:
            range_option = self.value('range_opts')
            if range_option == bc_generic_model.max_min_sheets:
                self._max_sheet = int(self.value('max_sheet'))
                self._min_sheet = int(self.value('min_sheet'))
            else:
                self._top_elev = float(self.value('top_elev'))
                self._bottom_elev = float(self.value('bottom_elev'))

    def _use_top_sheet(self) -> bool:
        """Returns true if this should be mapped to the 2D grid.

        Returns:
            (bool): See description.
        """
        domain = self.value('domain')
        return domain is None or domain != Domains.PM

    def get_grid_component_type(self) -> str:
        """Returns the type of grid component ('point', 'face', 'segment') from the att type.

        Override this.

        Returns:
              (str): See description.
        """
        raise NotImplementedError('Implement')  # pragma no cover - shouldn't ever get here

    def _points_at_point(self, point: Point, ugrid_mapper: UgridMapper) -> None:
        """Returns the UGrid point indices under the feature point.

        Args:
            point (Point): The feature point.
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.

        Returns:
            (list[int]): List of ugrid point indices.
        """
        point_idxs, weights = ugrid_mapper.get_ugrid_points_at_point(
            point, z_min=self._bottom_elev, z_max=self._top_elev, sheet_min=self._min_sheet, sheet_max=self._max_sheet
        )
        self.intersection['points'] = point_idxs
        self.intersection['weights'] = weights

    def _points_on_arc(self, arc: Arc, ugrid_mapper: UgridMapper) -> None:
        """Returns the UGrid point indices under the feature arc.

        Args:
            arc (Arc): The feature arc.
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.
        """
        exterior = self.value('map_to_boundary')
        point_idxs, t_values = ugrid_mapper.get_ugrid_points_on_arc(
            arc,
            exterior=exterior,
            z_min=self._bottom_elev,
            z_max=self._top_elev,
            sheet_min=self._min_sheet,
            sheet_max=self._max_sheet
        )
        self.intersection['points'] = point_idxs
        self.intersection['t_values'] = t_values

    def _points_in_polygon(self, polygon: Polygon, ugrid_mapper: UgridMapper) -> None:
        """Stores the UGrid point indices under the feature point.

        Args:
            polygon (Polygon): The feature polygon.
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.
        """
        point_idxs = ugrid_mapper.get_ugrid_points_in_polygon(
            polygon,
            strictly_in=True,
            z_min=self._bottom_elev,
            z_max=self._top_elev,
            sheet_min=self._min_sheet,
            sheet_max=self._max_sheet
        )
        self.intersection['points'] = point_idxs

    def _faces_on_arc(self, arc: Arc, ugrid_mapper: UgridMapper) -> None:
        """Stores the vertical faces intersected by the arc.

        Args:
            arc (Arc): The feature arc.
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.
        """
        exterior = self.value('map_to_boundary')
        if exterior:
            sorted_faces = ugrid_mapper.get_ugrid_side_faces_on_arc(
                arc,
                z_min=self._bottom_elev,
                z_max=self._top_elev,
                sheet_min=self._min_sheet,
                sheet_max=self._max_sheet
            )
            self.intersection['faces'] = sorted_faces
        else:
            layer_face_lists = ugrid_mapper.get_vertical_faces_intersected_by_arc(
                arc,
                z_min=self._bottom_elev,
                z_max=self._top_elev,
                sheet_min=self._min_sheet,
                sheet_max=self._max_sheet
            )

            # Flatten the list of lists
            self.intersection['faces'] = itertools.chain.from_iterable(layer_face_lists)

    def _top_faces_in_polygon(self, polygon: Polygon, ugrid_mapper: UgridMapper) -> None:
        """Stores the top faces in the polygon.

        Args:
            polygon (Polygon): The feature polygon.
            ugrid_mapper (UgridMapper): The UGrid 3d mapper.
        """
        self.intersection['faces'] = ugrid_mapper.get_ugrid_top_faces_in_polygon(polygon)

    def _create_choose_faces_by_nodes_list(self, ugrid3d: UGrid) -> None:
        """Returns a list for use by the 'Choose faces by nodes list' command.

        n1(i), n2(i), n3(i), n4(i)...end Node numbers for each face to be chosen.
        """
        faces_by_nodes_list = []
        for cell_idx, face_idx in self.intersection['faces']:
            face_points = list(ugrid3d.get_cell_3d_face_points(cell_idx, face_idx))  # tuple to list
            if face_points:
                face_points = [p + 1 for p in face_points]  # Make it 1-based here so when we append 0 it stays 0
                if len(face_points) < 4:
                    face_points.append(0)
                faces_by_nodes_list.append(face_points)
        self.intersection['face nodes'] = faces_by_nodes_list
