"""Creates a coverage with polygons."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from itertools import count, pairwise
from typing import Optional, Sequence, TypeAlias
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import Arc as DoArc, Coverage, Point as DoPoint, Polygon as DoPolygon, Projection

# 4. Local modules

# Public type aliases
Point: TypeAlias = Sequence[float]
Arc: TypeAlias = Sequence[Point]
Ring: TypeAlias = Sequence[Point]
Polygon: TypeAlias = Sequence[Ring]

# Internally used type aliases
IPoint: TypeAlias = tuple[float, float, float]
IArc: TypeAlias = tuple[IPoint, ...]
IRing: TypeAlias = tuple[IPoint, ...]
IPolygon: TypeAlias = tuple[IRing, ...]


def build(
    disjoint_points: Sequence | None = None,
    polylines: Sequence | None = None,
    name: str | None = None,
    uuid: str | None = None,
    projection: Projection | None = None
) -> Coverage:
    """Builds and returns a Coverage from xyz locations.

    I didn't include building polygons because I didn't need them, but feel free to add polygons if you like.

    Args:
        disjoint_points: xyz locations which will become data_objects.Point disjoint points.
        polylines: Lines that will become arcs, defined by xyz locations.
        name: Coverage name.
        uuid: Coverage uuid (if not provided, a random one will be assigned).
        projection: Coverage projection.

    Returns:
        See description.
    """
    builder = CoverageBuilder(name, uuid, projection)
    disjoint_points = disjoint_points or []
    for point in disjoint_points:
        builder.add_point(point)

    polylines = polylines or []
    for line in polylines:
        builder.add_arc(line)
    return builder.build()


class CoverageBuilder:
    """Builds a coverage from xyz locations."""
    def __init__(
        self, name: Optional[str] = None, coverage_uuid: Optional[str] = None, projection: Optional[Projection] = None
    ):
        """Builds and returns a Coverage from xyz locations.

        Args:
            name: Coverage name.
            coverage_uuid: Coverage uuid (if not provided, a random one will be assigned).
            projection: Coverage projection.

        Returns:
            See description.
        """
        self._points: dict[IPoint, DoPoint] = {}
        self._nodes: dict[IPoint, DoPoint] = {}
        self._point_id = count(start=1)  # Points and nodes share the same ID namespace in SMS.
        self._arcs: dict[IArc, DoArc] = {}
        self._arc_id = count(start=1)
        self._polygons: dict[IPolygon, DoPolygon] = {}
        self._polygon_id = count(start=1)
        self._outer_ring_arcs: set[IArc] = set()

        self._name = name
        self._uuid = coverage_uuid or str(uuid.uuid4())
        self._projection = projection

    def add_point(self, point: Point) -> int:
        """
        Add a point to the coverage.

        Points added by this method are disjoint (not attached to any other features).

        Args:
            point: A sequence of three floats, representing the x, y, and z coordinates of the point.

        Returns:
            The feature ID assigned to the point.
        """
        _validate_point(point)

        tuple_point = (point[0], point[1], point[2])
        point = self._get_point(tuple_point)
        return point.id

    def _get_point(self, point: IPoint) -> DoPoint:
        """Create or retrieve the data_objects point for the provided point."""
        if point not in self._points:
            x, y, z = point
            feature_id = next(self._point_id)
            do_point = DoPoint(x, y, z, feature_id=feature_id)
            self._points[point] = do_point
        return self._points[point]

    def add_arc(self, arc: Arc) -> int:
        """
        Add an arc to the coverage.

        Args:
            arc: The arc to add. Should be a sequence of points. Each point should be a sequence of three floats,
                representing the x, y, and z coordinates of the point. The  arc will be oriented such that the first
                point in the input is the arc's start node and the last point is the arc's end node.

        Returns:
            The feature ID assigned to the arc.
        """
        _validate_arc(arc)
        tuple_arc = tuple((x, y, z) for (x, y, z) in arc)
        arc = self._get_arc(tuple_arc, False)
        return arc.id

    def _get_arc(self, arc: IArc, is_outer_ring: bool) -> DoArc:
        """Create or retrieve the data_objects arc for the provided arc."""
        if arc not in self._arcs:
            start = self._get_node(arc[0])
            end = self._get_node(arc[-1])
            vertices = [DoPoint(x, y, z) for (x, y, z) in arc[1:-1]]
            feature_id = next(self._arc_id)
            do_arc = DoArc(start_node=start, end_node=end, vertices=vertices, feature_id=feature_id)
            self._arcs[arc] = do_arc
        if is_outer_ring:
            self._outer_ring_arcs.add(arc)
        return self._arcs[arc]

    def _get_node(self, node: IPoint) -> DoPoint:
        """Create or retrieve the data_objects node for the provided node."""
        if node not in self._nodes:
            x, y, z = node
            feature_id = next(self._point_id)
            do_node = DoPoint(x, y, z, feature_id=feature_id)
            self._nodes[node] = do_node

        return self._nodes[node]

    def add_polygon(self, polygon: Polygon) -> int:
        """
        Add a polygon to the coverage.

        Args:
            polygon: The polygon to add. Should be a sequence of rings, where the first ring is the outer boundary and
                any subsequent rings (if present) are inner holes. Each ring should be a sequence of points, in either
                clockwise or counterclockwise order. Each point should be a sequence of three floats, representing the
                point's x, y, and z coordinates, in that order.

        Returns:
            The feature ID assigned to the polygon.
        """
        _validate_polygon(polygon)
        tuple_polygon = tuple(tuple((float(x), float(y), float(z)) for (x, y, z) in ring) for ring in polygon)
        polygon = self._get_polygon(tuple_polygon)
        return polygon.id

    def _get_polygon(self, polygon: IPolygon) -> DoPolygon:
        """Create or retrieve the data_objects polygon for the provided polygon."""
        do_rings = []
        for ring in polygon:
            is_outer_ring = len(do_rings) == 0
            do_ring = []
            directions = []
            for start, end in pairwise(ring):
                start, end, swapped = _sort(start, end)
                arc = self._get_arc((start, end), is_outer_ring)
                do_ring.append(arc)
                directions.append(swapped)
            if ring[0] != ring[-1]:
                start = ring[-1]
                end = ring[0]
                start, end, swapped = _sort(start, end)
                arc = self._get_arc((start, end), is_outer_ring)
                do_ring.append(arc)
                directions.append(swapped)
            do_rings.append((do_ring, directions))

        feature_id = next(self._polygon_id)
        do_polygon = DoPolygon(feature_id=feature_id)
        arcs, directions = do_rings[0]
        do_polygon.set_arcs(arcs, directions)
        if len(do_rings) > 1:
            arcs = [arcs for arcs, _directions in do_rings[1:]]
            directions = [directions for _arcs, directions in do_rings[1:]]
            do_polygon.set_interior_arcs(arcs, directions)

        self._polygons[polygon] = do_polygon
        return do_polygon

    def build(self) -> Coverage:
        """Builds the coverage."""
        coverage = Coverage(name=self._name, uuid=self._uuid, projection=self._projection)

        if self._points:
            points = list(self._points.values())
            coverage.set_points(points)

        if self._arcs:
            # If we add arcs that are in use in a polygon's outer ring, then the coverage will end up with duplicates.
            non_outer_arcs = [do_arc for arc, do_arc in self._arcs.items()]
            coverage.arcs = non_outer_arcs

        if self._polygons:
            polygons = list(self._polygons.values())
            coverage.polygons = polygons

        coverage.complete()

        return coverage


def _sort(start: IPoint, end: IPoint) -> tuple[IPoint, IPoint, bool]:
    """Swap start,end if end comes before start, else return them in the same order as passed."""
    swapped = end < start
    if swapped:
        return end, start, True
    else:
        return start, end, False


def _validate_point(point: Point):
    """Check that a point is valid."""
    if len(point) != 3:
        raise AssertionError('Points must have exactly three components.')


def _validate_arc(arc: Arc):
    """Check that an arc is valid."""
    if len(arc) < 2:
        raise AssertionError('Arcs must have at least two points.')

    for point in arc:
        _validate_point(point)


def _validate_ring(ring: Ring):
    """Check that a polygon ring is valid."""
    if len(ring) < 3:
        raise AssertionError('Polygon rings must have at least three points.')
    _validate_arc(ring)


def _validate_polygon(polygon: Polygon):
    """Check that a polygon is valid."""
    if len(polygon) == 0:
        raise AssertionError('Polygons must have at least one ring.')

    for ring in polygon:
        _validate_ring(ring)
