"""CoverageReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules
from geopandas import GeoDataFrame
import h5py
from shapely import LineString, Point, Polygon

# 3. Aquaveo modules

# 4. Local modules

POLY_ID = 'Id'
POLY_ARC_ID = 'Arc Id'
POLY_ARC_DIRECTION = 'ArcDirection'
POLY_INNER_POLY = 'Inner Poly'

ARC_ID = 'Id'
ARC_START_NODE = 'Node1 Id'
ARC_END_NODE = 'Node2 Id'

POINT_ID = 'Id'
POINT_X = 'X Locations'
POINT_Y = 'Y Locations'
POINT_Z = 'Z Locations'

VERTEX_ARC_ID = 'Arc Id'
VERTEX_X = 'X Locations'
VERTEX_Y = 'Y Locations'
VERTEX_Z = 'Z Locations'


def read_coverage_group_name(h5_file):
    """Reads and returns the coverage name from the h5 file.

    Args:
        h5_file: h5 file object.

    Returns:
        (str): The coverage name.
    """
    cov_names = h5_file['Map Data'].keys()
    if not cov_names:
        return ''
    group_name = None
    for c_name in cov_names:
        group_name = c_name
        break
    return group_name


def _read_attributes(h5_group) -> dict:
    """Read all attributes from an h5py group into a Python dictionary."""
    # Loop through items
    atts = {}
    for name, item in h5_group.items():
        if isinstance(item, h5py.Dataset):
            # Read dataset values as a NumPy array
            atts[name] = [i.decode('utf-8') if isinstance(i, bytes) else i for i in item]
    return atts


def _coverage_attributes_from_h5(coverage_group):
    atts = dict()
    if 'Attributes' in coverage_group:
        attributes_group = coverage_group['Attributes']
        atts = _read_attributes(attributes_group)
    if 'CoverageType' in coverage_group.attrs:
        atts['CoverageType'] = coverage_group.attrs['CoverageType'][0].decode('utf-8')
    return atts


class CoverageReader:
    """Class to read a coverage from an H5 file as a GeoDataFrame."""

    def __init__(self, file_path: str, default_wkt: str = '', atts=None):
        """Construct a CoverageReader object.

        Args:
            file_path (str): The filename to read
            default_wkt (str): The default WKT for the coverage's projection
            atts (dict): The dict specifying the attributes for the coverage
        """
        if atts is None:
            atts = {}
        self._file_path = file_path
        self._polygon_list = []
        self._poly_ids = []
        self._geometry_types = []
        self._polygon_arc_ids = []
        self._polygon_arc_directions = []
        self._interior_arc_ids = []
        self._interior_arc_directions = []
        self._start_node = []
        self._end_node = []
        self._arc_list = []
        self._arc_ids = []
        self._point_list = []
        self._point_ids = []
        self._default_wkt = default_wkt
        self._cov_atts = atts

    def read_uuid(self) -> str | None:
        """Reads the coverage UUID only.

        Returns:
            (str): The coverage UUID.
        """
        cov_uuid = str(uuid.uuid4())
        with h5py.File(self._file_path, 'r') as file:
            cov_group_name = read_coverage_group_name(file)
            if cov_group_name == '':
                return None
            coverage_group = file['Map Data/' + cov_group_name]
            if 'GUID' in coverage_group:
                cov_uuid = coverage_group['GUID'][0].decode('utf-8')
        return cov_uuid

    def read_attributes(self) -> dict:
        """Reads the coverage attributes only.

        Returns:
            (dict): The attributes as a Python dictionary.
        """
        atts = dict()
        with h5py.File(self._file_path, 'r') as file:
            cov_group_name = read_coverage_group_name(file)
            if cov_group_name == '':
                return atts
            coverage_group = file['Map Data/' + cov_group_name]
            atts = _coverage_attributes_from_h5(coverage_group)
        return atts

    def read(self) -> GeoDataFrame | None:
        """Converts a coverage to a Pandas GeoDataFrame.

        Returns:
            (GeoDataFrame): The coverage.
        """
        with h5py.File(self._file_path, 'r') as file:
            cov_group_name = read_coverage_group_name(file)
            if cov_group_name == '':
                return None
            coverage_group = file['Map Data/' + cov_group_name]
            if 'Polygons' in coverage_group:
                polygons_group = coverage_group['Polygons']
                self._read_polygons(polygons_group)
            if 'Arcs' in coverage_group:
                arcs_group = coverage_group['Arcs']
                self._read_arcs(arcs_group)
            if 'Points' in coverage_group:
                points_group = coverage_group['Points']
                self._read_points(points_group)
            if 'Nodes' in coverage_group:
                nodes_group = coverage_group['Nodes']
                self._read_nodes(nodes_group)
            if 'Vertices' in coverage_group:
                vertices_group = coverage_group['Vertices']
                self._read_vertices(vertices_group)
            self._assign_arc_geometry()
            self._assign_polygon_geometry()
            json_strings = [''] * len(self._polygon_arc_ids)
            crs = ''
            if 'Coordinates' in coverage_group:
                coordinate_group = coverage_group['Coordinates']
                if 'WKT' in coordinate_group.attrs:
                    crs = coordinate_group.attrs['WKT'][0].decode('utf-8')
            if not crs:
                crs = self._default_wkt
            cov_uuid = str(uuid.uuid4())
            if 'GUID' in coverage_group:
                cov_uuid = coverage_group['GUID'][0].decode('utf-8')
            cov_name = cov_group_name
            if 'Name' in coverage_group:
                cov_name = coverage_group['Name'][0].decode('utf-8')
            gdf = GeoDataFrame({'id': self._poly_ids + self._arc_ids + self._point_ids,
                                'geometry_types': self._geometry_types,
                                'geometry': self._polygon_list + self._arc_list + self._point_list,
                                'polygon_arc_ids': self._polygon_arc_ids,
                                'polygon_arc_directions': self._polygon_arc_directions,
                                'interior_arc_ids': self._interior_arc_ids,
                                'interior_arc_directions': self._interior_arc_directions,
                                'start_node': self._start_node, 'end_node': self._end_node,
                                'attributes': json_strings}, crs=crs)
            gdf.attrs['name'] = cov_name
            gdf.attrs['uuid'] = cov_uuid
            gdf.attrs['attributes'] = _coverage_attributes_from_h5(coverage_group)
            gdf.attrs['filename'] = self._file_path
            return gdf

    def _read_polygons(self, polygons_group):
        poly_dict = {}
        poly_ids = []
        if POLY_ID in polygons_group:
            poly_ids = [i for i in polygons_group[POLY_ID]]
            poly_dict = {x: {} for x in poly_ids if x != 0}
        _add_to_dict(poly_dict, poly_ids, polygons_group, POLY_ARC_ID, False)
        _add_to_dict(poly_dict, poly_ids, polygons_group, POLY_ARC_DIRECTION, False)
        _add_to_dict(poly_dict, poly_ids, polygons_group, POLY_INNER_POLY, False)
        for poly_id, poly in poly_dict.items():
            poly_arc_ids = []
            poly_arc_directions = []
            hole_arc_ids = []
            hole_arc_directions = []
            if POLY_INNER_POLY in poly:
                cur_inner = 0
                hole_ids = []
                hole_directions = []
                for index, inner in enumerate(poly[POLY_INNER_POLY]):
                    # Create any interior rings (if there)
                    if inner != 0:
                        # Inside arcs have direction of "1" if not reversed
                        arc_dir = poly[POLY_ARC_DIRECTION][index] if POLY_ARC_DIRECTION in poly else 1
                        cur_arc = poly[POLY_ARC_ID][index]
                        if cur_inner != inner:
                            if hole_ids:
                                hole_arc_ids.append(hole_ids)
                                hole_arc_directions.append(hole_directions)
                            hole_ids = [cur_arc]
                            hole_directions = [arc_dir]
                            cur_inner = inner
                        else:
                            hole_ids.append(cur_arc)
                            hole_directions.append(arc_dir)
                    else:
                        # Outside arcs have direction of "1" if not reversed
                        arc_dir = poly[POLY_ARC_DIRECTION][index] if POLY_ARC_DIRECTION in poly else 1
                        poly_arc_ids.append(poly[POLY_ARC_ID][index])
                        poly_arc_directions.append(arc_dir)
                if hole_ids:
                    hole_arc_ids.append(hole_ids)
                    hole_arc_directions.append(hole_directions)
            self._geometry_types.append('Polygon')
            self._poly_ids.append(poly_id)
            self._polygon_arc_ids.append(poly_arc_ids)
            self._polygon_arc_directions.append(poly_arc_directions)
            self._interior_arc_ids.append(hole_arc_ids)
            self._interior_arc_directions.append(hole_arc_directions)
            self._polygon_list.append(Polygon())
            self._start_node.append(-1)
            self._end_node.append(-1)

    def _read_arcs(self, arcs_group):
        if ARC_ID in arcs_group:
            self._arc_ids = [i for i in arcs_group[ARC_ID]]
        if ARC_START_NODE in arcs_group:
            self._start_node.extend([i for i in arcs_group[ARC_START_NODE]])
        if ARC_END_NODE in arcs_group:
            self._end_node.extend([i for i in arcs_group[ARC_END_NODE]])
        if self._arc_ids:
            self._geometry_types.extend(['Arc'] * len(self._arc_ids))
            self._arc_list.extend([LineString()] * len(self._arc_ids))
            empty_list = [[] for _ in range(len(self._arc_ids))]
            self._polygon_arc_ids.extend(empty_list.copy())
            self._polygon_arc_directions.extend(empty_list.copy())
            self._interior_arc_ids.extend(empty_list.copy())
            self._interior_arc_directions.extend(empty_list.copy())

    def _read_points(self, points_group):
        if POINT_ID in points_group:
            self._point_ids.extend([i for i in points_group[POINT_ID]])
        for index in range(len(self._point_ids)):
            x = y = z = 0.0
            if POINT_X in points_group:
                x = points_group[POINT_X][index]
            if POINT_Y in points_group:
                y = points_group[POINT_Y][index]
            if POINT_Z in points_group:
                z = points_group[POINT_Z][index]
            self._point_list.append(Point(x, y, z))
            self._geometry_types.append('Point')
            self._start_node.append(-1)
            self._end_node.append(-1)
            self._polygon_arc_ids.append([])
            self._polygon_arc_directions.append([])
            self._interior_arc_ids.append([])
            self._interior_arc_directions.append([])

    def _read_nodes(self, nodes_group):
        node_ids = []
        if POINT_ID in nodes_group:
            node_ids = [i for i in nodes_group[POINT_ID]]
            self._point_ids.extend(node_ids)
        for index in range(len(node_ids)):
            x = y = z = 0.0
            if POINT_X in nodes_group:
                x = nodes_group[POINT_X][index]
            if POINT_Y in nodes_group:
                y = nodes_group[POINT_Y][index]
            if POINT_Z in nodes_group:
                z = nodes_group[POINT_Z][index]
            self._point_list.append(Point(x, y, z))
            self._geometry_types.append('Node')
            self._start_node.append(-1)
            self._end_node.append(-1)
            self._polygon_arc_ids.append([])
            self._polygon_arc_directions.append([])
            self._interior_arc_ids.append([])
            self._interior_arc_directions.append([])

    def _read_vertices(self, vertices_group):
        self._vertex_arc_ids = []
        if VERTEX_ARC_ID in vertices_group:
            self._vertex_arc_ids = [i for i in vertices_group[VERTEX_ARC_ID]]
        for index in range(len(self._vertex_arc_ids)):
            x = y = z = 0.0
            if VERTEX_X in vertices_group:
                x = vertices_group[VERTEX_X][index]
            if VERTEX_Y in vertices_group:
                y = vertices_group[VERTEX_Y][index]
            if VERTEX_Z in vertices_group:
                z = vertices_group[VERTEX_Z][index]
            self._point_ids.append(-1)
            self._point_list.append(Point(x, y, z))
            self._geometry_types.append('Vertex')
            self._start_node.append(-1)
            self._end_node.append(-1)
            self._polygon_arc_ids.append([])
            self._polygon_arc_directions.append([])
            self._interior_arc_ids.append([])
            self._interior_arc_directions.append([])

    def _assign_polygon_geometry(self):
        """Assigns polygon geometry based on polygon information."""
        for poly_index in range(len(self._poly_ids)):
            # Use the arcs defined for the polygon
            points = []
            for idx, (arc_id, arc_dir) in enumerate(
                    zip(self._polygon_arc_ids[poly_index], self._polygon_arc_directions[poly_index])):
                arc_points = self._arc_points_from_id(arc_id)
                if arc_points is not None:
                    # Outside arcs have direction of "1" if not reversed
                    if not arc_dir:
                        arc_points = arc_points.reverse()
                    for pt_idx, pt in enumerate(arc_points.coords):
                        if idx == 0 or pt_idx > 0:
                            points.append((pt[0], pt[1], pt[2]))
            holes = []
            for hole_ids, hole_dirs in zip(self._interior_arc_ids[poly_index],
                                           self._interior_arc_directions[poly_index]):
                inner_ring = []
                for hole_idx, (arc_id, arc_dir) in enumerate(zip(hole_ids, hole_dirs)):
                    arc_points = self._arc_points_from_id(arc_id)
                    if arc_points is not None:
                        # Inside arcs have direction of "1" if not reversed
                        if not arc_dir:
                            arc_points = arc_points.reverse()
                        for pt_idx, pt in enumerate(arc_points.coords):
                            if hole_idx == 0 or pt_idx > 0:
                                inner_ring.append((pt[0], pt[1], pt[2]))
                holes.append(inner_ring)
            self._polygon_list[poly_index] = Polygon(points, holes)

    def _assign_arc_geometry(self):
        """Assigns arc geometry based on arc information."""
        num_arcs = len(self._arc_ids)
        num_polys = len(self._poly_ids)
        for arc_index, arc_id in enumerate(self._arc_ids):
            start_pt = self._get_node_loc_from_id(self._start_node[arc_index + num_polys])
            end_pt = self._get_node_loc_from_id(self._end_node[arc_index + num_polys])
            if start_pt is not None and end_pt is not None:
                arc_pts = [start_pt]
                vert_index = 0
                for geom_index, geom_type in enumerate(self._geometry_types):
                    if geom_type == 'Vertex':
                        if self._vertex_arc_ids[vert_index] == arc_id:
                            arc_pts.append(self._point_list[geom_index - num_arcs - num_polys])
                        vert_index += 1
                arc_pts.append(end_pt)
                self._arc_list[arc_index] = LineString(arc_pts)

    def _get_node_loc_from_id(self, node_id: int) -> Point | None:
        for index, point_id in enumerate(self._point_ids):
            if point_id == node_id:
                return self._point_list[index]

    def _arc_points_from_id(self, arc_id) -> LineString | None:
        for index, cur_arc_id in enumerate(self._arc_ids):
            if cur_arc_id == arc_id:
                return self._arc_list[index]


def _add_to_dict(feature_dict: dict, feature_ids: list, feature_group: h5py.Group, identifier: str,
                 use_id_0: bool = True):
    if identifier in feature_group:
        for index, value in enumerate(feature_group[identifier]):
            cur_id = feature_ids[index]
            if use_id_0 or cur_id != 0:
                if identifier in feature_dict[cur_id]:
                    feature_dict[cur_id][identifier].append(value)
                else:
                    feature_dict[cur_id][identifier] = [value]
