"""RefineUGridByErrorTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import itertools
import os

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import UGridBuilder
from xms.core.filesystem import filesystem
from xms.gdal.rasters import RasterInput, RasterReproject
from xms.gdal.rasters.raster_utils import fix_data_types, make_raster_projections_consistent
from xms.gdal.utilities import gdal_wrappers as gw
from xms.grid.geometry.geometry import on_line_and_between_endpoints_2d
from xms.grid.ugrid import UGrid as XmUGrid
from xms.tool_core import IoDirection, Tool
from xms.tool_core.tool import equivalent_arguments

# 4. Local modules
from xms.tool.algorithms.mesh_2d.mesh_from_ugrid import MeshFromUGrid

ARG_INPUT_GRID = 0
ARG_OUTPUT_GRID = 1
ARG_ERROR_THRESHOLD = 2
ARG_MAX_ITRS = 3
ARG_RASTER_1 = 4
ARG_RASTER_2 = 5

DEFAULT_TOLERANCE = 0.000001  # XM_ZERO_TOL


class RefineUGridByErrorTool(Tool):
    """Tool to interpolate multiple rasters to a UGrid with priority."""

    def __init__(self, name='Refine UGrid by Error'):
        """Initializes the class."""
        super().__init__(name=name)
        self._file_count = 0
        self._input_cogrid = None
        self._ugrid = None
        self._input_raster_filenames = []
        self._error = 2.0

        self._cells_over_threshold = set()
        self._split_cellstreams = {}
        self._split_into_num_cells = {}
        self._adjacent_cellstreams = {}
        self._split_edges_to_midpoint = {}
        self._cell_refinement_level = {}

        self._force_ugrid = True
        self._geom_txt = 'UGrid'

        # Used to build up lists of node ids that outside the raster bounds
        self._out_of_bounds_ids = set()

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='grid', description=f'{self._geom_txt.capitalize()}'),
            self.string_argument(name='refined_ugrid', description=f'Output {self._geom_txt} name', value='refined',
                                 optional=False),
            self.float_argument(name='error_threshold', description='Error threshold', value=5.0, min_value=0.0),
            self.integer_argument(name='max_itrs', description='Maximum iterations', value=10, min_value=1),
            self.raster_argument(name='raster_1', description='Raster 1 (highest priority)'),
            self.raster_argument(name='raster_2', description='Raster 2', optional=True),
        ]
        return arguments

    def input_raster_count(self, arguments):
        """Get the number of input raster arguments.

        Doesn't check to see if argument is set to a value.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        return len(arguments) - ARG_RASTER_1

    def add_input_raster(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        raster_count = self.input_raster_count(arguments)

        # add new last raster argument
        raster_count += 1
        raster_argument = self.raster_argument(name=f'raster_{raster_count}', description=f'Raster {raster_count}',
                                               optional=True, value=None)
        arguments.append(raster_argument)

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        last_raster = arguments[-1]
        if last_raster.value is not None and last_raster.value != '':
            self.add_input_raster(arguments)

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the input grid
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])
        # Validate rasters are specified and match
        self._validate_input_rasters(errors, arguments, self._get_input_raster_names(arguments))

        return errors

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        # Make sure there are 6 or more arguments
        default_arguments = self.initial_arguments()
        default_length = len(default_arguments)
        if len(arguments) < default_length:
            return False
        if not equivalent_arguments(arguments[0:default_length], default_arguments):
            return False
        for i in range(len(default_arguments), len(arguments)):
            if arguments[i].io_direction != IoDirection.INPUT:
                return False
            if arguments[i].type != 'raster':
                return False
        return True

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        self._input_cogrid = self.get_input_grid(argument.text_value)
        if not self._input_cogrid:
            errors[argument.name] = f'Could not read {self._geom_txt}.'
        elif not self._input_cogrid.check_all_cells_2d():
            errors[argument.name] = f'{self._geom_txt.capitalize()} must be 2D.'
        else:
            self._ugrid = self._input_cogrid.ugrid

    def _validate_input_rasters(self, errors, arguments, rasters):
        """Validate input rasters are specified.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            arguments (list): The tool arguments.
            rasters (list): The input raster names.
        """
        for raster_text in rasters:
            key = None
            for argument in arguments:
                if argument.value == raster_text:
                    key = argument.name
            if key is not None:
                raster_filename = self.get_input_raster_file(raster_text)
                if raster_filename and os.path.isfile(raster_filename):
                    self._input_raster_filenames.append(raster_filename)

    def _get_input_raster_names(self, arguments):
        """Get the input rasters and blend distances for primaries.

        Args:
            arguments(list[str]): The arguments.

        Returns:
            list: List of the input raster names
        """
        raster_names = []
        for i in range(ARG_RASTER_1, len(arguments)):
            if arguments[i].text_value != '':
                raster = arguments[i].value
                raster_names.append(raster)
        return raster_names

    def get_cell_edges_for_splitting(self, ug, index, locations):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid to modify.
            index (int): the cell index
            locations (list): locations of grid points

        Returns:
            (list): list of edges (all collinear points that make up an edge)
        """
        pts = []
        if index not in self._adjacent_cellstreams:
            pts = ug.get_cell_points(index)
        else:
            cellstream = self._adjacent_cellstreams[index]
            pts = cellstream[2:]

        num_pts = len(pts)
        edge_list = []

        # find the first point to start with
        index1 = 0
        loc1 = locations[pts[index1]]
        index2 = self.advance_index(index1, num_pts, True)
        loc2 = locations[pts[index2]]
        index3 = self.advance_index(index2, num_pts, True)
        loc3 = locations[pts[index3]]

        while on_line_and_between_endpoints_2d(loc1, loc3, loc2, DEFAULT_TOLERANCE):
            loc1 = loc2
            loc2 = loc3
            index2 = index3
            index3 = self.advance_index(index3, num_pts, True)
            loc3 = locations[pts[index3]]

        beginning = first_pt = pts[index2]
        done = False

        while done is False:
            cur_edge = []
            cur_edge.extend([first_pt])

            # advance everything again
            loc1 = loc2
            loc2 = loc3
            index2 = index3
            index3 = self.advance_index(index3, num_pts, True)
            loc3 = locations[pts[index3]]

            while on_line_and_between_endpoints_2d(loc1, loc3, loc2, DEFAULT_TOLERANCE):
                cur_edge.extend([pts[index2]])
                loc1 = loc2
                loc2 = loc3
                index2 = index3
                index3 = self.advance_index(index3, num_pts, True)
                loc3 = locations[pts[index3]]

            if pts[index2] == beginning:
                cur_edge.extend([pts[index2]])
                edge_list.extend([cur_edge])
                done = True
                continue

            cur_edge.extend([pts[index2]])
            edge_list.extend([cur_edge])
            first_pt = pts[index2]

        return edge_list

    def edge_exceeds_threshold(self, ug, raster, pt1, pt2):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid to modify.
            raster (RasterInput): the raster
            pt1 (int): first pt on edge
            pt2 (int): second pt on edge

        Returns:
            (bool): does the error exceed the threshold
        """
        end_pts = [pt1, pt2]
        end_locs = ug.get_points_locations(end_pts)
        dx = end_locs[1][0] - end_locs[0][0]
        dy = end_locs[1][1] - end_locs[0][1]
        dz = end_locs[1][2] - end_locs[0][2]

        num_check_pts_x = int(abs(dx / raster.pixel_width))
        num_check_pts_y = int(abs(dy / raster.pixel_height))
        num_check_pts = max(num_check_pts_x, num_check_pts_y)

        cur_check_pt = 1
        xy_loc = []

        while cur_check_pt <= num_check_pts:
            xy_loc = [end_locs[0][0] + (dx * (cur_check_pt / (num_check_pts + 1))),
                      end_locs[0][1] + (dy * (cur_check_pt / (num_check_pts + 1)))]
            ugrid_elev = end_locs[0][2] + (dz * (cur_check_pt / (num_check_pts + 1)))

            raster_elev = self.raster_elev_at_loc(raster, xy_loc[0], xy_loc[1], ugrid_elev)

            if abs(ugrid_elev - raster_elev) > self._error:
                return True

            cur_check_pt += 1

        return False

    def calculate_centroid_z(self, ug, cell_id):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid to modify.
            cell_id (int): id of cell to check

        Returns:
            (float): centroid z
        """
        pts = ug.get_cell_locations(cell_id)

        tot_z = 0.0
        for pt in pts:
            tot_z += pt[2]

        return tot_z / len(pts)

    def centroid_exceeds_threshold(self, ug, raster, cell_id):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid to modify.
            raster (RasterInput): the raster
            cell_id (int): id of cell to check

        Returns:
            (bool): does the error exceed the threshold
        """
        centroid = ug.get_cell_centroid(cell_id)[1]
        ugrid_elev = self.calculate_centroid_z(ug, cell_id)

        raster_elev = self.raster_elev_at_loc(raster, centroid[0], centroid[1], ugrid_elev)

        if abs(ugrid_elev - raster_elev) > self._error:
            return True

        return False

    def do_cell_split(self, ug, locations, id, raster):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            locations (list): locations of grid points
            id (int): cell id
            raster (RasterInput): raster for elevations

        """
        new_cellstream = []
        edges_for_split = self.get_cell_edges_for_splitting(ug, id, locations)
        mid_pts = []
        for edge in edges_for_split:
            pt1 = edge[0]
            pt2 = edge[len(edge) - 1]
            mid_pt_id = self.get_edge_midpoint(ug, pt1, pt2, locations, raster)
            mid_pts.extend([mid_pt_id])

        if len(edges_for_split) > 3:
            centroid = list(ug.get_cell_centroid(id)[1])
            centroid_z = self.calculate_centroid_z(ug, id)
            centroid[2] = self.raster_elev_at_loc(raster, centroid[0], centroid[1], centroid_z)
            locations.extend([centroid])
            centroid_id = len(locations) - 1

            # create the new cellstreams - all new cells are quads
            for i, edge in enumerate(edges_for_split):
                new_edges = []
                new_edges.extend([edge[0]])

                if len(edge) > 2:
                    mid_pt_idx = edge.index(mid_pts[i])
                    new_edges.extend(edge[1:mid_pt_idx])
                new_edges.extend([mid_pts[i], centroid_id])

                prev_edge_index = self.advance_index(i, len(edges_for_split), False)
                prev_edge = edges_for_split[prev_edge_index]
                prev_mid_pt = mid_pts[prev_edge_index]

                if len(prev_edge) > 2:
                    mid_pt_idx = prev_edge.index(prev_mid_pt)
                    new_edges.extend(prev_edge[mid_pt_idx:-1])
                else:
                    new_edges.extend([prev_mid_pt])

                new_cellstream.extend([7, len(new_edges)])  # polygon, number of sides
                new_cellstream.extend(new_edges)

            num_new_cells = len(edges_for_split)

        else:
            # create the new cellstreams - all new cells are tris
            for i, edge in enumerate(edges_for_split):
                new_edges = []

                if len(edge) > 2:
                    mid_pt_idx = edge.index(mid_pts[i])
                    new_edges.extend(edge[1:mid_pt_idx])

                prev_edge_index = self.advance_index(i, len(edges_for_split), False)
                prev_edge = edges_for_split[prev_edge_index]
                new_edges.extend([mid_pts[i], mid_pts[prev_edge_index]])

                if len(prev_edge) > 2:
                    prev_mid_pt = mid_pts[prev_edge_index]
                    mid_pt_idx = prev_edge.index(prev_mid_pt)
                    new_edges.extend(prev_edge[mid_pt_idx + 1:-1])
                new_edges.extend([edge[0]])

                new_cellstream.extend([7, len(new_edges)])  # polygon, number of sides
                new_cellstream.extend(new_edges)

            # create the triangle in the middle
            new_cellstream.extend([5, 3, mid_pts[0], mid_pts[1], mid_pts[2]])

            num_new_cells = 4

        # update the refinement level
        self._cell_refinement_level[id] = self._cell_refinement_level[id] + 1
        self._split_cellstreams[id] = new_cellstream
        self._split_into_num_cells[id] = num_new_cells

        # update adjacent cells
        for i, edge in enumerate(edges_for_split):
            if len(edge) == 2:
                # this is the only case where a point was added
                adj_ids = self._ugrid.get_edge_adjacent_cells(edge)
                for adj_id in adj_ids:
                    if adj_id == id:
                        continue
                    self.update_cell_adj_to_split(ug, edge[0], mid_pts[i], edge[1], adj_id)

                    # check if it should be split due to refinement levels
                    if self._cell_refinement_level[id] > self._cell_refinement_level[adj_id] + 1:
                        self.do_cell_split(ug, locations, adj_id, raster)

        if id in self._adjacent_cellstreams.keys():
            del self._adjacent_cellstreams[id]

    def split_cells(self, ug, raster):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            raster (RasterInput): raster for elevations

        Returns:
            (UGrid): the new grid
            (set): set of cells to check on the next iteration
        """
        if len(self._cells_over_threshold) == 0:
            return ug, []

        check_cells = set()
        locations = list(ug.locations)

        # split cells
        for id in self._cells_over_threshold:
            self.do_cell_split(ug, locations, id, raster)

        # create the new UGrid
        old_cell_count = ug.cell_count
        old_cellstream = np.asarray(ug.cellstream)
        new_cellstream = []
        tmp_cellstream = []

        cur_old_cell_id = 0
        cur_new_cell_id = 0
        stream_idx = 0

        num_to_change = len(self._split_cellstreams) + len(self._adjacent_cellstreams)
        num_changed = 0

        new_refinement_levels = {}
        while cur_old_cell_id < old_cell_count:
            num_cell_pts = old_cellstream[stream_idx + 1]
            next_cell_stream_idx = stream_idx + num_cell_pts + 2
            refine_level = self._cell_refinement_level[cur_old_cell_id]

            # is the cell split?
            if cur_old_cell_id in self._split_cellstreams.keys():
                over_threshold = cur_old_cell_id in self._cells_over_threshold
                tmp_cellstream.extend(self._split_cellstreams[cur_old_cell_id])
                new_cell = 0
                while new_cell < self._split_into_num_cells[cur_old_cell_id]:
                    if over_threshold is True:
                        check_cells.add(cur_new_cell_id)
                    new_refinement_levels[cur_new_cell_id] = refine_level
                    cur_new_cell_id += 1
                    new_cell += 1

                num_changed += 1

            # is it an adjacent cell to be modified?
            elif cur_old_cell_id in self._adjacent_cellstreams.keys():
                tmp_cellstream.extend(self._adjacent_cellstreams[cur_old_cell_id])
                new_refinement_levels[cur_new_cell_id] = refine_level
                cur_new_cell_id += 1
                num_changed += 1

            else:
                # add the cell to the new cellstream
                tmp_cellstream.extend(old_cellstream[stream_idx:next_cell_stream_idx])
                new_refinement_levels[cur_new_cell_id] = refine_level
                cur_new_cell_id += 1

            if num_changed == num_to_change:
                # we are done removing after this
                tmp_cellstream.extend(old_cellstream[next_cell_stream_idx:])

                # copy over the rest of the refinement levels
                cur_old_cell_id += 1
                while cur_old_cell_id < old_cell_count:
                    new_refinement_levels[cur_new_cell_id] = self._cell_refinement_level[cur_old_cell_id]
                    cur_old_cell_id += 1
                    cur_new_cell_id += 1
                new_cellstream.extend(tmp_cellstream)
                break

            cur_old_cell_id += 1
            stream_idx = next_cell_stream_idx

            if (len(tmp_cellstream)) % 5000 == 0:
                new_cellstream.extend(tmp_cellstream)
                tmp_cellstream.clear()

        # create the new ugrid
        ug = XmUGrid(locations, new_cellstream)

        # update the refine levels
        self._cell_refinement_level = new_refinement_levels

        return ug, check_cells

    def update_cell_adj_to_split(self, ug, edge_pt1, new_pt, edge_pt2, cell_id):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            edge_pt1 (int): pt1 of split edge
            new_pt (int): new pt
            edge_pt2 (int): pt2 of split edge
            cell_id (int): cell id

        """
        if cell_id in self._split_cellstreams:
            cellstream = self._split_cellstreams[cell_id]
            index = 0
            while index < len(cellstream):
                cell_type_index = index
                index += 1
                num_pts_index = index
                num_points = cellstream[num_pts_index]
                index += 1
                pts = cellstream[index:(index + num_points)]
                for i, pt in enumerate(pts):
                    if pt == edge_pt2:
                        next_i = self.advance_index(i, num_points, True)
                        if pts[next_i] == edge_pt1:
                            cellstream[cell_type_index] = 7
                            cellstream[num_pts_index] = num_points + 1
                            cellstream.insert(index + next_i, new_pt)
                            return

                # move on to the next split cell
                index += num_points
        else:
            new_cellstream = []
            if cell_id in self._adjacent_cellstreams.keys():
                cellstream = self._adjacent_cellstreams[cell_id]
                num_sides = len(cellstream) - 1  # only subtract one instead of two because we are adding a side
                new_cellstream.extend([7, num_sides])  # poly cell type and number of sides
                pts = cellstream[2:]
                index = pts.index(edge_pt2)
                pts.insert(index + 1, new_pt)
                new_cellstream.extend(pts)

            else:
                pts = list(self._ugrid.get_cell_points(cell_id))
                num_sides = len(pts)
                new_cellstream.extend([7, num_sides + 1])  # poly cell type and number of sides
                index = pts.index(edge_pt2)
                pts.insert(index + 1, new_pt)
                new_cellstream.extend(pts)

            self._adjacent_cellstreams[cell_id] = new_cellstream

    def advance_index(self, cur_index, num_items, forward):
        """Called to determine if arguments are valid.

        Args:
            cur_index (int): index to be advanced.
            num_items (int): number of items.
            forward (bool): true = move forward, false = move back

        Returns:
            (int): new index
        """
        new_index = -1
        if forward is True:
            if cur_index == num_items - 1:
                new_index = 0
            else:
                new_index = cur_index + 1
        else:
            if cur_index == 0:
                new_index = num_items - 1
            else:
                new_index = cur_index - 1

        return new_index

    def raster_elev_at_loc(self, raster, x, y, default_z, log_id=None):
        """Find the pixel containing a point coordinate.

        Args:
            raster (RasterInput): the raster
            x (float): X-coordinate of the location
            y (float): Y-coordinate of the location
            default_z (float): Value to use if raster value is nodata
            log_id (int): If provided, will store the id when the location is outside the raster bounds

        Returns:
            (float): The raster elevation at the x,y location. Returns the default value if there is no data.
        """
        elev = raster.get_raster_value_at_loc(x, y)
        if elev == raster.nodata_value:
            if log_id is not None:
                self._out_of_bounds_ids.add(log_id)
            return default_z
        return elev

    def get_edge_midpoint(self, ug, pt1, pt2, locations, raster):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            pt1 (int): first point on edge
            pt2 (int): second point on edge
            locations (list): locations of grid points
            raster (RasterInput): raster for the elevation

        Returns:
            (int): point id for the new point
        """
        # has a midpoint already been created?
        if (pt2, pt1) in self._split_edges_to_midpoint.keys():
            return self._split_edges_to_midpoint[(pt2, pt1)]

        end_pts = [pt1, pt2]
        end_locs = ug.get_points_locations(end_pts)
        xy_mid = [((end_locs[0][0] + end_locs[1][0]) / 2.0), ((end_locs[0][1] + end_locs[1][1]) / 2.0),
                  ((end_locs[1][2] + end_locs[1][2]) / 2.0)]
        xy_mid[2] = self.raster_elev_at_loc(raster, xy_mid[0], xy_mid[1], xy_mid[2])
        locations.append(xy_mid)
        mid_pt_id = len(locations) - 1
        self._split_edges_to_midpoint[(pt1, pt2)] = mid_pt_id

        return mid_pt_id

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # start with the right raster values
        temp_folder = filesystem.temp_filename()  # Gets deleted with process
        os.mkdir(temp_folder)
        vrt_filename = os.path.join(temp_folder, 'work.vrt')
        self.logger.info('Generating merged virtual raster from inputs...')
        self._input_raster_filenames.reverse()
        self._input_raster_filenames = fix_data_types(self._input_raster_filenames)
        self._input_raster_filenames = make_raster_projections_consistent(self._input_raster_filenames,
                                                                          RasterReproject.GRA_Cubic)
        gw.gdal_build_vrt(self._input_raster_filenames, vrt_filename)
        raster_input = RasterInput(vrt_filename)

        # Interpolate from the VRT raster to the input grid locations
        self.logger.info(f'Interpolating from source raster to target {self._geom_txt}...')
        locations = self._ugrid.locations
        num_points = len(self._ugrid.locations)

        for i, location in enumerate(locations):
            if (i + 1) % 10000 == 0:
                self.logger.info(f'Processing point {i + 1} of {num_points}...')
            location[2] = self.raster_elev_at_loc(raster_input, location[0], location[1], location[2], log_id=i + 1)
        self._report_out_of_bounds_points()

        self._ugrid.locations = locations

        # now do refining
        self._error = float(arguments[ARG_ERROR_THRESHOLD].text_value)

        visited = set()
        cur_num_cells = self._ugrid.cell_count
        to_check = set(range(cur_num_cells))
        self._cell_refinement_level = dict(zip(to_check, itertools.repeat(0)))
        max_iterations = int(arguments[ARG_MAX_ITRS].text_value)
        cur_iteration = 1

        while len(to_check) > 0 and cur_iteration <= max_iterations:
            for cell_id in to_check:
                if cell_id not in visited:
                    if self.centroid_exceeds_threshold(self._ugrid, raster_input, cell_id):
                        self._cells_over_threshold.add(cell_id)
                        visited.add(cell_id)
                        continue

                    edges = self._ugrid.get_cell_edges(cell_id)
                    done = False
                    for edge in edges:
                        if done:
                            break
                        adj_cells = self._ugrid.get_edge_adjacent_cells(edge)
                        for adj_id in adj_cells:
                            if (len(adj_cells) != 1 and adj_id == cell_id):
                                continue

                            if adj_id in to_check:
                                # if this cell has already been visited and it's not split, this edge is fine
                                if adj_id in visited and adj_id not in self._cells_over_threshold:
                                    continue

                                if self.edge_exceeds_threshold(self._ugrid, raster_input, edge[0], edge[1]):
                                    self._cells_over_threshold.add(cell_id)
                                    self._cells_over_threshold.add(adj_id)
                                    visited.add(adj_id)
                                    done = True

                    visited.add(cell_id)

            self._ugrid, to_check = self.split_cells(self._ugrid, raster_input)

            new_num_cells = self._ugrid.cell_count
            cells_created_this_itr = new_num_cells - cur_num_cells
            self.logger.info(f'Iteration {cur_iteration} complete. {cells_created_this_itr} cells created.')

            # reset all the lists for the next iteration
            visited.clear()
            self._cells_over_threshold.clear()
            self._split_into_num_cells.clear()
            self._split_cellstreams.clear()
            self._adjacent_cellstreams.clear()

            # advance for next iteration
            cur_iteration += 1
            cur_num_cells = new_num_cells

        if cur_iteration >= max_iterations:
            self.logger.info('Maximum number of iterations reached.')
        else:
            self.logger.info('Refining completed.')

        if self._force_ugrid is True:
            co_builder = UGridBuilder()
            co_builder.set_is_2d()
            co_builder.set_unconstrained()
            co_builder.set_ugrid(self._ugrid)
            output = co_builder.build_grid()
        else:
            convert = MeshFromUGrid()
            output, _ = convert.convert(source_opt=convert.SOURCE_OPT_POINTS, input_ugrid=self._ugrid,
                                        logger=self.logger, tris_only=False, split_collinear=True)

        self.set_output_grid(output, arguments[ARG_OUTPUT_GRID], None, force_ugrid=self._force_ugrid)

    def _report_out_of_bounds_points(self):
        """Log a message if there are any input locations that are outside the raster bounds."""
        if len(self._out_of_bounds_ids) > 0:
            self.logger.warning(
                'Locations with the following ids were outside of the raster bounds:\n'
                f'{np.array(list(self._out_of_bounds_ids))}'
            )
            self._out_of_bounds_ids = set()
