"""Class to read a shapefile and convert it to a data_objects coverage."""
# 1. Standard python modules
import logging
import os
import uuid

# 2. Third party modules
from osgeo import osr
import osgeo.gdal as gdal
import shapefile
from shapely.geometry import shape

# 3. Aquaveo modules
from xms.data_objects.parameters import Arc, Coverage, Point, Polygon
from xms.gdal.utilities import gdal_utils as gu

# 4. Local modules

log_frequency = 1000


class ShapefileConverter:
    """Class to read a shapefile and convert it to a data_objects coverage."""

    def __init__(self, filename, project_wkt, extents=None):
        """Constructor.

        Args:
            filename (:obj:`str`): Path to the polygon shapefile
            project_wkt (:obj:`str`): WKT containing a projection to project to
            extents (:obj:`list`): List of 2D points containing a bounding box for clipping polygons
        """
        from shapely import speedups
        speedups.disable()  # Wish I knew why we have to do this.
        self._logger = logging.getLogger('xms.srhw')
        self._filename = filename
        gdal.SetConfigOption('OGR_CT_FORCE_TRADITIONAL_GIS_ORDER', 'YES')
        shapefile_wkt = gu.read_projection_file(os.path.splitext(self._filename)[0] + '.prj')
        project_sr = osr.SpatialReference()
        h_wkt = gu.strip_vertical(project_wkt)
        project_sr.ImportFromWkt(h_wkt)
        shapefile_sr = osr.SpatialReference()
        h_wkt = gu.strip_vertical(shapefile_wkt)
        shapefile_sr.ImportFromWkt(h_wkt)
        self._coord_trans = osr.CreateCoordinateTransformation(shapefile_sr, project_sr)
        self._extents = extents
        self._next_point_id = 1
        self._next_arc_id = 1
        self._next_poly_id = 1
        self._pt_hash = {}  # {location_hash: Point}
        self._arc_hash = {}  # {arc_hash: Arc}
        # This is the mapping of XMS map module feature id to the attribute (BC or material) id in the file.
        self.feature_id_to_att_id = {}  # {feature_id: bc_id/mat_id}

    def _find_field_name(self, shp_f):
        """Find the first field name in a shapefile's records.

        Args:
            shp_f (:obj:`shapefile.Reader`): The feature object as read from the shapefile by pyshp

        Returns:
            (:obj:`str`): The ID field name,
        """
        for field in shp_f.fields:
            if field[0].lower() == "texture":
                return field[0]
        raise ValueError(f'Could not find "texture" field name in shapefile: {self._filename}')

    def _create_do_coverage(self):
        """Get an empty data_objects Coverage.

        Returns:
            (:obj:`data_objects.parameters.Coverage`): See description
        """
        # Using the file's basename as the coverage name.
        cov_name = f'{os.path.splitext(os.path.basename(self._filename))[0]}'
        cov = Coverage(name=cov_name, uuid=str(uuid.uuid4()))
        cov.complete()
        return cov

    def _get_hashed_point(self, x, y):
        """Get an existing data_objects point using a location hash (creates the point if it doesn't exist).

        Args:
            x (:obj:`float`): The x-coordinate of the point

        Returns:
            (:obj:`data_objects.parameters.Point`): The existing or newly created point associated with the passed
            in location
        """
        point_hash = hash((x, y))
        do_point = self._pt_hash.get(point_hash)
        if not do_point:
            do_point = Point(x=x, y=y, feature_id=self._next_point_id)
            self._next_point_id += 1
            self._pt_hash[point_hash] = do_point
        return do_point

    def _get_hashed_arc(self, coords):
        """Create a data_objects arc from a sequence of point coordinates, or retrieve a previously created one.

        Args:
            coords (:obj:`list`): The x,y coords of the arc points

        Returns:
            (:obj:`data_objects.parameters.Arc`): The existing or newly created arc associated with the passed
            in locations
        """
        points = [self._get_hashed_point(x, y) for x, y in coords]
        # Check if we have already created this arc
        point_ids = [point.id for point in points]
        point_ids.sort()
        arc_hash = hash(tuple(point_ids))  # Create a hash using all the sorted point ids of the arc (ignore direction)
        do_arc = self._arc_hash.get(arc_hash)
        if not do_arc:
            vertices = points[1:-1] if len(points) > 2 else None
            do_arc = Arc(start_node=points[0], end_node=points[-1], vertices=vertices, feature_id=self._next_arc_id)
            self._next_arc_id += 1
            self._arc_hash[arc_hash] = do_arc
        return do_arc

    def convert_polygons(self):
        """Convert a polygon shapefile to a data_object Coverage.

        Returns:
             (:obj:`Coverage`): The data_objects Coverage geometry
        """
        try:
            do_polygons = []
            with shapefile.Reader(self._filename) as f:
                num_polys = len(f)
                field_name = self._find_field_name(f)
                self._logger.info(f'Converting {num_polys} polygons to features...')
                for i, polygon in enumerate(f):
                    if (i + 1) % log_frequency == 0:
                        self._logger.info(f'Processing polygon {i + 1} of {num_polys}...')
                    geometry = shape(polygon.shape)
                    exterior_points = self._convert_points(list(geometry.exterior.coords))
                    if self._in_extents(exterior_points):
                        exterior = self._get_hashed_arc(exterior_points)
                        holes = []
                        for interior in geometry.interiors:
                            interior_points = self._convert_points(interior.coords)
                            holes.append([self._get_hashed_arc(interior_points)])
                        do_polygon = Polygon(feature_id=self._next_poly_id)
                        self.feature_id_to_att_id[self._next_poly_id] = polygon.record[field_name]
                        self._next_poly_id += 1
                        do_polygon.set_arcs([exterior])
                        if holes:
                            do_polygon.set_interior_arcs(holes)
                        do_polygons.append(do_polygon)
            do_cov = self._create_do_coverage()
            do_cov.polygons = do_polygons
            return do_cov
        except ValueError as ve:
            self._logger.error(f'{str(ve)}')
        except Exception:
            self._logger.error(f'Error converting the shapefile {self._filename} to polygons.')
        return None

    def _convert_points(self, coords):
        if self._coord_trans:
            transformed_coords = self._coord_trans.TransformPoints(coords)
            transformed_coords = [(x, y) for (x, y, z) in transformed_coords]
            return transformed_coords
        return coords

    def _in_extents(self, exterior_points):
        if self._extents:
            for point in exterior_points:
                point_0_in_box = self._extents[0][0] <= point[0] <= self._extents[1][0]
                point_1_in_box = self._extents[0][1] <= point[1] <= self._extents[1][1]
                if point_0_in_box and point_1_in_box:
                    return True
            return False
        return True
