"""DirectionalRoughnessTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import multiprocessing
import os

# 2. Third party modules
import numpy as np
import pandas

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.geometry.geometry import point_in_polygon_2d
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.landuse_codes import get_directional_roughness_values, get_directional_roughness_values_ccap

ARG_INPUT_LANDUSE = 0
ARG_INPUT_METHOD = 1
ARG_INPUT_TOTAL_DISTANCE = 2
ARG_INPUT_WEIGHTED_DISTANCE = 3
ARG_INPUT_GRID = 4
ARG_INPUT_LANDUSE_TYPE = 5
ARG_INPUT_CSV = 6
ARG_INPUT_DEFAULT = 7
ARG_OUTPUT_DATASET = 8
OUTPUT_TREE_PATH = 'surface_directional_effective_roughness_length'
NUMBER_OF_ANGLES = 12
DEGREES_TO_METERS = 1e05


def _calculate_linear_for_angle_worker(args):
    """
    Worker function to be run in a separate process. Calculates roughness for all points for a single angle.

    Args:
        args (tuple): A tuple containing all necessary arguments.
                      Expected order: (angle_idx, point_locations, total_distance, weighted_distance,
                                       code_vals, roughness_vals, default_roughness, raster_params)

    Returns:
        np.ndarray: A 1D array of roughness values for the given angle.
    """
    # unpack arguments
    (angle_idx, point_locations, total_distance, weighted_distance,
     code_vals, roughness_vals, default_roughness, raster_params) = args

    # get angle distances
    pi180 = math.pi / 180.0
    oc_angle = convert_from_oceanographic_to_cartesian(angle_idx * 30)
    delta_x = total_distance * math.cos(oc_angle * pi180)
    delta_y = total_distance * math.sin(oc_angle * pi180)

    num_points = len(point_locations)
    # each worker only computes one row (one angle) of the final dataset
    dataset_row = np.full(num_points, default_roughness)
    null_pt = (-9999, -9999, 0)

    # main processing loop for all grid points
    for pt_idx, point in enumerate(point_locations):
        # define a start and end for the line segment
        end_point = [point[0] + delta_x, point[1] + delta_y, 0.0]
        point_range = [point, end_point]

        # get the path of points through the raster
        path_points = _get_path_worker(raster_params, point_range, null_pt)

        # calculate the weighted average roughness along the path
        sum_weights = 0.0
        sum_weighted = 0.0

        for cur_point in path_points:
            # check if the point is outside the raster
            if cur_point[2] == null_pt[2]:
                roughness = 0.0
            else:
                code = int(cur_point[2] + 0.5)
                idx = (np.abs(code_vals - code)).argmin()
                roughness = roughness_vals[idx]

            # calculate weights
            result = compute_weight_from_distance(point, cur_point, weighted_distance)
            sum_weights += result
            sum_weighted += result * roughness

        if sum_weights != 0.0:
            dataset_row[pt_idx] = float(sum_weighted / sum_weights)

    # return the result for a single angle
    return dataset_row


def _get_path_worker(raster_params, search_path, null_pt):
    """
    A worker function that gets a path through the raster. It should be called from other worker processes.

    Arguments:
        raster_params (dict): A dictionary containing all the necessary raster attributes.
        search_path (list of tuples): The path of points to search in the raster.
        null_pt (point): A null point for nodata values.
    """
    # unpack raster parameters from the dictionary
    origin = raster_params['origin']
    spacing = raster_params['spacing']
    xres = raster_params['xres']
    yres = raster_params['yres']
    min_pt = raster_params['min_pt']
    max_pt = raster_params['max_pt']
    raster_values = raster_params['raster_values']

    # convert world coordinates to raster cell coordinates
    search_path_xy = np.array(search_path, dtype=np.float64)[:, :2]
    raster_coords = np.empty_like(search_path_xy)
    raster_coords[:, 0] = (search_path_xy[:, 0] - origin[0]) / spacing[0]
    raster_coords[:, 1] = (origin[1] - search_path_xy[:, 1]) / spacing[1]
    pi = (xres, yres)

    # iterate over line segments to generate all raster cells along the path
    all_cells = []
    for ii in range(len(raster_coords) - 1):
        p0, p1 = raster_coords[ii], raster_coords[ii + 1]
        num_steps = int(np.ceil(np.max(np.abs(p1 - p0)))) + 1

        x_points, y_points = np.linspace(p0[0], p1[0], num_steps), np.linspace(p0[1], p1[1], num_steps)
        all_cells.append(np.floor(np.vstack([x_points, y_points]).T).astype(int))

    # add the last point of the path to make it a (1, 2) array
    all_cells.append(np.floor(raster_coords[-1]).astype(int).reshape(1, 2))
    rcells = np.vstack(all_cells)

    # filter out duplicate consecutive cells
    unique_mask = np.concatenate(([True], np.any(rcells[1:] != rcells[:-1], axis=1)))
    unique_cells = rcells[unique_mask]

    # get Z values for all unique cells
    valid_bounds_mask = ((unique_cells[:, 0] >= 0) & (unique_cells[:, 0] < pi[0]) & (unique_cells[:, 1] >= 0)
                         & (unique_cells[:, 1] < pi[1]))
    rpath = np.full((len(unique_cells), 3), null_pt, dtype=np.float64)

    # convert cell indices back to world coordinates
    rpath[:, 0] = unique_cells[:, 0] * spacing[0] + origin[0]
    rpath[:, 1] = origin[1] - unique_cells[:, 1] * spacing[1]

    # find valid cells (within bounds and valid Z)
    valid_cells = unique_cells[valid_bounds_mask]
    if valid_cells.size > 0:
        z_values = raster_values[valid_cells[:, 1], valid_cells[:, 0]]
        valid_z_mask = (z_values >= min_pt[2]) & (z_values <= max_pt[2])
        # create a mask to identify valid points
        full_validity_mask = np.copy(valid_bounds_mask)
        full_validity_mask[valid_bounds_mask] = valid_z_mask
        # overwrite z value for valid points
        rpath[full_validity_mask, 2] = z_values[valid_z_mask]

    return rpath.tolist()


def convert_from_oceanographic_to_cartesian(angle):
    """Converts the angle from oceanographic angle convention to cartesian convention.

    Arguments:
         angle (float):  The angle in degrees to convert, in oceanographic convention.

    Returns:
        float:  The angle converted to cartesian convention.
    """
    return 90.0 - angle


def compute_weight_from_distance(location, cur_point, weighted_distance):
    """Computes a weight from a distance.

    Arguments:
        location (tuple):  The target location.
        cur_point (Sequence):  The current location.
        weighted_distance (float):  The weighted distance factor.

    Returns:
        float:  The computed weight.
    """
    dist = math.sqrt((location[0] - cur_point[0]) ** 2 + (location[1] - cur_point[1]) ** 2)
    num = -(dist ** 2) / (2 * (weighted_distance ** 2))
    return math.exp(num)


def compute_double_weighted_average(weight, val, area):
    """Weight values and area.

    Arguments:
        weight (list of float):  list of weights.
        val (list of float):  list of values.
        area (list of float):  list of areas.

    Returns:
        float:  the computed double weighted averags.
    """
    w_size = len(weight)
    v_size = len(val)
    a_size = len(area)

    if w_size != v_size or w_size != a_size:
        return -1

    total = 0.0
    sum_aw = 0.0
    for i in range(w_size):
        total += weight[i] * val[i] * area[i]
        sum_aw += weight[i] * area[i]
    return total / sum_aw


def _get_key_row(item):
    """Functor to return the second item in an indexable object.

    Args:
        item (Sequence): The object to index

    Returns:
        int: See description
    """
    return item[1]


def _get_key_column(item):
    """Functor to return the first item in an indexable object.

    Args:
        item (Sequence): The object to index

    Returns:
        int: See description
    """
    return item[0]


class DirectionalRoughnessTool(Tool):
    """Tool to convert an NLCD raster to directional roughness dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Directional Roughness')
        self._cogrid = None
        self._grid = None
        self._grid_uuid = None
        self._builders = []
        self._raster = None
        self._raster_values = None
        self._raster_null_value = -999999.0
        self._x_origin = 0.0
        self._y_origin = 0.0
        self._pixel_width = 0
        self._pixel_height = 0
        self._xres = 0
        self._yres = 0
        self._min_pt = None
        self._max_pt = None
        self._origin = None
        self._spacing = None
        self._pixel_area = 0.0
        self._mapping_table = {}
        self._default_value = 0.025
        self._dataset_values = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='landuse_raster', description='Input landuse raster'),
            self.string_argument(name='method', description='Method',
                                 choices=['Linear', 'Sector'], value='Linear'),
            self.float_argument(name='total_distance', description='Total distance (m)', value=10000.0,
                                min_value=0.0),
            self.float_argument(name='weighted_distance', description='Weighted distance (m)', value=3000.0,
                                min_value=0.0),
            self.grid_argument(name='grid', description='Target grid'),
            self.string_argument(name='landuse_type', description='Landuse raster type',
                                 choices=['NLCD', 'C-CAP', 'Other'], value='NLCD'),
            self.file_argument(name='mapping_csv', description='Landuse to directional roughness mapping table',
                               optional=True),
            self.float_argument(name='default_value', description='Default wind reduction value', value=0.001,
                                min_value=0.0),
            self.dataset_argument(name='roughness_dataset', description='Output wind reduction dataset',
                                  value='z0Land', io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # raster must be an index raster
        self._raster = self.get_input_raster(arguments[ARG_INPUT_LANDUSE].text_value)
        if not ru.is_index_raster(self._raster):
            errors[arguments[ARG_INPUT_LANDUSE].name] = 'Landuse raster must be an index raster.'

        # Validate input data
        self._cogrid = self.get_input_grid(arguments[ARG_INPUT_GRID].text_value)
        if not self._cogrid:
            errors[arguments[ARG_INPUT_GRID].name] = 'Could not open target grid.'
        land_use_type = arguments[ARG_INPUT_LANDUSE_TYPE].text_value if \
            arguments[ARG_INPUT_LANDUSE_TYPE].text_value is not None else ''
        csv_file = arguments[ARG_INPUT_CSV].text_value if arguments[ARG_INPUT_CSV].text_value is not None \
            else ''

        # Custom mapping table CSV
        if len(csv_file) > 0:
            if not os.path.exists(csv_file):
                errors[arguments[ARG_INPUT_CSV].name] = 'Could not open Landuse to Roughness mapping table.'
        # Must specify a mapping table if using custom land use type
        elif land_use_type != 'NLCD' and land_use_type != 'C-CAP':
            errors[arguments[ARG_INPUT_CSV].name] = \
                'Must select a Roughness mapping table CSV file when land use type is "Other".'
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        custom_codes = arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'Other'
        arguments[ARG_INPUT_CSV].hide = not custom_codes

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Set up some grid variables
        self._grid_uuid = self._cogrid.uuid
        self._grid = self._cogrid.ugrid

        # Get the directional roughness info
        code_vals, _, roughness_vals = self.get_roughness_info(arguments)
        code_vals = np.array(code_vals)
        default_roughness, total_distance, weighted_distance = self.get_input_options(arguments)

        # Set up some raster variables
        self.initialize_raster(arguments)

        point_locations = self._grid.locations.tolist()  # Iterate on a pure Python list, numpy array is slow
        point_locations = gu.transform_points_from_wkt(point_locations, self.default_wkt, self._raster.wkt)
        if arguments[ARG_INPUT_METHOD].text_value == 'Linear':
            self.calculate_linear(point_locations, total_distance, weighted_distance, code_vals, roughness_vals,
                                  default_roughness)
        elif arguments[ARG_INPUT_METHOD].text_value == 'Sector':
            self.calculate_sector(point_locations, total_distance, weighted_distance, code_vals, roughness_vals,
                                  default_roughness)

        # Write out the dataset
        self._setup_output_dataset_builder(arguments)
        self._add_output_datasets()

    def get_roughness_info(self, arguments):
        """Gets the roughness information, either built in or from a user CSV file.

        Arguments:
            arguments (list): The tool arguments.

        Returns:
            tuple of lists:  The code values, descriptions, and roughness values for each landuse type.
        """
        if arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'NLCD':
            return get_directional_roughness_values()
        elif arguments[ARG_INPUT_LANDUSE_TYPE].text_value == 'C-CAP':
            return get_directional_roughness_values_ccap()
        # The user chose a .csv file.  Read it as a dataframe, and rename the columns.
        mapping_df = pandas.read_csv(arguments[ARG_INPUT_CSV].text_value, index_col=0, header=0)
        mapping_df.columns = ['Description', 'Roughness']
        return mapping_df.index.tolist(), mapping_df['Description'].tolist(), mapping_df['Roughness'].tolist()

    def get_input_options(self, arguments):
        """Gets the user input float options.

        Arguments:
            arguments (list): The tool arguments.

        Returns:
            tuple of float:  The default roughness, total distance, and weighted distance input values
        """
        default_roughness = float(arguments[ARG_INPUT_DEFAULT].text_value)
        total_distance = float(arguments[ARG_INPUT_TOTAL_DISTANCE].text_value)
        weighted_distance = float(arguments[ARG_INPUT_WEIGHTED_DISTANCE].text_value)
        if gu.valid_wkt(self._raster.wkt):
            sr = gu.wkt_to_sr(self._raster.wkt)
            if sr.IsGeographic():
                total_distance /= DEGREES_TO_METERS
                weighted_distance /= DEGREES_TO_METERS
        return default_roughness, total_distance, weighted_distance

    def calculate_linear(self, point_locations, total_distance, weighted_distance, code_vals,
                         roughness_vals, default_roughness):
        """
        Calculates roughness for the linear method using a pool of multiprocessing workers.
        """
        self.logger.info("Setting up multiprocessing pool...")

        # create a dictionary of raster parameters to pass to the workers
        raster_params = {
            'origin': self._origin,
            'spacing': self._spacing,
            'xres': self._xres,
            'yres': self._yres,
            'min_pt': self._min_pt,
            'max_pt': self._max_pt,
            'raster_values': self._raster_values,
            'default_value': self._default_value
        }

        # create a list of tasks, one for each angle
        tasks = []
        for angle_idx in range(NUMBER_OF_ANGLES):
            task_args = (angle_idx, point_locations, total_distance, weighted_distance,
                         code_vals, roughness_vals, default_roughness, raster_params)
            tasks.append(task_args)

        # we use all the CPU cores except one for the OS and SMS
        num_processes = max(1, os.cpu_count() - 1)
        self.logger.info(f"Starting calculations for {NUMBER_OF_ANGLES} angles.")
        results = []
        try:
            with multiprocessing.Pool(processes=num_processes) as pool:
                # distribute 'tasks' list to the worker function and collect results
                results = pool.map(_calculate_linear_for_angle_worker, tasks)

            # the results will be a list of 12 1D numpy arrays
            # we stack them to form the final (12, num_points) array
            self._dataset_values = np.array(results)
            self.logger.info("All angles processed successfully.")

        except Exception as e:  # pragma no cover
            self.fail(f"A multiprocessing error occurred: {e}")

    def calculate_sector(self, point_locations, total_distance, weighted_distance, code_vals,
                         roughness_vals, default_roughness):
        """Calculates roughness for the sector method.

        Arguments:
            point_locations (list of tuples):  The grid locations used for the calculation.
            total_distance (float):  The total distance to calculate.
            weighted_distance (float):  The weighted distance to calculate.
            code_vals (np.array):  The landuse code values found for the landuse data.
            roughness_vals (list of float):  The landuse roughness values corresponding to the landuse types.
            default_roughness (float):  The default roughness value.

        Returns:
            list of float:  The roughness values calculated.
        """
        # SECTION method chosen by user
        # create a 3 point range based on the node location and 15 degrees above  and below the node location
        # (creates a triangle).  Then get the raster points in the triangle area.  Then get the land code
        # for each point and compute
        # create triangle that is 15 deg above and below the incoming angle
        oc_angles = [convert_from_oceanographic_to_cartesian(i * 30) for i in range(NUMBER_OF_ANGLES)]
        angles_above = [angle + 15 for angle in oc_angles]
        angles_below = [angle - 15 for angle in oc_angles]
        pi180 = math.pi / 180.0
        delta_pt_below = [[math.cos(angle_below * pi180) * total_distance,
                           math.sin(angle_below * pi180) * total_distance] for angle_below in angles_below]
        delta_pt_above = [[math.cos(angle_above * pi180) * total_distance,
                           math.sin(angle_above * pi180) * total_distance] for angle_above in angles_above]

        num_points = len(point_locations)
        self._dataset_values = np.full((NUMBER_OF_ANGLES, num_points), default_roughness)
        for pt_idx, location in enumerate(point_locations):
            self._log_progress(pt_idx, num_points)
            for angle_idx in range(NUMBER_OF_ANGLES):
                pt_above = [location[0] + delta_pt_above[angle_idx][0], location[1] + delta_pt_above[angle_idx][1], 0.0]
                pt_below = [location[0] + delta_pt_below[angle_idx][0], location[1] + delta_pt_below[angle_idx][1], 0.0]
                triangle_points = [location, pt_below, pt_above]
                triangle_centroid = (location[0] + pt_below[0] + pt_above[0]) / 3, \
                                    (location[1] + pt_below[1] + pt_above[1]) / 3

                #  1/2 Base * height
                temp = math.sqrt((pt_above[0] - pt_below[0]) ** 2 + (pt_above[1] - pt_below[1]) ** 2)
                total_area = .5 * total_distance * temp

                raster_points = self.get_points_in_polygon(triangle_points)
                total_raster_area = len(raster_points) * self._pixel_area
                area = [self._pixel_area] * len(raster_points)
                centroid_unused = self._find_centroid_for_unused_area(area, raster_points, total_area,
                                                                      triangle_centroid)
                unused_area = max(0.0, total_area - total_raster_area)  # make sure not negative

                raster_values = []
                weights = []
                for raster_point in raster_points:
                    code = int(raster_point[2])
                    idx = (np.abs(code_vals - code)).argmin()
                    roughness = roughness_vals[idx]
                    raster_values.append(roughness)
                    weights.append(compute_weight_from_distance(location, raster_point, weighted_distance))

                # build data for the unused section
                area.append(unused_area)
                raster_values.append(default_roughness)
                weights.append(compute_weight_from_distance(location, centroid_unused, weighted_distance))
                self._dataset_values[angle_idx][pt_idx] = compute_double_weighted_average(weights, raster_values, area)

    def _find_centroid_for_unused_area(self, areas, locs, total_area, xy_centroid_of_total):
        """Finds unused areas centroids from a list of points.

        Arguments:
            areas (list of float):  areas of each location.
            locs (list of tuples):  locations of the places to search.
            total_area (float):  the total area.
            xy_centroid_of_total (tuple of float):  centroid of the total area of interest.

        Returns:
            tuple of floats:  the centroid location
        """
        # compute unused area
        sum_area = 0.0
        for area in areas:
            sum_area += area
        unused_area = total_area - sum_area

        new_area = 0.0
        new_pt = [0, 0, 0]
        for i in range(1, len(areas)):
            if i == 1:
                new_pt = self._find_xy_centroid_for_total(areas[0], areas[1], locs[0], locs[1])
                new_area = areas[0] + areas[1]
            else:
                new_pt = self._find_xy_centroid_for_total(new_area, areas[i], new_pt, locs[i])
                new_area += areas[i]
        new_pt = self._find_xy_centroid_for_unused_area(new_area, unused_area, new_pt, xy_centroid_of_total)
        return new_pt

    def _find_xy_centroid_for_total(self, area_1, area_2, centroid_pt_1, centroid_pt_2):
        """Merges two areas to find the centroid.

        Arguments:
            area_1 (float):  The first area.
            area_2 (float):  The second area.
            centroid_pt_1 (Sequence):  The first centroid point.
            centroid_pt_2 (Sequence):  The second centroid point.

        Returns:
            tuple of floats:  The centroid of the total area.
        """
        return (area_1 * centroid_pt_1[0] + (area_2 * centroid_pt_2[0])) / (area_1 + area_2), \
               (area_1 * centroid_pt_1[1] + (area_2 * centroid_pt_2[1])) / (area_1 + area_2)

    def _find_xy_centroid_for_unused_area(self, area_1, unused_area, centroid_1, centroid_total):
        """Finds area2 centroid.

        Arguments:
            area_1 (float):  The first area.
            unused_area (float):  The unused area.
            centroid_1 (Sequence):  The first centroid.
            centroid_total (Sequence):  The total centroid.

        Returns:
            tuple of floats:  The centroid of the unused area.
        """
        return (centroid_total[0] * (area_1 + unused_area) - (centroid_1[0] * area_1)) / unused_area, \
               (centroid_total[1] * (area_1 + unused_area) - (centroid_1[1] * area_1)) / unused_area

    def initialize_raster(self, arguments):
        """Get the raster, its size, geo transform, and band, etc."""
        self.logger.info('Retrieving input raster...')
        self._raster = self.get_input_raster(arguments[ARG_INPUT_LANDUSE].text_value)
        self._x_origin = self._raster.xorigin
        self._y_origin = self._raster.yorigin
        self._pixel_width = self._raster.pixel_width
        self._pixel_height = self._raster.pixel_height
        self._xres, self._yres = self._raster.resolution
        self._raster_null_value = self._raster.nodata_value
        self._pixel_area = abs(self._pixel_width * self._pixel_height)

        min_pt, max_pt = self._raster.get_raster_bounds()
        self._min_pt = np.array(min_pt)
        self._max_pt = np.array(max_pt)
        band = self._raster.gdal_raster.GetRasterBand(1)
        band_stats = band.GetStatistics(True, True)
        self._min_pt[2] = band_stats[0]
        self._max_pt[2] = band_stats[1]
        xsize = (max_pt[0] - min_pt[0]) / self._xres
        ysize = (max_pt[1] - min_pt[1]) / self._yres

        self._origin = (self._min_pt[0], self._max_pt[1])
        self._spacing = (xsize, ysize)
        self._raster_values = self._raster.get_raster_values()

    def _log_progress(self, point_idx, num_points):
        """Logs a progress message after processing a batch of points.

        Args:
            point_idx (int): 0-based index of the current point being processed
            num_points (int): Total number of points being processed
        """
        interval = max(1, num_points // 100)  # log 100 times total
        if point_idx == 0 or (num_points > 100 and (point_idx + 1) % interval == 0):
            self.logger.info(f'Processing point {point_idx + 1} of {num_points}...')

    def _get_path(self, search_path, null_pt):
        # create a dictionary of the raster parameters to pass to the worker
        raster_params = {
            'origin': self._origin,
            'spacing': self._spacing,
            'xres': self._xres,
            'yres': self._yres,
            'min_pt': self._min_pt,
            'max_pt': self._max_pt,
            'raster_values': self._raster_values,
            'default_value': self._default_value
        }
        return _get_path_worker(raster_params, search_path, null_pt)

    def get_points_in_path_indices(self, poly_indices, poly_pts):
        """Gets raster points inside the boundary formed by given raster poly indices.

        Arguments:
            poly_indices (list of tuples):  The sorted raster indices of the boundary.
            poly_pts (list of tuples):  The points defining the polygon area.

        Returns:
            list of tuples:  The world locations found inside the polygon, with values from the raster.
        """
        # Get image values
        points = []
        pi = self._xres, self._yres

        index = 0
        end_index = len(poly_indices)
        while index != end_index:
            curr_j = poly_indices[index][1]
            start_j_index = curr_j * pi[0]
            adding = True
            interval_index_begin = index
            interval_index_end = interval_index_begin + 1

            # if we only have one value in a row, add the index and move to the next row by increment begin and end
            if interval_index_end != end_index and poly_indices[interval_index_end][1] != curr_j:
                self._add_points_by_index(start_j_index + poly_indices[interval_index_begin][0], self._spacing[0],
                                          self._spacing[1], self._min_pt, self._max_pt, self._xres, self._yres, points)
                interval_index_begin += 1
                interval_index_end += 1

            start_i = -1
            end_i = -1
            # Loop over the intervals in a single row
            while interval_index_end != end_index and poly_indices[interval_index_end][1] == curr_j:
                start_i = poly_indices[interval_index_begin][0]
                end_i = poly_indices[interval_index_end][0]

                # if the begining i(x) index and ending are the same, and there is  another interval we need to know
                # if the cells in the interval need  to be added
                if start_i == end_i:
                    pass
                    # test_index = start_j_index + end_i + 1
                    # # see if we have more segments on the line
                    # if test_index < end_index and poly_indices[test_index][1] == curr_j:
                    #     # see if the next interval should be adding points or skipped
                    #     adding = test_index_in_out(test_index, self._spacing[0], self._spacing[1], self._min_pt,
                    #                                self._max_pt, self._xres, self._yres, poly_pts)
                    #
                    #     # add endpoint only if we are going to start a non-adding section otherwise it will get added
                    #     # with the next segment or as the lastI/J when starting the new row.
                    #     if not adding:
                    #         last_index = start_j_index + end_i
                    #         self._add_points_by_index(last_index, self._spacing[0], self._spacing[1], self._min_pt,
                    #                                   self._max_pt, self._xres, self._yres, points)
                elif start_i + 1 == end_i:
                    # if our interval is two adjacent pixels
                    # Add starting point, go to next interval w/ non-adjacent pixels or
                    # the row, need to check pixel after adjacent cells
                    # Add segment start
                    start_index = start_j_index + start_i
                    self._add_points_by_index(start_index, self._spacing[0], self._spacing[1], self._min_pt,
                                              self._max_pt, self._xres, self._yres, points)

                    # test one after the end of the segment
                    test_index = start_j_index + end_i + 1
                    # only test if the next one isn't just one past (continuing the line)
                    # see if we have more segments on the line
                    if test_index < end_index and poly_indices[test_index][1] == curr_j and \
                            poly_indices[test_index][0] != end_i + 1:
                        adding = is_raster_index_in_poly(test_index, self._spacing[0], self._spacing[1], self._min_pt,
                                                         self._max_pt, self._xres, self._yres, poly_pts)

                        # add endpoint only if we are going to start a non-adding section  otherwise it will get added
                        # with the next segment or as the  lastI/J when starting the new row.
                        if not adding:
                            last_index = start_j_index + end_i
                            self._add_points_by_index(last_index, self._spacing[0], self._spacing[1], self._min_pt,
                                                      self._max_pt, self._xres, self._yres, points)
                else:
                    if adding:
                        # add values from beginning to ending
                        start_i = max(start_i, 0)
                        end_i = min(end_i, pi[0])

                        for i in range(start_i, end_i):
                            the_index = start_j_index + i
                            self._add_points_by_index(the_index, self._spacing[0], self._spacing[1], self._min_pt,
                                                      self._max_pt, self._xres, self._yres, points)
                    adding = not adding

                index = index + 1
                interval_index_begin = index
                interval_index_end = interval_index_begin + 1

            # Add the last end value
            if 0 < end_i < pi[0]:
                # add the previous end segment value
                last_index = start_j_index + end_i
                self._add_points_by_index(last_index, self._spacing[0], self._spacing[1], self._min_pt, self._max_pt,
                                          self._xres, self._yres, points)

            # Just finished a j line, go to the next
            index = index + 1

        return points

    def get_points_in_polygon(self, polygon_points):
        """Gets raster points inside the boundary formed by the polygon passed in.

        Arguments:
            polygon_points (list of tuples):  The points defining the polygon area.

        Returns:
            list of tuples:  The world locations found inside the polygon, with values from the raster.
        """
        poly_pts = polygon_points
        poly_pts.append(polygon_points[0])

        poly_indices = self._get_path_2d_indices(poly_pts)

        # Sort the indices so we are increasing from bottom to top and left to right within the rows
        temp_sort = sorted(poly_indices, key=_get_key_column)
        poly_indices = sorted(temp_sort, key=_get_key_row)

        if len(poly_indices) == 0:
            return []

        return self.get_points_in_path_indices(poly_indices, poly_pts)

    def _add_points_by_index(self, index, xsize, ysize, min_pt, max_pt, num_cols, num_rows, points):
        """Adds point locations to the list, based on its raster cell index.

        Arguments:
             index (int):  Index on the path.
             xsize (float):  Size in the x direction.
             ysize (float):  Size in the y direction.
             min_pt (tuple):  Min of the raster.
             max_pt (tuple):  Max of the raster.
             num_cols (int):  Number of columns on raster.
             num_rows (int):  Number of rows on raster.
             points (list of tuples):  Points we've found on the path, and add to if found.
        """
        cur_row = int(index / num_cols)
        cur_col = index % num_cols
        found, point = find_raster_location(xsize, ysize, min_pt, max_pt, cur_row, cur_col, num_cols, num_rows)
        if found:
            point[2] = get_raster_value(self._raster_values, cur_col, cur_row, self._default_value)
            points.append(point)

    def _get_path_2d_indices(self, search_path):
        """Gets the row/column indices of the path passed in.

        Arguments:
            search_path (list of tuples):  The path to perform the search on.

        Returns:
            list of tuples:  The row/column values of the path.
        """
        # Terrain following algorithm
        # For each search segment:
        # Collect elevation points along that segment
        # points whose x or y distance to the elevation is less than half a cell
        p1 = [
            (search_path[0][0] - self._origin[0]) / self._spacing[0],
            (self._origin[1] - search_path[0][1]) / self._spacing[1], 0.0
        ]
        pi = self._xres, self._yres

        # for all search segments
        indices_2d = []
        for ii in range(1, len(search_path)):
            p0 = [coord for coord in p1]
            p1[0] = (search_path[ii][0] - self._origin[0]) / self._spacing[0]
            p1[1] = (self._origin[1] - search_path[ii][1]) / self._spacing[1]

            # add first point if in raster
            pt1 = [int(p0[0]), int(p0[1])]
            #  point within raster
            if not (pt1[0] < 0 or pt1[0] >= pi[0] or pt1[1] < 0 or pt1[1] >= pi[1]):
                indices_2d.append(pt1)

            # if segment to one side of raster then next segment
            if p0[0] < 0 and p1[0] < 0 or p0[1] < 0 and p1[1] < 0 or p0[0] > float(pi[0]) and p1[0] > float(pi[0]) or \
                    p0[1] > float(pi[1]) and p1[1] > float(pi[1]):
                continue

            # prepare to traverse segment in unit x direction with a delta for y
            cdiff = [p1[0] - p0[0], p1[1] - p0[1], 0.0]
            ynotx = cdiff[0] ** 2 < cdiff[1] ** 2
            if ynotx:
                # Swap
                cdiff[0], cdiff[1] = cdiff[1], cdiff[0]
                p0[0], p0[1] = p0[1], p0[0]
                p1[0], p1[1] = p1[1], p1[0]
            delta = tuple(np.divide(cdiff, -cdiff[0] if cdiff[0] < 0 else cdiff[0]))

            # move the initial point to the next integer coordinate
            # the adjustment is forward or behind depending the sign of delta
            pp = list(np.add(p0, np.multiply(delta, p0[0] - int(p0[0]) if delta[0] < 0 else int(p0[0] + 1) - p0[0])))

            # traverse all integer locations in x
            while pp[0] > p1[0] if delta[0] < 0 else pp[0] <= p1[0]:
                # Get unswapped point in double for saving and in integer for address
                ps = [coord for coord in pp]
                if ynotx:
                    # Swap
                    ps[0], ps[1] = ps[1], ps[0]
                pt2 = int(ps[0]), int(ps[1])
                # point within raster else next point
                if pt2[0] < 0 or pt2[0] >= pi[0] or pt2[1] < 0 or pt2[1] >= pi[1]:
                    pp = list(np.add(pp, delta))
                    continue

                indices_2d.append(pt2)
                pp = list(np.add(pp, delta))

            if ynotx:
                # Swap
                p1[0], p1[1] = p1[1], p1[0]

        return indices_2d

    def _setup_output_dataset_builder(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        for i in range(12):
            dataset_name = f'{OUTPUT_TREE_PATH}/{arguments[ARG_OUTPUT_DATASET].text_value}_{i * 30:03d}'
            writer = self.get_output_dataset_writer(
                name=dataset_name,
                geom_uuid=self._grid_uuid,
                null_value=self._raster_null_value,
            )
            self._builders.append(writer)

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Adding output datasets...')
        # Reorder the array so it is in the "TO" direction in Cartesian space
        sort_indices = [int(((270 - i * 30) % 360) / 30) for i in range(NUMBER_OF_ANGLES)]
        sorted_dataset_values = self._dataset_values[sort_indices]
        for angle_idx in range(len(self._builders)):
            self.logger.info('Writing output directional roughness dataset to XMDF file...')
            builder = self._builders[angle_idx]
            builder.write_xmdf_dataset([0.0], [sorted_dataset_values[angle_idx]])
            # Send the dataset back to XMS in a folder under the target geometry
            self.set_output_dataset(builder)


def find_raster_location(xsize, ysize, min_pt, max_pt, cur_row, cur_col, num_cols, num_rows):
    """Finds the XY location of the raster based on the index passed in.

    Arguments:
        xsize (float):  The x step size/pixel size in the x direction
        ysize (float):  The y step size/pixel size in the y direction
        min_pt (tuple):  The image bounding box minimum
        max_pt (tuple):  The image bounding box maximum
        cur_row (int):  The row being found
        cur_col (int):  The column being found
        num_cols (int):  The number of columns in the raster
        num_rows (int):  The number of rows in the raster

    Returns:
        tuple (bool, list):  True if raster pixel found, False if not found, along with the location
    """
    location = [0, 0, 0]
    if cur_row < 0 or cur_row > num_rows - 1 or cur_col < 0 or cur_col > num_cols - 1:
        return False, location
    else:
        location[0] = min_pt[0] + (xsize * cur_col)
        location[1] = max_pt[1] - (ysize * cur_row)
        return True, location


def get_raster_value(raster_values: np.ndarray, col, row, default_value):
    """Get the raster value for a pixel.

    Args:
        raster_values (np.darray): Raster values.
        col (int): Column of the pixel value.
        row (int): Row of the pixel value.
        default_value (float): Default value.

    Returns:
        float: The raster value for the pixel or the default roughness value if out of bounds
    """
    if row < 0 or col < 0:
        return default_value
    shape = raster_values.shape
    if row >= shape[0] or col >= shape[1]:
        return default_value
    return raster_values[row][col]


def is_raster_index_in_poly(index, xsize, ysize, min_pt, max_pt, num_cols, num_rows, points):
    """Tests if the raster cell index is inside the polygon passed in.

    Arguments:
         index (int):  Index on the path.
         xsize (float):  Size in the x direction.
         ysize (float):  Size in the y direction.
         min_pt (tuple):  Min of the raster.
         max_pt (tuple):  Max of the raster.
         num_cols (int):  Number of columns on raster.
         num_rows (int):  Number of rows on raster.
         points (list of tuples):  Points we've found on the path.

    Returns:
        bool:  True if in the polygon, False otherwise.
    """
    cur_row = int(index / num_cols)
    cur_col = index % num_cols
    found, point = find_raster_location(xsize, ysize, min_pt, max_pt, cur_row, cur_col, num_cols, num_rows)
    if found:
        return point_in_polygon_2d(polygon=points, point=point) != -1
    return False
