"""Merges polygons based on a classification."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import time

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.tool.algorithms.coverage.polygon_coverage_builder import PolygonCoverageBuilder
from xms.tool.utilities.coverage_conversion import poly_area_x2


class GridCellToPolygonCoverageBuilder:
    """Given a grid and integer cell dataset, computes polygonal boundaries of areas with the same dataset value.

    - output is a pandas GeoDataFrame coverage
    - create_polygons_and_build_coverage: Builds the Coverage.
    - find_polygons: Just finds the polygons

    Glossary:
        I've tried to be consistent with the following terms to help avoid
        confusion:

        poly = A poly is a list of ugrid point indexes with the last point
        equal to the first point. Poly points can be in any order, clockwise
        or counterclockwise: [9, 3, 1, 6, 9]

        polygon = A polygon is a list of poly, outer first, then inner holes
        if any: [[3, 6, 8, 2, 3], [6, 7, 8, 6]]

        multipoly = list of polygon, each one associated with the same dataset value.
    """
    def __init__(self, co_grid, dataset_values, wkt, coverage_name, null_value=None, logger=None):
        """Initializes the class.

        Args:
            co_grid (xms.constraint.ugrid2d.UGrid2d): The grid.
            dataset_values (list of int): Cell dataset of integers.
            wkt: The map projection's WKT.
            coverage_name (str): Name to be given to the new coverage.
            null_value (Optional[object]): A dataset value to exclude from polygon building. Any grid cells with this
                assigned value will not be merged into the polygons of the output coverage.
            logger (Optional[Logger]): The logger to use. If not provided will be the xms.tool logger.
        """
        self._ugrid = co_grid.ugrid
        self._dataset_values = dataset_values
        self._wkt = wkt
        self._coverage_name = coverage_name
        self._null_value = null_value
        self._logger = logger if logger is not None else logging.getLogger('xms.tool')

        # Stuff for extracting polygons (list of point ids)
        self._unvisited_cells = set()  # Set of unvisited cells (cell indexes)
        self._ugrid_pt_locs = self._ugrid.locations  # xyz locations of grid points

        self.log_interval = 2.0  # Log messages every this many seconds
        self._last_time = time.time()

    def create_polygons_and_build_coverage(self, hash_arcs=True):
        """Creates a Coverage with Polygons made from areas where cells have the same dataset value.

        Args:
            hash_arcs (bool): If True, will hash arcs to eliminate duplicates.

        Returns:
            GeoDataFrame: See description
        """
        self._logger.info('Creating polygons from grid cell assignments...')
        dataset_multipolys = self.find_polygons()
        if self._null_value is not None:
            dataset_multipolys.pop(self._null_value, None)
        return self._build_coverage(dataset_multipolys, hash_arcs)

    def find_polygons(self):
        """Finds boundaries of areas where cells have the same dataset value.

        Returns:
            dataset_multipolys: Dict of dataset value -> multipoly. A multipoly
            is a list of polygon. A polygon is a list of poly, outer first,
            then inner holes if any. A poly is a list of ugrid point indexes
            with the last point equal to the first point. Poly points can be
            in any order, clockwise or counterclockwise.
        """
        self._logger.info('Finding contiguous cells with same category assignments...')
        self._unvisited_cells = {s for s in range(self._ugrid.cell_count)}
        dataset_multipolys = {}  # Dict of dataset value -> multipolygons

        # Visit all cells in the grid
        while self._unvisited_cells:
            start_cell = self._unvisited_cells.pop()
            value = self._dataset_values[start_cell]

            # Aggregate all adjacent cells with same dataset value and add edges
            # to pt_map
            pt_map = {}  # Dict of point index -> set of edge adjacent point indices
            stack = [start_cell]  # Stack of cell indexes used to recurse on adjacent cells
            self._unvisited_cells.add(start_cell)
            i = 0
            while stack:
                cell = stack.pop()
                if cell not in self._unvisited_cells:
                    continue
                self._add_needed_adjacent_edges(cell, value, stack, pt_map)
                self._mark_cell_as_visited(cell)

                if i % 100 == 0:  # Only check the logging timer every 100 iterations of this loop.
                    self._do_progress()
                i += 1

            # Get polygons from edges and save them
            if value not in dataset_multipolys:
                self._logger.info(f'Found cells with category assignment: {value}')
                dataset_multipolys[value] = []
            polygon = self._polygon_from_edges(pt_map)
            dataset_multipolys[value].append(polygon)

        return dataset_multipolys

    def _do_progress(self):
        """Will log progress every self.log_interval seconds."""
        num_cells = self._ugrid.cell_count
        now = time.time()
        if now - self._last_time > self.log_interval:
            self._last_time = now
            self._logger.info(f'Visited {num_cells - len(self._unvisited_cells)} of {num_cells} cells...')

    def _add_needed_adjacent_edges(self, cell, value, stack, pt_map):
        """Adds outer edges to pt_map and edges between adjacent cells with a different dataset value.

        Also may grow the stack of more adjacent cells with the same dataset value are found.

        Args:
            cell (int): Index of a cell.
            value (int): Current dataset value of region we are aggregating.
            stack (list of int): Stack of adjacent cells sharing the same dataset value.
            pt_map ({pt_idx: {adjacent_point_indices}}): Mapping of point indices to set of adjacent point indices
        """
        # Add edge if adjacent element doesn't exist or it's value doesn't match
        cell_edges = self._ugrid.get_cell_edges(cell)
        for edge_idx in range(len(cell_edges)):
            adj_cell = self._ugrid.get_cell_2d_edge_adjacent_cell(cell, edge_idx)
            if adj_cell >= 0 and self._dataset_values[adj_cell] == value:  # values match
                # Add adj_cell to stack if we haven't visited it already
                if adj_cell in self._unvisited_cells:
                    stack.append(adj_cell)
            else:
                # Add edge nodes to point map
                node0 = cell_edges[edge_idx][0]
                node1 = cell_edges[edge_idx][1]

                if node0 not in pt_map:
                    pt_map[node0] = set()
                pt_map[node0].add(node1)

                if node1 not in pt_map:
                    pt_map[node1] = set()
                pt_map[node1].add(node0)

    def _mark_cell_as_visited(self, cell):
        """Marks the cells as visited by removing it from the set of unvisited cells.

        Args:
            cell (int): Index of a cell.
        """
        if cell in self._unvisited_cells:
            self._unvisited_cells.remove(cell)

    def _polygon_from_edges(self, pt_map):
        """Returns the polygon (outer and inners) defined by the pt_map.

        Args:
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            polygon
        """
        polygon = self._polygon_from_point_map(pt_map)
        self._sort_outer_poly_first(polygon)
        return polygon

    def _sort_outer_poly_first(self, polygon):
        """Finds the outer poly and puts it first in the list of polys.

        Outer poly is the one with the greatest area.

        Args:
            polygon: list of poly

        Returns:
            polygon
        """
        if len(polygon) < 2:
            return

        max_area = -1.0
        max_poly = 0
        for i, poly in enumerate(polygon):
            point_locations = [self._ugrid_pt_locs[point] for point in poly]
            area = abs(poly_area_x2(point_locations))
            if abs(area) > max_area:
                max_area = area
                max_poly = i

        if max_poly != 0:
            polygon[max_poly], polygon[0] = polygon[0], polygon[max_poly]

    def _polygon_from_point_map(self, pt_map):
        """Returns the polygon (with outer and inners) defined by the pt_map.

        Args:
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            polygon: list of poly, outer first, then inner holes.
        """
        polygon = []
        while pt_map:
            list_index = {}  # Index of point in poly, so we know if we've made a loop
            poly = []  # List of point indexes forming a loop

            # Start poly with arbitrary first point in map
            next_point = next(iter(pt_map))  # Get any point in the map
            poly.append(next_point)
            list_index[next_point] = len(poly) - 1

            done = False
            while not done:

                # Get next point
                last_point = poly[-1]
                next_point = self._next_point(last_point, pt_map)

                # Cleanup
                pt_map[last_point].remove(next_point)
                if not pt_map[last_point]:
                    del pt_map[last_point]
                pt_map[next_point].remove(last_point)
                if not pt_map[next_point]:
                    del pt_map[next_point]

                # Stop if we're back to the beginning
                if next_point == poly[0]:
                    done = True
                elif next_point in list_index:
                    # We've looped back on ourselves. Cut off loop to form a poly
                    new_poly = poly[list_index[next_point]:]
                    new_poly.append(next_point)  # Repeat first point as the last point
                    polygon.append(new_poly)
                    del poly[list_index[next_point]:]  # strip the loop from the current poly

                poly.append(next_point)
                list_index[next_point] = len(poly) - 1

            polygon.append(poly)
        return polygon

    @staticmethod
    def _next_point(last_point, pt_map):
        """Returns the next point we will visit starting from the last point using the pt_map.

        There may be more than one next point to choose from. We just pick the next one in
        the set.

        Args:
            last_point (int): Point index.
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            next_point (int): Point index.
        """
        point_set = pt_map[last_point]
        next_point = -1
        for point in point_set:
            if point != last_point:
                next_point = point
                break
        return next_point

    def _build_coverage(self, dataset_multipolys, hash_arcs):
        """Creates a GeoDataFrame Coverage with the Polygons and Arcs.

        Args:
            dataset_multipolys: Dict of dataset value -> multipoly. A multipoly
             is a list of polygon and a polygon is a list of poly (outer, inners),
             and a poly is a list of ugrid point indexes with the last point
             equal to the first point.
            hash_arcs (bool): If True, will hash arcs to eliminate duplicates.

        Returns:
            GeoDataFrame Coverage.

        """
        self._logger.info('Merging contiguous areas with same category assignment into coverage polygons...')
        builder = PolygonCoverageBuilder(self._ugrid_pt_locs, self._wkt, self._coverage_name, self._logger)
        return builder.build_coverage(dataset_multipolys, hash_arcs)
