"""CoverageWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"


# 1. Standard Python modules
from pathlib import Path
import uuid

# 2. Third party modules
from geopandas import GeoDataFrame
import h5py
import numpy as np
from pyproj.enums import WktVersion

# 3. Aquaveo modules

# 4. Local modules
from .coverage_reader import (ARC_END_NODE, ARC_ID, ARC_START_NODE, POINT_ID, POINT_X, POINT_Y, POINT_Z,
                              POLY_ARC_DIRECTION, POLY_ARC_ID, POLY_ID, POLY_INNER_POLY, VERTEX_ARC_ID, VERTEX_X,
                              VERTEX_Y, VERTEX_Z)


def _write_polygons(polygons: GeoDataFrame, group: h5py.Group):
    ids = []
    arc_ids = []
    arc_directions = []
    inner_poly = []
    for poly in polygons.itertuples():
        inner_id = 0
        for arc_id, arc_direction in zip(poly.polygon_arc_ids, poly.polygon_arc_directions):
            ids.append(poly.id)
            arc_ids.append(arc_id)
            arc_directions.append(arc_direction)
            inner_poly.append(inner_id)
        for hole, hole_directions in zip(poly.interior_arc_ids, poly.interior_arc_directions):
            inner_id += 1
            for arc_id, arc_direction in zip(hole, hole_directions):
                ids.append(poly.id)
                arc_ids.append(arc_id)
                arc_directions.append(arc_direction)
                inner_poly.append(inner_id)
    group.create_dataset(POLY_ID, data=np.array(ids, dtype=np.int32))
    group.create_dataset(POLY_ARC_ID, data=np.array(arc_ids, dtype=np.int32))
    group.create_dataset(POLY_ARC_DIRECTION, data=np.array(arc_directions, dtype=np.int32))
    group.create_dataset(POLY_INNER_POLY, data=np.array(inner_poly, dtype=np.int32))


def _write_arcs(arcs: GeoDataFrame, group: h5py.Group):
    ids = []
    start_node_ids = []
    end_node_ids = []
    for arc in arcs.itertuples():
        ids.append(arc.id)
        start_node_ids.append(arc.start_node)
        end_node_ids.append(arc.end_node)
    group.create_dataset(ARC_ID, data=np.array(ids, dtype=np.int32))
    group.create_dataset(ARC_START_NODE, data=np.array(start_node_ids, dtype=np.int32))
    group.create_dataset(ARC_END_NODE, data=np.array(end_node_ids, dtype=np.int32))


def _write_points(points: GeoDataFrame, group: h5py.Group):
    ids = []
    x_locations = []
    y_locations = []
    z_locations = []
    for point in points.itertuples():
        ids.append(point.id)
        x_locations.append(point.geometry.x)
        y_locations.append(point.geometry.y)
        z_locations.append(point.geometry.z)
    group.create_dataset(POINT_ID, data=np.array(ids, dtype=np.int32))
    group.create_dataset(POINT_X, data=np.array(x_locations))
    group.create_dataset(POINT_Y, data=np.array(y_locations))
    group.create_dataset(POINT_Z, data=np.array(z_locations))


def _write_vertices(arcs: GeoDataFrame, group: h5py.Group):
    arc_ids = []
    x_locations = []
    y_locations = []
    z_locations = []
    for arc in arcs.itertuples():
        for vertex in list(arc.geometry.coords)[1:-1]:
            arc_ids.append(arc.id)
            x_locations.append(vertex[0])
            y_locations.append(vertex[1])
            z_locations.append(vertex[2])
    if arc_ids:
        group.create_dataset(VERTEX_ARC_ID, data=np.array(arc_ids, dtype=np.int32))
        group.create_dataset(VERTEX_X, data=np.array(x_locations))
        group.create_dataset(VERTEX_Y, data=np.array(y_locations))
        group.create_dataset(VERTEX_Z, data=np.array(z_locations))


def _convert_to_np_ascii(text: str) -> np.array:
    """Converts a text string to a np.array to write to an HDF5 file."""
    ascii_string = text
    string_length = len(ascii_string) + 1
    return np.array([ascii_string], dtype=f'S{string_length}')


class CoverageWriter:
    """Class to write a coverage to an H5 file."""

    def __init__(self, file_path: str, default_wkt: str = ''):
        """Construct a CoverageWriter object.

        Args:
            file_path (str): The filename to write
            default_wkt (str): The coverage projection's WKT information
        """
        self._file_path = file_path
        self._default_wkt = default_wkt

    def write(self, coverage: GeoDataFrame):
        """Writes a Pandas GeoDataFrame coverage to an HDF5 file.

        Args:
            coverage (GeoDataFrame): The coverage.
        """
        self._create_output_directory()
        with h5py.File(self._file_path, 'w') as file:
            file.create_dataset('File Type', data=_convert_to_np_ascii('Xmdf'))
            file.create_dataset('File Version', data=np.array([99.99], dtype=np.float32))
            map_group = file.create_group('Map Data')
            map_group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            map_group.create_dataset('Number', data=np.array([1], dtype=np.int32))
            coverage_group = map_group.create_group('Coverage1')
            coverage_group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            crs = ''
            if coverage.crs is not None:
                crs = coverage.crs.to_wkt(WktVersion.WKT1_GDAL)
            if not crs:
                crs = self._default_wkt
            if crs:
                coordinate_group = coverage_group.create_group('Coordinates')
                coordinate_group.attrs['Grouptype'] = _convert_to_np_ascii('Coordinates')
                coordinate_group.attrs['Version'] = np.array([2], dtype=np.int32)
                coordinate_group.attrs['WKT'] = _convert_to_np_ascii(crs)
            cov_uuid = str(uuid.uuid4())
            if 'uuid' in coverage.attrs:
                cov_uuid = coverage.attrs['uuid']
            coverage_group.create_dataset('GUID', data=_convert_to_np_ascii(cov_uuid))
            cov_name = 'default coverage'
            if 'name' in coverage.attrs:
                cov_name = coverage.attrs['name']
            coverage_group.create_dataset('Name', data=_convert_to_np_ascii(cov_name))
            coverage_group.create_dataset('ReadAs', data=_convert_to_np_ascii('Coverage'))
            # This code does not currently support Arc Groups
            group = coverage_group.create_group('Arc Groups')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            # Polygons
            group = coverage_group.create_group('Polygons')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            polygons = coverage[coverage['geometry_types'] == 'Polygon']
            if len(polygons):
                _write_polygons(polygons, group)
            # Arcs
            group = coverage_group.create_group('Arcs')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            arcs = coverage[coverage['geometry_types'] == 'Arc']
            if len(arcs):
                _write_arcs(arcs, group)
            # Points
            group = coverage_group.create_group('Points')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            points = coverage[coverage['geometry_types'] == 'Point']
            if len(points):
                _write_points(points, group)
            # Nodes
            group = coverage_group.create_group('Nodes')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            nodes = coverage[coverage['geometry_types'] == 'Node']
            if len(nodes):
                _write_points(nodes, group)
            # Vertices
            group = coverage_group.create_group('Vertices')
            group.attrs['Grouptype'] = _convert_to_np_ascii('Generic')
            if len(arcs):
                _write_vertices(arcs, group)

    def _create_output_directory(self):
        filepath = Path(self._file_path)
        # Get the directory part only
        directory = filepath.parent
        # Create directory if it does not exist
        directory.mkdir(parents=True, exist_ok=True)

    def export_to_ascii(self, coverage: GeoDataFrame):
        """Exports the coverage to filename as ASCII to help in debugging.

        Args:
            coverage: The coverage.
        """
        self._create_output_directory()
        with open(self._file_path, 'w') as file:
            if 'name' in coverage.attrs:
                file.write(f'Name: {coverage.attrs["name"]}\n')

            # Points
            disjoint_pts = coverage[coverage['geometry_types'] == 'Point']
            file.write('Disjoint Points\n')
            file.write('ID\tX\tY\tZ\n')
            for pt in disjoint_pts.itertuples():
                file.write(f'{pt.id}\t{pt.geometry.x}\t{pt.geometry.y}\t{pt.geometry.z}\n')

            # Nodes
            nodes = coverage[coverage['geometry_types'] == 'Node']
            file.write('\nNodes\n')
            file.write('ID\tX\tY\tZ\n')
            for pt in nodes.itertuples():
                file.write(f'{pt.id}\t{pt.geometry.x}\t{pt.geometry.y}\t{pt.geometry.z}\n')

            # Arcs
            arcs = coverage[coverage['geometry_types'] == 'Arc']
            file.write('\nArcs\n')
            for arc in arcs.itertuples():
                file.write(f'Arc ID: {arc.id}\t')
                file.write(f'Node1 ID: {arc.start_node}\t')
                file.write(f'Node2 ID: {arc.end_node}\n')
                file.write('Vertices\n')
                file.write('X\tY\tZ\n')
                vertices = arc.geometry.coords[1:-1]
                for pt in vertices:
                    file.write(f'{pt[0]}\t{pt[1]}\t{pt[2]}\n')

            # Polygons
            file.write('\nPolygons\n')
            polygons = coverage[coverage['geometry_types'] == 'Polygon']
            for poly in polygons.itertuples():
                file.write(f'Polygon ID: {poly.id}\n')
                file.write(f'Outer arcs: {" ".join([str(arc_id) for arc_id in poly.polygon_arc_ids])}\n')
                for i, inner_poly in enumerate(poly.interior_arc_ids):
                    file.write(f'Inner poly {i}: {" ".join([str(arc_id) for arc_id in inner_poly])}\n')

    def export_to_map(self, coverages: list[GeoDataFrame], write_uuid: bool = False):
        """Exports a .map file given the list of coverages.

        Args:
            coverages: List of coverages.
            write_uuid: If True, the coverage uuid is written.
        """
        with open(self._file_path, 'w') as file:
            file.write('MAP\n')
            file.write('VERSION 5.0\n')
            for coverage in coverages:
                file.write('BEGCOV\n')
                if 'name' in coverage.attrs:
                    file.write(f'COVNAME "{coverage.attrs["name"]}"\n')
                if write_uuid and 'uuid' in coverage.attrs:
                    file.write(f'COVGUID "{coverage.attrs["uuid"]}"\n')

                # Points
                disjoint_pts = coverage[coverage['geometry_types'] == 'Point']
                for pt in disjoint_pts.itertuples():
                    file.write('POINT\n')
                    file.write(f'XY {pt.geometry.x} {pt.geometry.y} {pt.geometry.z}\n')
                    file.write(f'ID {pt.id}\n')
                    file.write('END\n')

                # Nodes
                nodes = coverage[coverage['geometry_types'] == 'Node']
                for pt in nodes.itertuples():
                    file.write('NODE\n')
                    file.write(f'XY {pt.geometry.x} {pt.geometry.y} {pt.geometry.z}\n')
                    file.write(f'ID {pt.id}\n')
                    file.write('END\n')

                # Arcs
                arcs = coverage[coverage['geometry_types'] == 'Arc']
                for arc in arcs.itertuples():
                    file.write('ARC\n')
                    file.write(f'ID {abs(arc.id)}\n')
                    file.write(f'NODES {arc.start_node} {arc.end_node}\n')
                    verts = arc.geometry.coords[1:-1]
                    file.write(f'ARCVERTICES {len(verts)}\n')
                    for v in verts:
                        file.write(f'{v[0]} {v[1]} {v[2]}\n')
                    file.write('END\n')

                # Polygons
                polys = coverage[coverage['geometry_types'] == 'Polygon']
                for poly in polys.itertuples():
                    file.write('POLYGON\n')
                    file.write(f'ID {poly.id}\n')
                    out_arcs = poly.polygon_arc_ids
                    file.write(f'ARCS {len(out_arcs)}\n')
                    for arc_id in out_arcs:
                        file.write(f'{arc_id}\n')
                    for inner_poly in poly.interior_arc_ids:
                        file.write(f'HARCS {len(inner_poly)}\n')
                        for arc_id in inner_poly:
                            file.write(f'{arc_id}\n')
                    file.write('END\n')
                file.write('ENDCOV\n')
