"""BlendRasterToEdgesTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import GdalRunner
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.rasters.merge_elevation_rasters_tool import merge_rasters_with_blend

ARG_PRIMARY_RASTER = 0
ARG_SECONDARY_RASTER = 1
ARG_BLEND_WIDTH = 2
ARG_OUTPUT_RASTER = 3


class BlendRasterToEdgesTool(Tool):
    """Tool to blend the secondary raster into the edges of the primary raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Blend Raster to Edges')
        self._file_count = 0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='primary_raster', description='Primary raster'),
            self.raster_argument(name='secondary_raster', description='Secondary raster'),
            self.float_argument(name='blend_width', description='Blend width along edge', min_value=0.0),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the raster files
        multiple_band_error = 'This tool should only be used on rasters with a single dataset and is not designed to ' \
                              'be used with RGB images or rasters with more than one band (dataset).'
        primary_raster = self.get_input_raster(arguments[ARG_PRIMARY_RASTER].value)
        if primary_raster.gdal_raster.RasterCount > 2:
            errors[arguments[ARG_PRIMARY_RASTER].name] = multiple_band_error
        secondary_raster = self.get_input_raster(arguments[ARG_SECONDARY_RASTER].value)
        if secondary_raster.gdal_raster.RasterCount > 2:
            errors[arguments[ARG_SECONDARY_RASTER].name] = multiple_band_error
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        primary_raster = self.get_input_raster(arguments[ARG_PRIMARY_RASTER].value)
        primary_raster_file = self.get_input_raster_file(arguments[ARG_PRIMARY_RASTER].value)
        secondary_raster = self.get_input_raster(arguments[ARG_SECONDARY_RASTER].value)
        secondary_raster_file = self.get_input_raster_file(arguments[ARG_SECONDARY_RASTER].value)
        blend_width = arguments[ARG_BLEND_WIDTH].value

        # make a temporary directory
        self.temp_file_path = filesystem.temp_filename()
        os.mkdir(self.temp_file_path)

        wkt = primary_raster.wkt

        # build projection file
        wkt_file = None
        if gu.valid_wkt(wkt):
            wkt_file = os.path.join(self.temp_file_path, 'output.prj')
            with open(wkt_file, 'wt') as file:
                file.write(wkt)

        gdal_runner = GdalRunner()

        # warp elevation file to clipper file (same pixels and bounds as raster_to_clip)
        min, max = primary_raster.get_raster_bounds()
        xmin = min[0]
        ymin = min[1]
        xmax = max[0]
        ymax = max[1]
        xsize, ysize = primary_raster.resolution
        args = ['-te', f'{xmin}', f'{ymin}', f'{xmax}', f'{ymax}', '-ts', f'{xsize}', f'{ysize}', '-r', 'bilinear']
        if wkt_file is not None:
            args.extend(['-t_srs', f'{wkt_file}'])
        args.extend(['-of', 'GTiff'])
        clipper_file_no_data = gdal_runner.run_wrapper('gdalwarp', secondary_raster_file, 'clipper_no_data.tif',
                                                       options=args)

        # use primary raster to put matching no data in secondary raster
        primary_no_data_value = primary_raster.nodata_value
        secondary_no_data_value = secondary_raster.nodata_value
        if primary_no_data_value is not None:
            if secondary_no_data_value is None:
                secondary_no_data_value = primary_no_data_value
            args = ['-A', primary_raster_file, '-B', clipper_file_no_data, '--outfile=$OUT_FILE$',
                    f'--calc=B*(A!={primary_no_data_value})', f'--NoDataValue={secondary_no_data_value}']
            blender_file = gdal_runner.run('gdal_calc.py', 'blender.tif', args)
        else:
            blender_file = clipper_file_no_data

        # blend
        blended_file = merge_rasters_with_blend(self, gdal_runner, primary_raster_file, blend_width,
                                                blender_file)

        # Reproject raster to the display projection
        output_raster = ru.reproject_raster(blended_file, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                            self.default_wkt, self.vertical_datum, self.vertical_units)
        self.set_output_raster_file(output_raster, arguments[ARG_OUTPUT_RASTER].text_value)

        # Remove the folder containing any generated files
        shutil.rmtree(gdal_runner.temp_file_path)


# def main():
#     """Main function, for testing."""
#     pass
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#     from xms.tool.utilities.file_utils import get_test_files_path
#
#     qapp = ensure_qapplication_exists()
#     tool = BlendRasterToEdgesTool()
#     tool.set_gui_data_folder(get_test_files_path())
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#     qapp = None
#
#
# if __name__ == "__main__":
#     main()
