"""SelUGridFromCoverage class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import List

# 2. Third party modules
import numpy as np
from shapely import LineString, Polygon

# 3. Aquaveo modules
from xms.gdal.utilities import gdal_utils as gu
from xms.tool_core import Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.geometry.geometry import run_parallel_points_in_polygon
from xms.tool.utilities.coverage_conversion import get_arcs_from_coverage, get_polygons_from_coverage

ARG_INPUT_GRID = 0
ARG_INPUT_COVERAGE = 1
ARG_INPUT_GEOM = 2
ARG_INPUT_OFFSET = 3
ARG_OUTPUT_DATASET = 4

DATASET_TYPE_UNDEFINED = -1
DATASET_TYPE_CELLS = 0
DATASET_TYPE_POINTS = 1


class DatasetFromCoverageTool(Tool):
    """Tool to select mesh nodes."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Dataset from Coverage')
        self._args = []
        self._sel_cov = None
        self._ugrid = None
        self._dataset_name = ''

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        choices = ['Polygons', 'Arcs']
        arguments = [
            self.grid_argument(name='input_grid', description='Grid', io_direction=IoDirection.INPUT),
            self.coverage_argument(name='input_coverage', description='Selection coverage'),
            self.string_argument(name='geometry', description='Geometry', value='Polygons', choices=choices),
            self.float_argument(name='offset', description='Offset (ft or m)', value=5.0, hide=True, optional=True),
            self.dataset_argument(name='dataset_name', description='Dataset name',
                                  value="Dataset", io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        arguments[ARG_INPUT_OFFSET].show = arguments[ARG_INPUT_GEOM].value == 'Arcs'

    def _create_dataset(self, use_arcs, offset):
        """Determine which grid nodes lie within an offset of polygons or arcs."""
        all_poly_locs = []

        # Get grid coordinates
        grid_locs = np.asarray([(x, y) for x, y, *_ in self._ugrid.locations])
        sel_values = np.zeros(len(grid_locs), dtype=float)

        # Polygons
        if use_arcs is False:
            holes_data, polys_data = get_polygons_from_coverage(self._sel_cov)
            for poly, holes in zip(polys_data, holes_data):
                poly_locs = []
                locs_lists = [poly['poly_pts']]
                for cur_hole in holes:
                    locs_lists.append(cur_hole['poly_pts'])
                for locs in locs_lists:
                    poly_locs.append(np.asarray([(loc[0], loc[1]) for loc in locs]))
                all_poly_locs.append(poly_locs)

        # Arcs
        else:
            arcs = get_arcs_from_coverage(self._sel_cov)
            for arc in arcs:
                # point list along the arc
                locations = [(x, y) for (x, y, _) in arc['arc_pts']]

                # Let shapely build polygons with end caps
                buff = LineString(locations).buffer(offset, cap_style=1)  # 0=round, 1=square, 2=flat

                if buff.is_empty:  # zero-length arc - hard to test
                    continue  # pragma no cover

                ring = np.asarray(buff.exterior.coords)

                # Polygon validity check
                poly = Polygon(ring)
                if (not poly.is_valid) or (poly.minimum_clearance < 1e-9):
                    poly = poly.buffer(0)
                    ring = np.asarray(poly.exterior.coords)

                # Ensure counter‑clockwise orientation
                if not Polygon(ring).exterior.is_ccw:
                    ring = ring[::-1]

                all_poly_locs.append([ring])

        # Test points in the polygon
        for poly in all_poly_locs:
            test_ring = Polygon(poly[0]).buffer(1e-9).exterior.coords
            test_ring = np.asarray(test_ring)

            in_poly = list(run_parallel_points_in_polygon(grid_locs, test_ring))

            # subtract any holes
            for hole in poly[1:]:
                in_hole = list(run_parallel_points_in_polygon(grid_locs, hole))
                in_poly = [
                    in_poly[idx] if not in_hole[idx] else 0.0
                    for idx in range(len(in_hole))
                ]

            # flag locked nodes
            for i, inside in enumerate(in_poly):
                if inside:
                    sel_values[i] = 1.0

        # Write the dataset
        writer = self.get_output_dataset_writer(name=self._dataset_name, geom_uuid=self._orig_grid.uuid)
        writer.append_timestep(0.0, sel_values)
        writer.appending_finished()
        self.set_output_dataset(writer)

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        if self._args[ARG_INPUT_COVERAGE].text_value:
            self._sel_cov = self.get_input_coverage(self._args[ARG_INPUT_COVERAGE].value)

        self._orig_grid = self.get_input_grid(self._args[ARG_INPUT_GRID].text_value)
        self._ugrid = self._orig_grid.ugrid
        self._dataset_name = arguments[ARG_OUTPUT_DATASET].value

        geom_choice = arguments[ARG_INPUT_GEOM].value
        factor = 1.0
        if gu.valid_wkt(self.default_wkt):
            # look in GdalUtility.cpp - LaUnit imUnitFromProjectionWKT(const std::string wkt)
            sr = gu.wkt_to_sr(self.default_wkt)
            if sr.IsGeographic():
                if sr.GetAngularUnitsName().upper() in ['DEGREE', 'DS']:
                    # rough conversion from meters to degrees
                    factor = 1.0 / 111000.0

        self._create_dataset(geom_choice == 'Arcs', arguments[ARG_INPUT_OFFSET].value * factor)
