"""FeaturesFromRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os.path

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.coverage_conversion import convert_lines_to_coverage
from xms.tool.utilities.file_utils import get_raster_filename
from xms.tool.whitebox import WhiteboxToolRunner
import xms.tool.whitebox.whitebox_tool_runner as wbr

ARG_INPUT_RASTER = 0
ARG_FEATURE_TYPE = 1
ARG_THRESHOLD_AREA = 2
ARG_PREPROCESSING_ENGINE = 3
ARG_OUTPUT_COVERAGE = 4


class StreamTracker:
    """StreamTracker class to keep track of streams generated from a raster."""

    def __init__(self, pos):
        """Initializes the class."""
        self.cur_pos = pos
        self.list_segment = [pos]

    def add_point_and_set_cur_pos(self, pos):
        """Adds a point and sets the current position.

        Args:
            pos (tuple): The point/position to set specified by the row, col values on a raster.
        """
        self.list_segment.append(pos)
        self.cur_pos = pos


class FeaturesFromRasterTool(Tool):
    """FeaturesFromRasterTool class."""
    FEATURE_TYPE_STREAM = 'Stream'
    FEATURE_TYPE_RIDGE = 'Ridge'

    WHITEBOX_RHO8 = 'Whitebox rho8'
    WHITEBOX_FULL_WORKFLOW = 'Whitebox full workflow'

    def __init__(self):
        """Initializes the class."""
        super().__init__('Features from Raster')
        self._raster = None
        self.flowacc = None
        self.flowdir = None
        self.elevations = None
        self.visited = None
        self.num_rows = 0
        self.num_cols = 0
        self.terminus_cells = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.string_argument(name='feature_type', description='Feature type', value=self.FEATURE_TYPE_STREAM,
                                 choices=[self.FEATURE_TYPE_STREAM, self.FEATURE_TYPE_RIDGE]),
            self.float_argument(name='threshold_area', description='Threshold area', value=100000.0),
            self.string_argument(name='preprocessing_engine', description='Pre-processing engine',
                                 value=self.WHITEBOX_RHO8, choices=[self.WHITEBOX_RHO8, self.WHITEBOX_FULL_WORKFLOW]),
            self.coverage_argument(name='output_coverage', description='Output coverage',
                                   io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Runs the tool.

        Args:
            arguments (list): A list of the tool's arguments.
        """
        input_filename = arguments[ARG_INPUT_RASTER].value
        self._raster = self.get_input_raster(input_filename)
        raster_filename = self.get_input_raster_file(input_filename)
        raster_base = os.path.splitext(os.path.basename(raster_filename))[0]
        demfill_filename = f'{raster_base}_breached'
        flowdir_filename = f'{raster_base}_flowdir'
        flowaccum_filename = f'{raster_base}_flowaccum'
        wbt_runner = WhiteboxToolRunner(self)
        if arguments[ARG_FEATURE_TYPE].value == self.FEATURE_TYPE_RIDGE:
            inverse_filename = f'{raster_base}_inverse'
            arg_values = {
                'input_file': input_filename,
                'output_file': inverse_filename
            }
            wbr.run_wbt_tool(wbt_runner, 'Negate', arg_values, False)
            input_filename = inverse_filename
        if arguments[ARG_PREPROCESSING_ENGINE].value == self.WHITEBOX_FULL_WORKFLOW:
            extension = '_full'
            demfill_filename = f'{demfill_filename}{extension}'
            flowdir_filename = f'{flowdir_filename}{extension}'
            flowaccum_filename = f'{flowaccum_filename}{extension}'
            arg_values = {
                'input_dem_file': input_filename,
                'output_dem_file': demfill_filename,
                'output_flow_pointer_file': flowdir_filename,
                'output_flow_accumulation_file': flowaccum_filename,
                'output_type': 'Cells'
            }
            wbr.run_wbt_tool(wbt_runner, 'FlowAccumulationFullWorkflow', arg_values, False)
        else:
            # Rho8 Method
            extension = '_rho8'
            demfill_filename = f'{demfill_filename}{extension}'
            flowdir_filename = f'{flowdir_filename}{extension}'
            flowaccum_filename = f'{flowaccum_filename}{extension}'
            # maximum_search_distance_cells = int(min(self._raster.resolution) / 40.0)
            maximum_search_distance_cells = 1
            arg_values = {
                'input_dem_file': input_filename,
                'output_file': demfill_filename,
                'maximum_search_distance_cells': maximum_search_distance_cells,
                'minimize_breach_distances': True,
                'fill_unbreached_depressions': True
            }
            wbr.run_wbt_tool(wbt_runner, 'BreachDepressionsLeastCost', arg_values, False)
            arg_values = {
                'input_dem_file': demfill_filename,
                'output_file': flowdir_filename
            }
            wbr.run_wbt_tool(wbt_runner, 'Rho8Pointer', arg_values, False)
            arg_values = {
                'input_dem_or_rho8_pointer_file': flowdir_filename,
                'output_raster_file': flowaccum_filename,
                'output_type': 'cells',
                'is_the_input_raster_a_rho8_flow_pointer': True
            }
            wbr.run_wbt_tool(wbt_runner, 'Rho8FlowAccumulation', arg_values, False)
        self._read_raster_flow_attributes(flowdir_filename, flowaccum_filename)
        accum_threshold = arguments[ARG_THRESHOLD_AREA].value
        ok, branches = self.find_streams(accum_threshold)
        streams = []
        if ok:
            for branch in branches:
                for stream in branch:
                    stream.reverse()
                    streams.append(stream)
            display_wkt = None
            if self.default_wkt is not None:
                display_wkt = gu.add_vertical_to_wkt(self.default_wkt, self.vertical_datum, self.vertical_units)
            new_cov = convert_lines_to_coverage(streams, arguments[ARG_OUTPUT_COVERAGE].value, self._raster.wkt,
                                                display_wkt)
            self.set_output_coverage(new_cov, arguments[ARG_OUTPUT_COVERAGE])

    def _read_raster_flow_attributes(self, flowdir_filename, flowaccum_filename):
        """Reads and sets the raster and flow direction/flow accumulation values for this tool."""
        self.num_rows = self._raster.resolution[1]
        self.num_cols = self._raster.resolution[0]
        flowdir = RasterInput(get_raster_filename(flowdir_filename))
        self.flowdir = flowdir.get_raster_values()
        flowacc = RasterInput(get_raster_filename(flowaccum_filename))
        self.flowacc = flowacc.get_raster_values()
        self.visited = np.full(np.shape(self.flowacc), False)
        self.elevations = self._raster.get_raster_values()

    def _find_ij_for_max_accum(self, accum_threshold):
        """Finds the row, col location on the raster with the max accumulation greater than the given threshold value.

        Args:
            accum_threshold (float): The accumulation threshold value.

        Returns:
            (tuple): The row, col of the max accumulation and whether a location was found.
        """
        active_accum = np.where(~self.visited, self.flowacc, -1)
        max_accum = np.amax(active_accum)
        area = abs(self._raster.pixel_width * self._raster.pixel_height)
        if max_accum * area >= accum_threshold:
            index = np.unravel_index(active_accum.argmax(), active_accum.shape)
            return index[0], index[1], True
        return -1, -1, False

    def find_streams(self, accum_threshold):
        """Finds all the streams on the raster for accumulations greater than the given threshold.

        Args:
            accum_threshold (float): The accumulation threshold value.

        Returns:
            (tuple): Whether the streams were found and a 3D list of points defining the stream segments for each
            stream branch.
        """
        streams = []
        i, j, found = self._find_ij_for_max_accum(accum_threshold)
        if not found:
            return False, streams
        while found:
            self.terminus_cells.append((i, j))
            streams.append(self._get_upstream_stream_segments(i, j, accum_threshold))
            i, j, found = self._find_ij_for_max_accum(accum_threshold)
        return True, streams

    def _get_upstream_stream_segments(self, i, j, accum_threshold):
        """Returns streams upstream from the given IJ cell indices.

        Args:
            i (int): The I index of the point to start
            j (int): The J index of the point to start
            accum_threshold (float): Pass in the min stream accumulation to use for determining whether a cell is a
                stream.  If a value of 0.0 is passed, this function will search for upstream cells with the maximum
                stream accumulation instead of using this value for a threshold.  The value passed should be in the
                units of the raster.

        Returns:
            (list): A 2D list of 3D point lists.  The 2D list contains stream segments for the entire branch, with each
            index having the segment points, with the downstream point in the segment as the first index and the
            upstream point in the segment as the last index.
        """
        streams = []
        stream_tracker = [StreamTracker((i, j))]
        found = True
        while found:
            list_in = []
            for st in stream_tracker:
                list_in.append(st.cur_pos)
            found, list_up = self._get_upstream_cells(list_in, accum_threshold)
            if found:
                for down_idx, cur_in in enumerate(list_in):
                    if len(list_up[down_idx]) > 1:
                        for cur_up in list_up[down_idx][1:]:
                            stream_tracker.append(StreamTracker(cur_in))
                            stream_tracker[-1].add_point_and_set_cur_pos(cur_up)
                # Remove dead-end or branching stream segments and push them to the streams list
                stream_index = 0
                for down_idx, cur_in in enumerate(list_in):
                    if len(list_up[down_idx]) == 0:
                        # The stream dead-ends
                        # Add the stream segment to the list and remove this stream tracker from the list
                        streams.append(stream_tracker[stream_index].list_segment)
                        del stream_tracker[stream_index]
                    else:
                        # The stream does not dead end
                        if len(list_up[down_idx]) > 1:
                            # The stream branches
                            # Add the stream segment to the list and start a new segment in the current stream tracker
                            streams.append(stream_tracker[stream_index].list_segment)
                            stream_tracker[stream_index] = StreamTracker(cur_in)
                        stream_tracker[stream_index].add_point_and_set_cur_pos(list_up[down_idx][0])
                        stream_index += 1
                    if stream_index >= len(stream_tracker):
                        break
        # Add all the remaining stream segments to the stream list
        for st in stream_tracker:
            streams.append(st.list_segment)
        out_streams = []
        for segment in streams:
            out_streams.append(self._add_segment_line(segment))
        return out_streams

    def _add_segment_line(self, segment):
        """Returns a vector of XYZ points defining a line given a line segment's row, col locations.

        Args:
            segment (list): The row, col locations of a line segment

        Returns:
            (list): A list of XYZ points defining the stream segment polyline.
        """
        pts = []
        origin = (self._raster.xorigin + 0.5 * self._raster.pixel_width,
                  self._raster.yorigin + 0.5 * self._raster.pixel_height)
        for cell in segment:
            i0 = cell[0]
            j0 = cell[1]
            cur_pt = (origin[0] + j0 * self._raster.pixel_width,
                      origin[1] + i0 * self._raster.pixel_height,
                      self.elevations[i0][j0])
            pts.append(cur_pt)
        return pts

    def _get_upstream_cells(self, list_in, accum_threshold):
        """Gets all the upstream cells with an accumulation higher than the threshold.

        This function returns all the upstream cells with an accumulation value higher than the given threshold and
        a flow direction that points to the input cells in the list_in list.  The size of the returned list will be
        the same size as the list_in list except each cell provided in list_in will have 0 or more upstream cells
        returned that are the upstream cells for each given cell in list_in.

        Args:
            list_in (list): The row, col locations of the input cells.  We're getting the upstream cells for these
             cells.
            accum_threshold (float): The flow accumulation area threshold in the XY units of the elevation raster.

        Returns:
            (list): A 2D list of row, col locations giving the upstream cells for each cell in the input list.
        """
        # Whitebox Flow Directions:
        # 64 128 1
        # 32 0   2
        # 16 8   4
        found = False
        list_up = []
        for cur_in in list_in:
            cur_i = cur_in[0]
            cur_j = cur_in[1]
            self.visited[cur_i][cur_j] = True
            list_temp = []
            valid_cells = []
            # Lower left cell
            valid = False
            list_temp.append((cur_i + 1, cur_j - 1))
            if cur_i != self.num_rows - 1 and cur_j != 0:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 1:
                        valid = True
            valid_cells.append(valid)
            # Lower cell
            valid = False
            list_temp.append((cur_i + 1, cur_j))
            if cur_i != self.num_rows - 1:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 128:
                        valid = True
            valid_cells.append(valid)
            # Lower right cell
            valid = False
            list_temp.append((cur_i + 1, cur_j + 1))
            if cur_i != self.num_rows - 1 and cur_j != self.num_cols - 1:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 64:
                        valid = True
            valid_cells.append(valid)
            # Left cell
            valid = False
            list_temp.append((cur_i, cur_j - 1))
            if cur_j != 0:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 2:
                        valid = True
            valid_cells.append(valid)
            # Right cell
            valid = False
            list_temp.append((cur_i, cur_j + 1))
            if cur_j != self.num_cols - 1:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 32:
                        valid = True
            valid_cells.append(valid)
            # Upper left cell
            valid = False
            list_temp.append((cur_i - 1, cur_j - 1))
            if cur_j != 0 and cur_i != 0:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 4:
                        valid = True
            valid_cells.append(valid)
            # Upper cell
            valid = False
            list_temp.append((cur_i - 1, cur_j))
            if cur_i != 0:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 8:
                        valid = True
            valid_cells.append(valid)
            # Upper right cell
            valid = False
            list_temp.append((cur_i - 1, cur_j + 1))
            if cur_i != 0 and cur_j != self.num_cols - 1:
                i0 = list_temp[-1][0]
                j0 = list_temp[-1][1]
                if self._get_upstream_area(i0, j0) >= accum_threshold:
                    if self.flowdir[i0][j0] == 16:
                        valid = True
            valid_cells.append(valid)
            list_add = []
            for cell_idx, valid_cell in enumerate(valid_cells):
                if valid_cell:
                    list_add.append(list_temp[cell_idx])
            list_up.append(list_add)
            if list_add:
                found = True
        return found, list_up

    def _get_upstream_area(self, i0, j0):
        """Returns the upstream area (a positive float value) for the raster cell at row i0 and column j0."""
        return abs(self.flowacc[i0][j0] * self._raster.pixel_width * self._raster.pixel_height)
