"""Testing utility functions."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import itertools
from pathlib import Path

# 2. Third party modules
from geopandas import GeoDataFrame
import numpy as np
from pyproj.enums import WktVersion
from shapely.coords import CoordinateSequence

# 3. Aquaveo modules
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.vectors import VectorOutput
from xms.tool_core.coverage_builder import CoverageBuilder
from xms.tool_core.coverage_writer import CoverageWriter

# 4. Local modules


def convert_lines_to_coverage(lines: list, name: str, from_wkt: str, to_wkt: str = None) -> GeoDataFrame:
    """Converts the lines to a GeoDataFrame.

    Args:
        lines (list): A list of polylines with each point containing XY or XYZ coordinates as a tuple.
        name (str): The name of the new coverage.
        from_wkt (str): The WKT specifying the projection of the lines passed to this function.
        to_wkt (str): The WKT specifying the display projection.

    Returns:
        (GeoDataFrame): A GeoDataFrame containing the given lines.
    """
    coord_trans, vm = _get_coord_trans_and_vm(from_wkt, to_wkt)
    crs = from_wkt
    if to_wkt:
        crs = to_wkt
    coverage_builder = CoverageBuilder(cov_wkt=crs, cov_name=name)
    for line in lines:
        if coord_trans is not None:
            line = gu.transform_points(line, coord_trans)
        line_pts = [(p[0], p[1], p[2] * vm) for p in line]
        if len(line_pts) > 1:
            coverage_builder.add_arc(line_pts)
    return coverage_builder.build_coverage()


def convert_points_to_coverage(points: list, name: str, from_wkt: str, to_wkt: str = None) -> GeoDataFrame:
    """Converts the points to a GeoDataFrame.

    Args:
        points (list): A list of points with each point containing XY or XYZ coordinates as a tuple.
        name (str): The name of the new coverage.
        from_wkt (str): The WKT specifying the projection of the points passed to this function.
        to_wkt (str): The WKT specifying the display projection.

    Returns:
        (GeoDataFrame): A GeoDataFrame containing the given points.
    """
    coord_trans, vm = _get_coord_trans_and_vm(from_wkt, to_wkt)
    crs = from_wkt
    if to_wkt:
        crs = to_wkt
    coverage_builder = CoverageBuilder(cov_wkt=crs, cov_name=name)
    for point in points:
        if coord_trans is not None:  # pragma: no cover - reenable with test_watershed_from_raster_tool.py
            point = gu.transform_points([point], coord_trans)[0]
        coverage_builder.add_point((point[0], point[1], point[2] * vm))
    return coverage_builder.build_coverage()


def convert_polygons_to_coverage(polygons: list, name: str, from_wkt: str, to_wkt: str = None) -> GeoDataFrame:
    """Converts the polygons to a GeoDataFrame.

    Right now, this function only converts the outside polygons to a GeoDataFrame.  If we ever need to handle polygon
    holes, we'll need to write code to convert the points representing these holes (the rings following the first ring
    in the lists of points for each polygon) to interior arcs for each polygon.

    Args:
        polygons (list): A list of polygons with each point containing XY or XYZ coordinates as a tuple.
        name (str): The name of the new coverage.
        from_wkt (str): The WKT specifying the projection of the polygons passed to this function.
        to_wkt (str): The WKT specifying the display projection.

    Returns:
        (GeoDataFrame): A GeoDataFrame containing the given polygons.
    """
    # Convert the polygons from the projection specified by from_wkt to the projection specified by to_wkt
    coord_trans, vm = _get_coord_trans_and_vm(from_wkt, to_wkt)
    crs = from_wkt
    if to_wkt:
        crs = to_wkt
    coverage_builder = CoverageBuilder(cov_wkt=crs, cov_name=name)
    for polygon in polygons:
        # Create point objects for the outer arc/ring
        if coord_trans is not None:  # pragma: no cover - reenable with test_watershed_from_raster_tool.py
            polygon[0] = gu.transform_points(polygon[0], coord_trans)
        polygon_pts = [(p[0], p[1], p[2] * vm) for p in polygon[0]]
        coverage_builder.add_polygon([polygon_pts])
    return coverage_builder.build_coverage()


def shapely_polygon_list_to_polygons(polygon_list):
    """Converts a shapely polygon list to a list of polygons.

    Args:
        polygon_list (list): A shapely polygon list.

    Returns:
        (list): A 3D list of tuples containing points making up all the polygons in the shapely polygon list.
    """
    polygons = []
    for polygon in polygon_list:
        poly_pts = [(p[0], p[1], 0.0) if len(p) < 3 else (p[0], p[1], p[2]) for p in polygon.exterior.coords]
        cur_poly = [poly_pts]
        polygons.append(cur_poly)
    return polygons


def points_to_shapefile(points, filename, to_wkt, from_wkt='', strip_vertical=False):
    """Converts coverage points to a shapefile.

    Args:
        points (GeoDataFrame): The points from a coverage.
        filename (str): The shapefile filename.
        to_wkt (str): The WKT of the projection to convert to.
        from_wkt (str): The WKT of the projection to convert to.
        strip_vertical (bool): Whether to strip the vertical projection.
    """
    vo = _initialize_file(filename, to_wkt, from_wkt, strip_vertical)
    for point in points.itertuples():
        vo.write_point([point.geometry.x, point.geometry.y])
    # Set vo to None to flush the file buffer
    vo = None


def arcs_to_shapefile(arcs, filename, to_wkt, from_wkt='', strip_vertical=False):
    """Converts coverage arcs to a shapefile.

    Args:
        arcs (GeoDataFrame): The arcs from a coverage.
        filename (str): The shapefile filename.
        to_wkt (str): The WKT of the projection to convert to.
        from_wkt (str): The WKT of the projection to convert to.
        strip_vertical (bool): Whether to strip the vertical projection.
    """
    vo = _initialize_file(filename, to_wkt, from_wkt, strip_vertical)
    for arc in arcs.itertuples():
        arc_points = arc.geometry.coords
        pts = [[pt[0], pt[1], pt[2]] for pt in arc_points]
        vo.write_arc(pts)
    # Set vo to None to flush the file buffer
    vo = None


def polygons_to_shapefile(polygons, filename, to_wkt, from_wkt='', strip_vertical=False):
    """Converts coverage polygons to a shapefile.

    Args:
        polygons (tuple): The tuple containing a list of polygons and their holes from the coverage.
        filename (str): The shapefile filename.
        to_wkt (str): The WKT of the projection to convert to.
        from_wkt (str): The WKT of the projection to convert to.
        strip_vertical (bool): Whether to strip the vertical projection
    """
    vo = _initialize_file(filename, to_wkt, from_wkt, strip_vertical)
    # Add coverage polygons as new features in the layer
    for holes_data, poly_data in zip(polygons[0], polygons[1]):
        # Create any interior rings (if there)
        holes = []
        for hole in holes_data:
            holes.append(hole['poly_pts'])
        vo.write_polygon(poly_data['poly_pts'], holes)
    # Set vo to None to flush the file buffer
    vo = None


def get_arcs_from_coverage(arcs_coverage: GeoDataFrame, to_wkt=None):
    """Gets the arcs from the given coverage.

    Args:
        arcs_coverage (GeoDataFrame): A coverage (GeoDataFrame) containing arcs.
        to_wkt (str): An optional WKT string used to convert the points.

    Return:
        (list): A list of arcs with their points and IDs.
    """
    from_wkt = ''
    if arcs_coverage is not None and arcs_coverage.crs is not None:
        from_wkt = arcs_coverage.crs.to_wkt(version=WktVersion.WKT1_GDAL)
    coord_trans, vm = _get_coord_trans_and_vm(from_wkt, to_wkt)
    arc_data = []
    if arcs_coverage is not None:
        arcs = arcs_coverage[arcs_coverage['geometry_types'] == 'Arc']
        for arc in arcs.itertuples():
            arc_points = list(arc.geometry.coords)
            pts = [[x, y, z * vm] for (x, y, z) in arc_points]
            if coord_trans is not None:
                pts = gu.transform_points(pts, coord_trans)
            arc_data.append({'id': arc.id,
                             'arc_pts': pts,
                             'cov_geom': arcs_coverage})
    return arc_data


def get_arcs_from_list(arcs_list):  # pragma no cover (runs in Linux tests)
    """Gets the arcs from the given list of arcs.

    Args:
        arcs_list (list): A list of arcs.

    Return:
        (list): A list of arcs with their points and IDs.
    """
    arc_data = []
    if arcs_list is not None:
        for idx, arc_points in enumerate(arcs_list):
            arc_data.append({'id': idx + 1,
                             'arc_pts': arc_points,
                             'cov_geom': None})
    return arc_data


def get_polygons_from_coverage(polygons_coverage: GeoDataFrame):
    """Gets the polygons from the given coverage.

    Args:
        polygons_coverage (GeoDataFrame): A coverage containing polygons.

    Return:
        (tuple): A list of polygons with their holes, if any.
    """
    poly_data = []
    hole_data = []
    if polygons_coverage is not None:
        polygons = polygons_coverage[polygons_coverage['geometry_types'] == 'Polygon']
        for poly in polygons.itertuples():
            pts = [list(coord) for coord in poly.geometry.exterior.coords]
            pts = [key for key, grp in itertools.groupby(pts)]
            poly_data.append({'id': poly.id,
                              'poly_pts': pts,
                              'cov_geom': polygons_coverage})
            interior_holes = []
            for hole in poly.geometry.interiors:
                pts = [list(coord) for coord in hole.coords]
                pts = [key for key, grp in itertools.groupby(pts)]
                interior_holes.append({'id': poly.id,
                                       'poly_pts': pts,
                                       'cov_geom': polygons_coverage})
            hole_data.append(interior_holes)
    return hole_data, poly_data


def parallel_polygon_perimeters(cov: GeoDataFrame):
    """Returns the dict of location lists for each poly in the coverage.  For use in run_parallel_points_in_polygon.

    (first is perimeter, the rest are hole defs)

    Args:
        cov (GeoDataFrame): An activity coverage.

    Return:
        dict: A dictionary of the location lists for each polygon in the coverage with the poly ID as the key.
    """
    all_polys = {}
    holes_data, polys_data = get_polygons_from_coverage(cov)
    for poly, holes in zip(polys_data, holes_data):
        perim_lists = [poly['poly_pts']]
        for cur_hole in holes:
            perim_lists.append(cur_hole['poly_pts'])
        poly_perims = []
        for loc_list in perim_lists:
            poly_perims.append(np.array([(p[0], p[1]) for p in loc_list]))
        all_polys[poly['id']] = poly_perims
    return all_polys


def get_polygons_from_list(polygons_list):  # pragma no cover (runs in Linux tests)
    """Gets polygons from the given list of polygons.

    Args:
        polygons_list (list): A list of polygons.

    Return:
        (tuple): A list of polygons with their holes, if any.
    """
    poly_data = []
    hole_data = []
    if polygons_list is not None:
        for idx, poly in enumerate(polygons_list):
            pts = [key for key, grp in itertools.groupby(poly[0])]
            poly_data.append({'id': idx + 1,
                              'poly_pts': pts,
                              'cov_geom': None})
            interior_holes = []
            for hole in poly[1:]:
                pts = [key for key, grp in itertools.groupby(hole)]
                interior_holes.append({'id': idx + 1,
                                       'poly_pts': pts,
                                       'cov_geom': None})
            hole_data.append(interior_holes)
    return hole_data, poly_data


def make_poly_clockwise(poly):
    """Makes the polygon clockwise.

    Last point expected to be repeat of first point.

    Args:
        poly (list[tuple]): List of xyz points, e.g. [(1,2,3), (4,5,6)]
    """
    area = poly_area_x2(poly)
    if area > 0.0:
        poly.reverse()


def make_poly_counter_clockwise(poly):
    """Makes the polygon clockwise.

    Last point expected to be repeat of first point.

    Args:
        poly (list[tuple]): List of xyz points, e.g. [(1,2,3), (4,5,6), (1,2,3)]
    """
    area = poly_area_x2(poly)
    if area < 0.0:
        poly.reverse()


def poly_area_x2(poly):
    """Computes 2 times the area of the poly (2 times because its faster).

    Last point expected to be repeat of first point.
    Positive area if counterclockwise polygon, otherwise area is negative.

    This doesn't include a translation to the origin to improve accuracy as
    we don't need it to be that accurate.

    Args:
        poly: A poly (list of point indexes).

    Returns:
        The area.
    """
    if len(poly) < 3:
        return 0.0

    area = 0.0
    for i in range(len(poly) - 1):
        area += (poly[i][0] * poly[i + 1][1])
        area -= (poly[i][1] * poly[i + 1][0])
    return area


def get_poly_points(coords: CoordinateSequence):
    """Returns a list of the polygon's x,y,z points.

    Last point is a repeat of the first point.

    Args:
        coords (CoordinateSequence): The coordinate sequence.

    Returns:
        (list of tuple of float): See description.
    """
    pts = [tuple(coord) for coord in coords]
    pts = [key for key, grp in itertools.groupby(pts)]
    return pts


def get_polygon_point_lists(df_polygon):
    """Returns a list of lists of points defining a polygon with possible inner holes.

    Last points are repeats of the first points.

    Args:
        df_polygon: the polygon row from the dataframe.

    Returns:
        (list of lists of tuple): The lists of points defining the polygon.
    """
    polygon = []
    # Outer arcs
    poly = get_poly_points(df_polygon.geometry.exterior.coords)
    make_poly_clockwise(poly)
    polygon.append(poly)
    for hole in df_polygon.geometry.interiors:
        poly = get_poly_points(hole.coords)
        make_poly_counter_clockwise(poly)
        polygon.append(poly)
    return polygon


def export_coverage_to_ascii(filename: str | Path, coverage: GeoDataFrame):
    """Exports the coverage to filename as ASCII to help in debugging.

    Args:
        filename: Filepath to ASCII file.
        coverage: The coverage.
    """
    writer = CoverageWriter(filename)
    writer.export_to_ascii(coverage)


def export_map_file(file_path: str | Path, coverages: list[GeoDataFrame], write_uuid: bool = False) -> None:
    """Exports a .map file given the list of coverages.

    Args:
        file_path: File path to .map file that will be created.
        coverages: List of coverages.
        write_uuid: If True, the coverage uuid is written.
    """
    writer = CoverageWriter(file_path)
    writer.export_to_map(coverages, write_uuid)


def export_coverage_to_map(filename: str | Path, coverage: GeoDataFrame, write_uuid: bool = False):
    """Exports the coverage to filename as xms map file to help in debugging.

    Args:
        filename: Filepath to map file.
        coverage: The coverage.
        write_uuid: If True, the coverage uuid is written.
    """
    export_map_file(filename, [coverage], write_uuid)


def _initialize_file(filename, to_wkt, from_wkt, strip_vertical):
    vo = VectorOutput()
    if strip_vertical:
        if gu.valid_wkt(from_wkt):
            from_wkt = gu.strip_vertical(from_wkt)
        if gu.valid_wkt(to_wkt):
            to_wkt = gu.strip_vertical(to_wkt)
        else:
            to_wkt = from_wkt
    vo.initialize_file(filename, to_wkt, from_wkt=from_wkt)
    return vo


def _get_coord_trans_and_vm(from_wkt, to_wkt):
    # Convert the lines from the projection specified by from_wkt to the projection specified by to_wkt
    coord_trans = None
    vm = 1.0
    if from_wkt and to_wkt is not None:
        from_wkt_no_vert = gu.strip_vertical(from_wkt)
        to_wkt_no_vert = gu.strip_vertical(to_wkt)
        coord_trans = gu.get_coordinate_transformation(from_wkt_no_vert, to_wkt_no_vert)
        vm = gu.get_vertical_multiplier(gu.get_vert_unit_from_wkt(from_wkt), gu.get_vert_unit_from_wkt(to_wkt))
    return coord_trans, vm
