"""EditRasterElevationsTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.gdal.vectors import VectorOutput
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.coverage_conversion import get_arcs_from_coverage

ARG_INPUT_RASTER = 0
ARG_INPUT_COVERAGE = 1
ARG_OUTPUT_RASTER = 2


class EditRasterElevationsTool(Tool):
    """Tool to edit a raster's elevation values via a coverage, writing out a new raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Edit Elevations')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.coverage_argument(name='input_coverage', description='Coverage with elevation values'),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_raster = self.get_input_raster(arguments[ARG_INPUT_RASTER].value)
        input_coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)

        # Get the intersection of the coverage points and arcs, and the input raster
        if self._edit_raster(input_raster, input_coverage, out_path):
            out_path = ru.reproject_raster(out_path, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                           self.default_wkt, self.vertical_datum, self.vertical_units)
            self.set_output_raster_file(out_path, arguments[ARG_OUTPUT_RASTER].text_value)

    def _edit_raster(self, input_raster, coverage, out_path):
        """Edits the raster based on the points and arcs of the coverage.

        Args:
            input_raster (xms.gdal.RasterInput): The input raster.
            coverage (obj): The input coverage.
            out_path (str): The filename of the raster to create.

        Returns:
            bool: True if success, False otherwise
        """
        # Grab some of the information needed from the coverage and raster
        points = coverage[coverage['geometry_types'] == 'Point']
        arcs = get_arcs_from_coverage(coverage)
        raster_wkt = input_raster.wkt
        display_wkt = self.default_wkt if self.default_wkt else ''
        raster_sr = display_sr = None
        if gu.valid_wkt(raster_wkt):
            raster_sr = gu.wkt_to_sr(raster_wkt)
        if gu.valid_wkt(display_wkt):
            display_sr = gu.wkt_to_sr(display_wkt)
        if raster_sr is not None and display_sr is not None:
            if raster_sr.IsCompound() and not display_sr.IsCompound():
                display_sr = self._handle_vertical_projection(display_sr, raster_sr)

        # Make an in memory vector layer with the points from the coverage in order to rasterize it
        feature_point_file = '/vsimem/points.shp'
        # feature_point_file = 'c:/temp/editRasterElevations/points.shp'
        point_vo = VectorOutput()
        point_vo.initialize_file(feature_point_file, raster_sr.ExportToWkt(), from_wkt=display_sr.ExportToWkt())
        # Loop on each point after transforming from the coverage projection to raster projection
        for point in points.itertuples():
            point_vo.write_point([point.geometry.x, point.geometry.y, point.geometry.z])

        # Make an in memory vector layer with arcs from the coverage in order to rasterize it
        feature_arc_file = '/vsimem/arcs.shp'
        # feature_arc_file = 'c:/temp/editRasterElevations/arcs.shp'
        arc_vo = VectorOutput()
        arc_vo.initialize_file(feature_arc_file, raster_sr.ExportToWkt(), from_wkt=display_sr.ExportToWkt())
        for arc in arcs:
            # Create a line string, and convert to the raster's spatial reference
            arc_vo.write_arc(arc['arc_pts'])

        # Create output raster, rasterize the vector layers, using the Z values from the point and arc layers
        ru.copy_raster_from_raster_input(input_raster, out_path)
        ri = RasterInput(out_path, True)
        gw.gdal_rasterize_layer(point_vo.ogr_layer, ri.gdal_raster, [1], [0], ["BURN_VALUE_FROM=Z"])
        gw.gdal_rasterize_layer(arc_vo.ogr_layer, ri.gdal_raster, [1], [0], ["BURN_VALUE_FROM=Z"])

        # Clean up the datasets
        ri = None
        point_vo = None
        arc_vo = None
        gu.delete_vector_file(feature_point_file)
        gu.delete_vector_file(feature_arc_file)
        return True

    def _handle_vertical_projection(self, display_sr, raster_sr):
        """Promote display_sr to a 3D projection based on the display projection settings.

        Args:
            display_sr (osr.SpatialReference): The display spatial reference.
            raster_sr (osr.SpatialReference): The raster spatial reference.

        Returns:
            osr.SpatialReference: The display spatial reference
        """
        # Promote display_sr to a 3D projection based on the display projection settings
        datum_v = self.vertical_datum
        units_v = self.vertical_units
        add_vertical, display_sr = gu.add_vertical_projection(display_sr, datum_v, units_v)
        if not add_vertical:
            # We need to demote the raster's projection from a 3D to a 2D projection in this case since a 3D
            # projection will incorrectly attempt to convert the Z coordinates of the map data.  See bug 13655.
            raster_sr.DemoteTo2D()
        return display_sr
