"""InterpolateToUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import os
import time
from timeit import default_timer
from typing import List

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.misc.observer import Observer
from xms.datasets.dataset_writer import DatasetWriter
from xms.interp.interpolate.interp_idw import InterpIdw
from xms.interp.interpolate.interp_linear import InterpLinear
from xms.tool_core import ALLOW_ONLY_SCALARS, Argument, Tool

# 4. Local modules
from xms.tool.algorithms.earcut.earcut import earcut

DEFAULT_EXTRAP_VALUE = -9876543        # default extrapolation value
XM_NODATA = -9999999                   # from xmscore
ARG_INPUT_DATASET = 0                  # input dataset to interpolate from
ARG_INPUT_ALL_TIMESTEPS = 1            # option to interpolate all timesteps
ARG_INPUT_DS_TIMESTEP = 2              # input dataset time step
ARG_INPUT_TARGET_UGRID = 3             # ugrid to interpolate to
ARG_OUTPUT_DATASET_NAME = 4            # dataset name - can be blank
ARG_INPUT_DATASET_LOCATION = 5         # points or cells
ARG_INPUT_INTERP_METHOD = 6            # Interp method - linear, idw, natural neighbor...
ARG_INPUT_INTERP_DIMENSION = 7         # 2D or 3D interpolation
ARG_INPUT_TRUNCATE = 8                 # truncate interpolated values
ARG_INPUT_TRUNC_MIN = 9                # truncate range minimum
ARG_INPUT_TRUNC_MAX = 10                # truncate range maximum
ARG_INPUT_EXTRAP = 11                  # extrapolation option - constant, idw, existing dataset
ARG_INPUT_EXTRAP_VAL = 12              # constant extrapolation value
ARG_INPUT_EXTRAP_IDW = 13              # idw weight computation option for extrapolation
ARG_INPUT_EXTRAP_IDW_NUM_NEAREST = 14  # number of nearest points for IDW weight
ARG_INPUT_EXTRAP_IDW_QUADRANT = 15     # use points in quadrants
ARG_INPUT_EXTRAP_DATASET = 16          # existing dataset for extrapolation
ARG_INPUT_EXTRAP_DS_TIMESTEP = 17      # time step for extrapolation existing dataset
ARG_INPUT_LINEAR_CLOUGH_TOCHER = 18    # check box to use clough tocher linear interp
ARG_INPUT_IDW_NODAL_FUNC = 19          # idw nodal function option
ARG_INPUT_IDW_NODAL_CLASSIC = 20       # idw class weight funcion for constant nodal function
ARG_INPUT_IDW_NODAL_EXPONENT = 21      # idw weighting exponent
ARG_INPUT_IDW_NOD_COEF = 22            # idw nodal coefficient computation option
ARG_INPUT_IDW_NOD_COEF_NEAR = 23       # idw nodal coeff computation nearest number of points
ARG_INPUT_IDW_NOD_COEF_QUAD = 24       # idw nodal coeff computation nearest in quadrants
ARG_INPUT_IDW_WEIGHT = 25              # idw interp weight computation option
ARG_INPUT_IDW_WEIGHT_NEAR = 26         # idw interp weight nearest number of points
ARG_INPUT_IDW_WEIGHT_QUAD = 27         # idw interp weight nearest in quadrants
ARG_INPUT_NN_NODAL_FUNC = 28           # natural neighbor nodal function
ARG_INPUT_NN_NODAL_COMP = 29           # natural neighbor nodal functions "computed by" option
ARG_INPUT_NN_NODAL_COMP_NEAR = 30      # natural neighbor number nearest for computing nodal function
ARG_INPUT_ANISOTROPY = 31              # toggle to specify anisotropy
ARG_INPUT_ANIS_HORIZ = 32              # horizontal anisotropy factor
ARG_INPUT_ANIS_AZIMUTH = 33            # horizontal anisotropy azimuth
ARG_INPUT_ANIS_VERTICAL = 34           # vertical anisotropy
ARG_INPUT_LOG = 35                     # log interpolation
ARG_INPUT_LOG_MIN = 36                 # min value for log transformed data
ARG_END = 37                           # end of arguments


class InterpObserver(Observer):
    """Observer for the interpolator."""

    def __init__(self, logger):
        """Constructor.

        Args:
            logger: logger
        """
        super(InterpObserver, self).__init__()
        self.logger = logger

    def on_progress_status(self, percent_complete):
        """Set percent complete.

        Args:
            percent_complete (float): percent
        """
        self.logger.info(f'Percent complete: {int(100 * percent_complete)}')

    def time_remaining_in_seconds(self, remaining_seconds):
        """Set the time remaining in seconds.

        Args:
            remaining_seconds: Time remaining in seconds.
        """
        pass  # This method must be defined or you get a stack overflow

    def time_elapsed_in_seconds(self, elapsed_seconds):
        """Overwrite the time elapsed in seconds.

        Args:
            elapsed_seconds: Time that has elapsed in seconds.
        """
        pass  # This method must be defined or you get a stack overflow


class InterpolateToUGridTool(Tool):
    """Tool to interpolate from a ugrid to another ugrid."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Interpolate to UGrid')
        self._args = None
        self._interpolator = None
        self._data = {'log_interp': False, 'min_log': 1.0e-6}
        self._output_ds = None
        self._observer = InterpObserver(self.logger)
        self._prog_increment = 1
        self._timer = None
        self._linear_prog_time = 1.0
        self._null_value = float(XM_NODATA)

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='dataset', description='Source dataset'),
            self.bool_argument(name='all_time_steps', description='Interpolate all timesteps', value=False),
            self.timestep_argument(name='time_step', description='Timestep', value=Argument.NONE_SELECTED,
                                   optional=True),
            self.grid_argument(name='target_grid', description='Target grid'),
            self.string_argument(name='dataset_name', description='Target dataset name', optional=True),
            self.string_argument(name='dataset_location', description='Target dataset location', value='Points',
                                 choices=['Points', 'Cells']),
            self.string_argument(name='interp_method', description='Interpolation method', value='Linear',
                                 choices=['Linear', 'Inverse distance weighted (IDW)', 'Natural neighbor']),
            self.string_argument(name='interp_dimension', description='Interpolation dimension', value='2D',
                                 choices=['2D', '3D']),
            self.string_argument(name='truncate', description='Truncate interpolated values option',
                                 value='Do not truncate',
                                 choices=['Do not truncate', 'Truncate to min/max of source dataset',
                                          'Truncate to specified min/max']),
            self.float_argument(name='trunc_min', description='Truncate range minimum', value=0.0),
            self.float_argument(name='trunc_max', description='Truncate range maximum', value=0.0),
            self.string_argument(name='extrap_option', description='Extrapolation option', value='Constant value',
                                 choices=['No extrapolation', 'Constant value', 'Inverse distance weighted (IDW)',
                                          'Existing dataset']),
            self.float_argument(name='extrap_val', description='Extrapolation constant value', value=0.0),
            self.string_argument(name='extrap_idw_option',
                                 description='Extrapolation IDW interpolation weights computation option',
                                 value='Use nearest points', choices=['Use nearest points', 'Use all points']),
            self.integer_argument(name='extrap_idw_nearest', description='Extrapolation IDW number of nearest points',
                                  min_value=1, value=16),
            self.bool_argument(name='extrap_idw_quadrant',
                               description='Extrapolation IDW use nearest points in each quadrant', value=False),
            self.dataset_argument(name='extrap_dataset', description='Existing dataset', filters=[ALLOW_ONLY_SCALARS],
                                  optional=True),
            self.timestep_argument(name='extrap_ds_time_step', description='Existing datset timestep',
                                   value=Argument.NONE_SELECTED),
            self.bool_argument(name='linear_clough_tocher', description='Clough-Tocher', value=False),
            self.string_argument(name='idw_nodal_func', description='IDW nodal function',
                                 value="Constant (Shepard's method)",
                                 choices=["Constant (Shepard's method)", 'Gradient plane', 'Quadratic']),
            self.bool_argument(name='idw_const_classic',
                               description='IDW constant nodal function use classic weight function', value=False),
            self.float_argument(name='idw_const_exp', description='IDW constant nodal function weighting exponent',
                                value=2.0),
            self.string_argument(name='idw_nodal_coeff', description='IDW computation of nodal coefficients option',
                                 value='Use nearest points', choices=['Use nearest points', 'Use all points']),
            self.integer_argument(name='idw_nodal_coeff_nearest',
                                  description='IDW nodal coefficients number of nearest points', min_value=1, value=16),
            self.bool_argument(name='idw_nodal_coeff_quadrant',
                               description='IDW nodal coefficients use nearest points in each quadrant', value=False),
            self.string_argument(name='idw_weights', description='IDW computation of interpolation weights option',
                                 value='Use nearest points', choices=['Use nearest points', 'Use all points']),
            self.integer_argument(name='idw_weights_nearest',
                                  description='IDW interpolation weights number of nearest points', min_value=1,
                                  value=16),
            self.bool_argument(name='idw_weights_quadrant',
                               description='IDW interpolation weights use nearest points in each quadrant',
                               value=False),
            self.string_argument(name='nn_nodal_func', description='Natural neighbor nodal function',
                                 value='Constant', choices=['Constant', 'Gradient plane', 'Quadratic']),
            self.string_argument(name='nn_nodal_comp', description='Natural neighbor nodal function computation option',
                                 value='Use natural neighbors',
                                 choices=['Use natural neighbors', 'Use nearest points', 'Use all points']),
            self.integer_argument(name='nn_nodal_comp_nearest',
                                  description='Natural neighbor nodal function computation number of nearest points',
                                  min_value=1, value=16),
            self.bool_argument(name='anisotropy', description='Specify anisotropy', value=False),
            self.float_argument(name='anis_horiz', description='Horizontal anisotropy', value=1.0),
            self.float_argument(name='anis_azimuth', description='Azimuth', value=0.0),
            self.float_argument(name='anis_verical', description='Vertical anisotropy (1/z-mag)', value=1.0),
            self.bool_argument(name='log_interp', description='Log interpolation', value=False),
            self.float_argument(name='min_log_val',
                                description='Minimum log interpolation scalar: set dataset values <= 0 to this value',
                                value=1.0e-6),
        ]
        self.enable_arguments(arguments)
        return arguments

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        arguments[ARG_OUTPUT_DATASET_NAME].show = True
        # hide truncation options
        show = True if arguments[ARG_INPUT_TRUNCATE].value == 'Truncate to specified min/max' else False
        arguments[ARG_INPUT_TRUNC_MIN].show = show
        arguments[ARG_INPUT_TRUNC_MAX].show = show

        # hide extrapolation
        extrap_list = [ARG_INPUT_EXTRAP, ARG_INPUT_EXTRAP_VAL, ARG_INPUT_EXTRAP_IDW, ARG_INPUT_EXTRAP_IDW_NUM_NEAREST,
                       ARG_INPUT_EXTRAP_IDW_QUADRANT, ARG_INPUT_EXTRAP_DATASET, ARG_INPUT_EXTRAP_DS_TIMESTEP]
        for item in extrap_list:
            arguments[item].show = False
        if arguments[ARG_INPUT_INTERP_METHOD].value in ['Linear', 'Natural neighbor']:
            arguments[ARG_INPUT_EXTRAP].show = True
            extrap = arguments[ARG_INPUT_EXTRAP].value
            if extrap == 'Constant value':
                arguments[ARG_INPUT_EXTRAP_VAL].show = True
            elif extrap == 'Inverse distance weighted (IDW)':
                arguments[ARG_INPUT_EXTRAP_IDW].show = True
                if arguments[ARG_INPUT_EXTRAP_IDW].value == 'Use nearest points':
                    arguments[ARG_INPUT_EXTRAP_IDW_NUM_NEAREST].show = True
                    arguments[ARG_INPUT_EXTRAP_IDW_QUADRANT].show = True
            elif extrap == 'Existing dataset':
                arguments[ARG_INPUT_EXTRAP_DATASET].show = True

        # hide clough-toucher
        show = False if arguments[ARG_INPUT_INTERP_METHOD].value != 'Linear' else True
        arguments[ARG_INPUT_LINEAR_CLOUGH_TOCHER].show = show

        # hide idw options
        idw_list = [ARG_INPUT_IDW_NODAL_FUNC, ARG_INPUT_IDW_NODAL_CLASSIC, ARG_INPUT_IDW_NODAL_EXPONENT,
                    ARG_INPUT_IDW_NOD_COEF, ARG_INPUT_IDW_NOD_COEF_NEAR, ARG_INPUT_IDW_NOD_COEF_QUAD,
                    ARG_INPUT_IDW_WEIGHT, ARG_INPUT_IDW_WEIGHT_NEAR, ARG_INPUT_IDW_WEIGHT_QUAD]
        arguments[ARG_INPUT_INTERP_DIMENSION].show = False
        for item in idw_list:
            arguments[item].show = False
        if arguments[ARG_INPUT_INTERP_METHOD].value == 'Inverse distance weighted (IDW)':
            arguments[ARG_INPUT_INTERP_DIMENSION].show = True
            arguments[ARG_INPUT_IDW_NODAL_FUNC].show = True
            arguments[ARG_INPUT_IDW_NOD_COEF].show = True
            arguments[ARG_INPUT_IDW_WEIGHT].show = True
            if arguments[ARG_INPUT_IDW_NODAL_FUNC].value == "Constant (Shepard's method)":
                arguments[ARG_INPUT_IDW_NODAL_CLASSIC].show = True
                arguments[ARG_INPUT_IDW_NODAL_EXPONENT].show = True
            if arguments[ARG_INPUT_IDW_NOD_COEF].value == 'Use nearest points':
                arguments[ARG_INPUT_IDW_NOD_COEF_NEAR].show = True
                arguments[ARG_INPUT_IDW_NOD_COEF_QUAD].show = True
            if arguments[ARG_INPUT_IDW_WEIGHT].value == 'Use nearest points':
                arguments[ARG_INPUT_IDW_WEIGHT_NEAR].show = True
                arguments[ARG_INPUT_IDW_WEIGHT_QUAD].show = True

        # hide natural neighbor
        nn_list = [ARG_INPUT_NN_NODAL_FUNC, ARG_INPUT_NN_NODAL_COMP, ARG_INPUT_NN_NODAL_COMP_NEAR]
        for item in nn_list:
            arguments[item].show = False
        if arguments[ARG_INPUT_INTERP_METHOD].value == 'Natural neighbor':
            arguments[ARG_INPUT_NN_NODAL_FUNC].show = True
            arguments[ARG_INPUT_NN_NODAL_COMP].show = True
            if arguments[ARG_INPUT_NN_NODAL_COMP].value == 'Use nearest points':
                arguments[ARG_INPUT_NN_NODAL_COMP_NEAR].show = True

        # hide anisotropy
        anis_list = [ARG_INPUT_ANIS_HORIZ, ARG_INPUT_ANIS_AZIMUTH, ARG_INPUT_ANIS_VERTICAL]
        show = arguments[ARG_INPUT_ANISOTROPY].value
        for item in anis_list:
            arguments[item].show = show

        # log interpolation
        arguments[ARG_INPUT_LOG_MIN].show = arguments[ARG_INPUT_LOG].value

        # hide input dataset time picker
        arguments[ARG_INPUT_ALL_TIMESTEPS].show = False
        arguments[ARG_INPUT_DS_TIMESTEP].show = False
        if arguments[ARG_INPUT_DATASET].value:
            dataset = self._data_handler.get_input_dataset(arguments[ARG_INPUT_DATASET].value)
            if dataset is not None and dataset.num_times > 1:  # If transient, show the option to use all timesteps
                arguments[ARG_INPUT_ALL_TIMESTEPS].show = True
            if not arguments[ARG_INPUT_ALL_TIMESTEPS].value:  # Only show timestep picker if not using all timesteps
                arguments[ARG_INPUT_DS_TIMESTEP].enable_timestep(arguments[ARG_INPUT_DATASET])

        # hide extrapolation existing dataset time picker
        arguments[ARG_INPUT_EXTRAP_DS_TIMESTEP].show = False
        if arguments[ARG_INPUT_EXTRAP_DATASET].value:
            arguments[ARG_INPUT_EXTRAP_DS_TIMESTEP].enable_timestep(arguments[ARG_INPUT_EXTRAP_DATASET])

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._args = arguments
        self._data['src_ds'] = self.get_input_dataset(self._args[ARG_INPUT_DATASET].value)
        self._data['src_grid'] = None
        if self._data['src_ds'] is not None:
            self._data['src_grid'] = self.get_input_dataset_grid(self._args[ARG_INPUT_DATASET].text_value)
        if not self._data['src_grid']:
            errors[self._args[ARG_INPUT_DATASET].name] = 'Unable to read UGrid from dataset.'
        else:
            # if we are not doing IDW then there can not be any 3D cells
            if arguments[ARG_INPUT_INTERP_METHOD].value != 'Inverse distance weighted (IDW)':
                if not self._data['src_grid'].check_all_cells_2d():
                    msg = 'The selected interpolation option does not support 3D cells.'
                    errors[self._args[ARG_INPUT_DATASET].name] = msg
        self._data['dst_grid'] = self.get_input_grid(self._args[ARG_INPUT_TARGET_UGRID].value)
        if not self._data['dst_grid']:
            errors[self._args[ARG_INPUT_TARGET_UGRID].name] = 'Unable to read target UGrid.'

        # can't interpolate to cells if the target is a mesh or scatter
        if self._args[ARG_INPUT_DATASET_LOCATION].text_value == 'Cells':
            path_target_grid = self._args[ARG_INPUT_TARGET_UGRID].text_value
            if path_target_grid.startswith('Scatter Data') or path_target_grid.startswith('Mesh Data'):
                msg = 'Unable to interpolate to cells for Mesh or Scatter.'
                errors[self._args[ARG_INPUT_DATASET_LOCATION].name] = msg

        # must choose a timestep or the option to use all timesteps must be enabled
        dataset = self._data_handler.get_input_dataset(arguments[ARG_INPUT_DATASET].value)
        transient = dataset is not None and dataset.num_times > 1
        no_timestep = self._args[ARG_INPUT_DS_TIMESTEP].value in ['', Argument.NONE_SELECTED]
        all_timesteps = self._args[ARG_INPUT_ALL_TIMESTEPS].value
        if transient and no_timestep and not all_timesteps:
            msg = 'Must select a timestep or enable the option to interpolate all timesteps.'
            errors[self._args[ARG_INPUT_DS_TIMESTEP].name] = msg

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self.logger.info(f'Process ID: {os.getpid()}')
        self._get_grid_data()
        self._scale_point_locations()
        self._create_interpolator()
        self._do_interpolation()
        self.copy_output_dataset()

    def copy_output_dataset(self):
        """Copy output dataset to output."""
        if self._output_ds:
            self.set_output_dataset(self._output_ds)

    def _do_interpolation(self):
        """Do the interpolation operation."""
        self.logger.info('Begin interpolation to target grid locations.')
        self._create_output_dataset()
        self._get_interpolation_times()
        self._begin_interpolation_set_truncation()

        # for cell activity
        # self._output_ds.timestep_mins = list(self._data['src_ds'].mins)
        # self._output_ds.timestep_maxs = list(self._data['src_ds'].maxs)

        # loop through the times of the input dataset
        for time_idx, ds_time in enumerate(self._data['ds_times']):
            self.logger.info(f'Interpolating time index: {time_idx}')
            time.sleep(0.1)  # allow gui to catch up
            self._data['cur_time_idx'] = time_idx
            self._data['cur_ds_time'] = ds_time

            self._set_truncation_time_step()
            for comp_idx in range(self._data['src_ds'].num_components):
                self._set_interpolator_scalars(comp_idx)
                self._interp_time_step(comp_idx)
            # update the output dataset
            self._output_ds.append_timestep(time=self._data['src_ds'].times[ds_time], data=self._data['new_scalars'],
                                            activity=self._data['new_activity'])
        self._output_ds.appending_finished()

    def _interp_time_step_linear(self, idx, pt):
        """Linear interpolation to points.

        Args:
            idx (int): point index that is being interpolated
            pt (iterable): point that is being interpolated to

        Returns:
            (float): the interpolated value
        """
        if (idx + 1) % self._prog_increment == 0 and self._data['prog'] < 100:
            self._data['prog'] += 1
            ptime = default_timer()
            elapsed_time = ptime - self._timer
            if elapsed_time > self._linear_prog_time:
                self._timer = ptime
                self.logger.info(f'Time step interpolation percent complete: {self._data["prog"]}%')
        p = (pt[0], pt[1], pt[2])
        if self._interpolator.triangle_containing_point(p) < 0:
            self._data['extrap_idxs'].append(idx)
        return self._interpolator.interpolate_to_point(p)

    def _interp_time_step(self, comp_idx):
        """Interpolate the current time step.

        Args:
            comp_idx (int): index of the component being interpolated (applies to vector datasets)
        """
        if self._args[ARG_INPUT_INTERP_METHOD].value == 'Inverse distance weighted (IDW)':
            new_scalars = self._interpolator.interpolate_to_points(self._data['dst_pts'])
        else:
            self._data['prog'] = 0
            self._data['extrap_idxs'] = []
            pts = self._data['dst_pts']
            self._prog_increment = max(1, int(len(pts) / 100))
            self._timer = default_timer()
            new_scalars = [self._interp_time_step_linear(i, pts[i]) for i in range(len(pts))]
        if self._data['log_interp']:
            new_scalars = [math.pow(10, s) for s in new_scalars]

        new_activity = None
        # do extrapolation if needed
        if self._args[ARG_INPUT_INTERP_METHOD].value != 'Inverse distance weighted (IDW)':
            extrap_pts = self._data['extrap_idxs']
            extrap_ids = [i + 1 for i in extrap_pts]
            if self._args[ARG_INPUT_EXTRAP].value == 'No extrapolation':
                if len(extrap_pts) > 0:
                    self.logger.info(f'{len(extrap_pts)} points were outside the input grid.')
                    self.logger.info(f'Outside of input grid point ids: {extrap_ids}')
                    self.logger.info('These points were assigned an activity value of 0')
                    set_pts = set(extrap_pts)
                    new_activity = [1 if i not in set_pts else 0 for i in range(len(new_scalars))]
            else:
                tri_act = self._interpolator.triangle_activity
                if len(tri_act) > 0:
                    if self._data.get('points_outside_src_grid', None) is None:
                        self._interpolator.triangle_activity = [1] * len(tri_act)
                        check = self._interpolator.triangle_containing_point
                        pts = self._data['dst_pts']
                        pts_outside = [i for i in extrap_pts if check((pts[i][0], pts[i][1])) == -1]
                        self._data['points_outside_src_grid'] = set(pts_outside)
                    inactive_pts = set(extrap_pts) - self._data['points_outside_src_grid']
                    new_activity = [1 if i not in inactive_pts else 0 for i in range(len(new_scalars))]
                new_scalars = list(new_scalars)
                self.logger.info(f'{len(extrap_pts)} points were extrapolated.')
                print_ids = extrap_ids if len(extrap_ids) < 1000 else extrap_ids[:1000]
                self.logger.info(f'Extrapolation point ids (up to 1000 printed): {print_ids}')
                if self._args[ARG_INPUT_EXTRAP].value == 'Inverse distance weighted (IDW)':
                    self.logger.info('IDW used to extrapolate to points')
                    pts = self._data['dst_pts']
                    interpolator = self._data['extrapolator']
                    for idx in extrap_pts:
                        p = (pts[idx][0], pts[idx][1], pts[idx][2])
                        new_scalars[idx] = interpolator.interpolate_to_point(p)
                elif self._args[ARG_INPUT_EXTRAP].value == 'Existing dataset':
                    ds_name = self._args[ARG_INPUT_EXTRAP_DATASET]
                    self.logger.info(f'Existing dataset values assigned to point: {ds_name}')
                    ds, ts = self._data['extrap_ds'], self._data['extrap_ds_ts']
                    act = None
                    if ds.activity:
                        act = ds.activity[ts]
                        new_activity = [1] * len(new_scalars)
                    ds_scalar = ds.values[ts]
                    for idx in extrap_pts:
                        new_scalars[idx] = ds_scalar[idx]
                        if act is not None:
                            new_activity[idx] = act[idx]
                else:
                    val = self._args[ARG_INPUT_EXTRAP_VAL].value
                    self.logger.info(f'Constant value applied to extrapolation points: {val}')
        num_comp = self._data['src_ds'].num_components
        if self._data['src_ds'].num_components > 1:
            if comp_idx == 0:
                self._data['new_scalars'] = [[0.0] * num_comp for i in range(len(new_scalars))]
            for i in range(len(new_scalars)):
                self._data['new_scalars'][i][comp_idx] = new_scalars[i]
        else:
            self._data['new_scalars'] = new_scalars

        if self._null_value is not None and new_activity:
            scalar = self._data['new_scalars']
            for i in range(len(scalar)):
                if not new_activity[i]:
                    if self._data['src_ds'].num_components > 1:
                        for j in range(len(scalar[i])):
                            scalar[i][j] = self._null_value
                    else:
                        scalar[i] = self._null_value
            new_activity = None

        # Create Cell Activity
        # ug = self._data['dst_grid'].ugrid
        # cell_act = [1] * ug.cell_count
        # for i in range(len(new_activity)):
        #     if new_activity[i] == 0:
        #         adj_cells = ug.get_point_adjacent_cells(i)
        #         for cidx in adj_cells:
        #             cell_act[cidx] = 0
        # new_activity = cell_act

        self._data['new_activity'] = new_activity
        print_activity = new_activity if new_activity is None or len(new_activity) < 1000 else new_activity[:1000]
        self.logger.info(f'New activity (up to index 1000 printed): {print_activity}')

        # no longer needed because we are using the NULL value. Leave this here in case you want to
        # create point data with point activity
        # if new_activity is not None:
        #     self._output_ds.use_activity_as_null = True

    def _set_interpolator_scalars(self, comp_idx):
        """Set the scalars on the interpolator.

        Args:
            comp_idx (int): index of the component being interpolated (applies to vector datasets)
        """
        ds_time = self._data['cur_ds_time']
        scalars = self._data['src_ds'].values[ds_time]
        if self._data['src_ds'].num_components > 1:
            scalars = [s[comp_idx] for s in scalars]
        if self._data['log_interp']:
            scalars = [self._get_log_val(s) for s in scalars]

        self._interpolator.scalars = scalars
        if 'extrapolator' in self._data:
            self._data['extrapolator'].scalars = scalars
        if self._data['src_ds'].activity is not None or self._data['src_ds'].null_value is not None:
            if self._data['src_ds'].null_value is not None:
                nv = self._data['src_ds'].null_value
                activity = [0 if nv == s else 1 for s in scalars]
            else:
                activity = self._data['src_ds'].activity[ds_time]
            if len(activity) == len(self._data['src_pts']):
                self._interpolator.point_activity = activity
                if 'extrapolator' in self._data:
                    self._data['extrapolator'].point_activity = activity
            else:  # cell activity
                if 'tri_idx_to_cell_idx' in self._data:
                    tri_to_cell = self._data['tri_idx_to_cell_idx']
                    tri_act = [activity[tri_to_cell[i]] for i in range(len(tri_to_cell))]
                else:
                    tri_act = activity
                self._interpolator.triangle_activity = tri_act

    def _set_truncation_time_step(self):
        """Set the truncation for a time step if truncation is being used."""
        if self._data['trunc_maxs']:  # if trunc to dataset min/max then we get the min/max for each time step
            idx = self._data['cur_time_idx']
            tmax = self._data['trunc_maxs'][idx]
            tmin = self._data['trunc_mins'][idx]
            self._interpolator.set_truncation(maximum=tmax, minimum=tmin)

    def _get_interpolation_times(self):
        """Get the interpolation times from the source dataset."""
        if self._args[ARG_INPUT_ALL_TIMESTEPS].value:
            ds_times = [i for i in range(len(self._data['src_ds'].times))]
        else:
            try:
                ds_time_step = int(self._args[ARG_INPUT_DS_TIMESTEP].value)
            except (ValueError, TypeError):
                ds_time_step = 1
            ds_times = [ds_time_step - 1]  # Make 0-based
        self._data['ds_times'] = ds_times

    def _begin_interpolation_set_truncation(self):
        """Set the truncation options for the interpolator."""
        ds_times = self._data['ds_times']
        trunc_opt = self._args[ARG_INPUT_TRUNCATE].value
        trunc_mins = trunc_maxs = []
        if trunc_opt != 'Do not truncate':
            if trunc_opt == 'Truncate to min/max of source dataset':
                tmin = self._data['src_ds'].mins[ds_times[0]]
                tmax = self._data['src_ds'].maxs[ds_times[0]]
                if len(ds_times) > 1:
                    trunc_maxs = [d for d in self._data['src_ds'].maxs]
                    trunc_mins = [d for d in self._data['src_ds'].mins]
            else:
                tmin = self._args[ARG_INPUT_TRUNC_MIN].value
                tmax = self._args[ARG_INPUT_TRUNC_MAX].value

            if self._data['log_interp']:
                tmin = self._get_log_val(tmin)
                tmax = self._get_log_val(tmax)
                for i in range(len(trunc_mins)):
                    trunc_mins[i] = self._get_log_val(trunc_mins[i])
                    trunc_maxs[i] = self._get_log_val(trunc_maxs[i])

            self._interpolator.set_truncation(maximum=tmax, minimum=tmin)

        self._data['trunc_mins'] = trunc_mins
        self._data['trunc_maxs'] = trunc_maxs

    def _get_log_val(self, val):
        """Get the log value of val and if val is <= 0.0 then set it to the log_min.

        Args:
            val (float): the value

        Returns:
            (float): the log of value or the log of log_min_val
        """
        return math.log10(val) if val > 0.0 else math.log10(self._data['min_log'])

    def _create_output_dataset(self):
        """Create the output dataset."""
        ds_loc = 'points' if self._args[ARG_INPUT_DATASET_LOCATION].value != 'Cells' else 'cells'
        ds_name = self._args[ARG_OUTPUT_DATASET_NAME].value
        if not ds_name:
            method_dict = {'Linear': 'linear', 'Inverse distance weighted (IDW)': 'idw', 'Natural neighbor': 'nn'}
            method = method_dict[self._args[ARG_INPUT_INTERP_METHOD].value]
            nodal = ''
            if method != 'linear':
                nod_dict = {"Constant (Shepard's method)": 'const', 'Gradient plane': 'grad', 'Quadratic': 'quad',
                            'Constant': 'const'}
                if method == 'idw':
                    nodal = f'_{nod_dict[self._args[ARG_INPUT_IDW_NODAL_FUNC].value]}'
                else:
                    nodal = f'_{nod_dict[self._args[ARG_INPUT_NN_NODAL_FUNC].value]}'
            ds_name = f'{self._data["src_ds"].name}_{method}{nodal}'

        self._output_ds = DatasetWriter(
            name=ds_name,
            num_components=self._data['src_ds'].num_components,
            ref_time=self._data['src_ds'].ref_time,
            time_units=self._data['src_ds'].time_units,
            null_value=self._null_value,
            location=ds_loc,
            geom_uuid=self._data['dst_grid'].uuid,
        )

    def _create_interpolator(self):
        """Set outputs from the tool."""
        self.logger.info('Creating interpolator.')
        self._data['log_interp'] = self._args[ARG_INPUT_LOG].value
        self._data['min_log'] = self._args[ARG_INPUT_LOG_MIN].value
        method = self._args[ARG_INPUT_INTERP_METHOD].value
        if method == 'Inverse distance weighted (IDW)':
            n_dict = {"Constant (Shepard's method)": 'constant', 'Gradient plane': 'gradient_plane',
                      'Quadratic': 'quadratic'}
            nodal = n_dict[self._args[ARG_INPUT_IDW_NODAL_FUNC].value]
            nearest = self._args[ARG_INPUT_IDW_NOD_COEF_NEAR].value
            quad_search = self._args[ARG_INPUT_IDW_NOD_COEF_QUAD].value
            if self._args[ARG_INPUT_IDW_NOD_COEF].value == 'Use all points':
                nearest = -1
                quad_search = False
            is_2d = self._args[ARG_INPUT_INTERP_DIMENSION].value == '2D'
            # interpolation weights
            weight_opt = self._args[ARG_INPUT_IDW_WEIGHT].value
            if weight_opt == 'Use nearest points':
                w_nearest = self._args[ARG_INPUT_IDW_WEIGHT_NEAR].value
                w_quad_search = self._args[ARG_INPUT_IDW_WEIGHT_QUAD].value
            else:
                w_nearest = -1
                w_quad_search = False
            self._interpolator = InterpIdw(points=self._data['src_pts'], nodal_function=nodal,
                                           number_nearest_points=nearest, quadrant_oct=quad_search,
                                           is_2d=is_2d, progress=self._observer,
                                           weight_number_nearest_points=w_nearest, weight_quadrant_oct=w_quad_search)
            if nodal == 'constant':
                self._interpolator.power = self._args[ARG_INPUT_IDW_NODAL_EXPONENT].value
                weight_calc = 'classic' if self._args[ARG_INPUT_IDW_NODAL_CLASSIC].value else 'modified'
                self._interpolator.weight_calculation_method = weight_calc
            self._interpolator.set_observer(self._observer)
        else:
            if method == 'Natural neighbor':
                n_dict = {'Constant': 'constant', 'Gradient plane': 'gradient_plane', 'Quadratic': 'quadratic'}
                nodal = n_dict[self._args[ARG_INPUT_NN_NODAL_FUNC].value]
                nodal_comp = self._args[ARG_INPUT_NN_NODAL_COMP].value
                nearest = -1
                if nodal_comp == 'Use natural neighbors':
                    ps_opt = 'natural_neighbor'
                    nearest = 1
                else:
                    ps_opt = 'nearest_points'
                    if nodal_comp == 'Use nearest points':
                        nearest = self._args[ARG_INPUT_NN_NODAL_COMP_NEAR].value
                self._interpolator = InterpLinear(points=self._data['src_pts'], triangles=self._data['src_tris'],
                                                  nodal_function=nodal, point_search_option=ps_opt,
                                                  number_nearest_points=nearest, progress=self._observer)
            else:
                self._interpolator = InterpLinear(points=self._data['src_pts'], triangles=self._data['src_tris'],
                                                  progress=self._observer)
                if self._args[ARG_INPUT_LINEAR_CLOUGH_TOCHER].value:
                    self._interpolator.set_use_clough_tocher(True)

            extrap_val = DEFAULT_EXTRAP_VALUE
            if self._args[ARG_INPUT_EXTRAP].value == 'Constant value':
                extrap_val = self._args[ARG_INPUT_EXTRAP_VAL].value
                if self._data['log_interp']:
                    extrap_val = self._get_log_val(extrap_val)
            elif self._args[ARG_INPUT_EXTRAP].value == 'Inverse distance weighted (IDW)':
                self._data['extrapolator'] = InterpIdw(points=self._data['src_pts'], nodal_function='constant')
                nearest = self._args[ARG_INPUT_EXTRAP_IDW_NUM_NEAREST].value
                quad_search = self._args[ARG_INPUT_EXTRAP_IDW_QUADRANT].value
                if self._args[ARG_INPUT_EXTRAP_IDW].value == 'Use all points':
                    nearest = -1
                    quad_search = False
                self._data['extrapolator'].set_search_options(nearest_point=nearest, quadrant_oct_search=quad_search)
            elif self._args[ARG_INPUT_EXTRAP].value == 'Existing dataset':
                self._data['extrap_ds'] = self.get_input_dataset(self._args[ARG_INPUT_EXTRAP_DATASET].value)
                ts = self._args[ARG_INPUT_EXTRAP_DS_TIMESTEP].value
                if ts in ['', '-- None Selected --']:
                    ts = 0
                else:
                    ts = int(ts) - 1
                self._data['extrap_ds_ts'] = ts

            self._interpolator.extrapolation_value = extrap_val

    def _get_grid_data(self):
        """Get the point and triangle information from grids."""
        self.logger.info('Getting source grid point locations.')
        # see if the grid has any cells that are not triangles, if so we need to use the UGrid2dDataExtractor
        self._data['src_tris'] = None
        ug = self._data['src_grid'].ugrid
        if self._data['src_ds'].location in ['points', '', ('points', 'points')]:
            self._data['src_pts'] = [[p[0], p[1], p[2]] for p in ug.locations]
            grid_locs = self._data['src_pts']
            self.logger.info('Getting triangles from source grid.')
            if self._data['src_grid'].check_all_cells_are_of_type(ug.cell_type_enum.TRIANGLE):
                self._data['src_tris'] = [p for i in range(ug.cell_count) for p in ug.get_cell_points(i)]
            else:
                self._data['cells_to_tris'] = {}
                tris = []
                tri_idx_to_cell_idx = []
                timer = default_timer()
                for i in range(ug.cell_count):
                    pts = ug.get_cell_points(i)
                    ptime = default_timer()
                    elapsed_time = ptime - timer
                    if elapsed_time > self._linear_prog_time:
                        timer = ptime
                        self.logger.info(f'{i} cells processed of {ug.cell_count} total cells.')
                    if len(pts) < 4:
                        tris.extend(pts)
                        tri_idx_to_cell_idx.append(i)
                    elif len(pts) > 3:
                        locs = [xy for p in pts for xy in grid_locs[p][:2]]
                        ear_tris = earcut(locs)
                        it = iter(ear_tris)
                        new_tris = zip(it, it, it)
                        for t in new_tris:
                            tris.extend([pts[t[0]], pts[t[1]], pts[t[2]]])
                            tri_idx_to_cell_idx.append(i)
                self._data['src_tris'] = tris
                self._data['tri_idx_to_cell_idx'] = tri_idx_to_cell_idx
        else:
            self._get_grid_cell_centers('src_grid', 'src_pts')

        self.logger.info('Getting target grid point locations.')
        if self._args[ARG_INPUT_DATASET_LOCATION].value == 'Points':
            self._data['dst_pts'] = [[p[0], p[1], p[2]] for p in self._data['dst_grid'].ugrid.locations]
        else:
            self._get_grid_cell_centers('dst_grid', 'dst_pts')

    def _get_grid_cell_centers(self, grid_str, pt_str):
        """Fill in the points for the ugrid.

        Args:
            grid_str (str): name of grid in the data dictionary
            pt_str (str): name of the point list in the data dictionary
        """
        if self._data[grid_str].cell_centers is not None:
            self._data[pt_str] = [[p[0], p[1], p[2]] for p in self._data[grid_str].cell_centers]
        else:
            ug = self._data[grid_str].ugrid
            self._data[pt_str] = [_get_cell_centroid(ug, i) for i in range(ug.cell_count)]

    def _scale_point_locations(self):
        """Scale the point locations if anisotropy has been specified."""
        # return if not using anisotropy or if the horizontal and vertical are equal to 1.0
        if not self._args[ARG_INPUT_ANISOTROPY].value:
            return
        if self._args[ARG_INPUT_ANIS_HORIZ].value == 1.0 and self._args[ARG_INPUT_ANIS_VERTICAL].value == 1.0:
            return
        self.logger.info('Applying anisotropy.')

        azimuth = self._args[ARG_INPUT_ANIS_AZIMUTH].value
        sin_az = math.sin(math.pi * (azimuth / 180.0))
        cos_az = math.cos(math.pi * (azimuth / 180.0))
        matrix = [[0.0, 0.0], [0.0, 0.0]]
        hor_anis = self._args[ARG_INPUT_ANIS_HORIZ].value
        matrix[0][0] = (hor_anis * cos_az * cos_az) + (sin_az * sin_az)
        matrix[0][1] = matrix[1][0] = (-hor_anis * cos_az * sin_az) + (sin_az * cos_az)
        matrix[1][1] = (hor_anis * sin_az * sin_az) + (cos_az * cos_az)
        z_mag = 1 / self._args[ARG_INPUT_ANIS_VERTICAL].value
        scale_z = 1.0 if self._args[ARG_INPUT_INTERP_DIMENSION].value == '2D' else z_mag
        for p in self._data['src_pts']:
            x, y = p[0], p[1]
            p[0] = matrix[0][0] * x + matrix[0][1] * y
            p[1] = matrix[1][0] * x + matrix[1][1] * y
            p[2] = p[2] * scale_z
        for p in self._data['dst_pts']:
            x, y = p[0], p[1]
            p[0] = matrix[0][0] * x + matrix[0][1] * y
            p[1] = matrix[1][0] * x + matrix[1][1] * y
            p[2] = p[2] * scale_z
        return self._data


def _get_cell_centroid(ug, idx):
    """Get the cell centroid as a list of floats.

    Args:
        ug (UGrid): the ugrid
        idx (int): index for the cell

    Returns:
        (list(float)): the centroid
    """
    _, p = ug.get_cell_centroid(idx)
    return [p[0], p[1], p[2]]
