"""Functions used to Interpolate."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules


class Interpolation:
    """Class for interpolating x, y data."""

    def __init__(self, x_values: list[float], y_values: list[float], input_dict: dict = None, use_log_x: bool = False,
                 use_log_y: bool = False, null_data: float = None, zero_tol: float = None):
        """Initialize Interpolation class.

        Args:
            x_values (list of floats): x values for the data
            y_values (list of floats): y values for the data
            input_dict (input_dict): The dict with the input data and settings (required if not using app_data or
                specify values)
            use_log_x (bool): Use log data for the X data
            use_log_y (bool): Use log data for the Y data
            null_data (float): Null data value
            zero_tol (float): Zero tolerance for the data
        """
        self.x = []
        self.y = []

        self.original_x = []
        self.original_y = []

        self.null_data = null_data
        self.zero_tol = zero_tol
        if self.null_data is None:
            if input_dict is not None:
                _, self.null_data = input_dict.get_data('Null data')
            else:
                raise ValueError("Null data not provided and no app_data or input_dict available.")
        if self.zero_tol is None:
            if input_dict is not None:
                _, self.zero_tol = input_dict.get_data('Zero tolerance')
            else:
                raise ValueError("Zero tolerance not provided and no app_data or input_dict available.")

        self.initializing = True
        self.set_xy(x_values, y_values)

        self.use_log_x = use_log_x
        self.use_log_y = use_log_y

        self.use_second_interpolation = False
        self.y_intercept_first = None
        self.y_intercept_second = None

        self.initializing = False

    def set_xy(self, x: list[float], y: list[float]):
        """Set X, Y data.

        Args:
            x (list of floats): x values for the data
            y (list of floats): y values for the data
        """
        if x is None or y is None:
            raise ValueError("X or Y values cannot be None.")
        if self.initializing and (len(x) == 0 and len(y) == 0):
            self.x = []
            self.y = []
            return
        if len(x) < 2 or len(y) < 2:
            raise ValueError("X and Y values must contain at least two points.")
        if len(x) != len(y):
            raise ValueError("X and Y values must have the same length.")

        self.original_x = x
        self.original_y = y

        self.x = []
        self.y = []

        size = len(x)
        for index in range(size):
            null_x = math.isclose(x[index], self.null_data, abs_tol=self.zero_tol)
            null_y = math.isclose(y[index], self.null_data, abs_tol=self.zero_tol)
            if not null_x and not null_y:
                self.x.append(x[index])
                self.y.append(y[index])

    def interpolate_y(self, x_interp: float, extrapolate: bool = True) -> tuple[float, int | None]:
        """Interpolate a y value from a given x value from the XY data.

        Args:
            x_interp (float): x value that we want a correlated y value
            extrapolate (bool): True if we want to determine a value outside the X,Y data.

        Returns:
            tuple: (y value, index of the lower x value used for interpolation)
        """
        y_interp = 0.0
        found_index = 1  # Default for extrapolating below
        size = len(self.x)
        inside_list = False

        if len(self.x) < 2 or len(self.y) < 2:
            # Check if both original lists have at least two points, but were removed due to null data
            if len(self.original_x) > 2 or len(self.original_y) > 2:
                for index in range(1, len(self.original_x)):
                    if self.original_x[index - 1] <= x_interp <= self.original_x[index] or \
                            self.original_x[index - 1] >= x_interp >= self.original_x[index]:
                        if math.isclose(self.original_y[index - 1], self.null_data, abs_tol=self.zero_tol) and \
                                math.isclose(self.original_y[index], self.null_data, abs_tol=self.zero_tol):
                            return float(self.null_data), index
            raise ValueError("X and Y values must contain at least two points.")

        min_x = min(self.x)
        max_x = max(self.x)
        if not extrapolate:
            if x_interp <= min_x:
                min_index = self.x.index(min_x)
                return float(self.y[min_index]), min_index
            elif x_interp >= max_x:
                max_index = self.x.index(max_x)
                return float(self.y[max_index]), max_index
            inside_list = True
        else:
            if max_x >= x_interp >= min_x:
                inside_list = True
            elif x_interp > max_x:
                # above_list
                found_index = self.x.index(max_x)
            elif x_interp < min_x:
                # below_list
                found_index = self.x.index(min_x)
        if inside_list:
            if self.x[0] == x_interp:  # If equal, following logic will skip past it; yet important to get index right
                return float(self.y[0]), 0
            for index in range(size - 1):
                if self.x[index] < x_interp <= self.x[index + 1] or self.x[index] >= x_interp > self.x[index + 1]:
                    found_index = index
                    break

        # shorthand
        index = found_index
        if index == size - 1:
            index -= 1
        prev = index + 1

        # Interpolate!
        y_interp = self._interpolate_y_with_indices(x_interp, index, prev)

        if self.use_second_interpolation:  # Improves guessing rate for flow interpolations
            next = index - 1
            if next >= 0:
                self.y_intercept_first = y_interp

                index_2 = index + 1
                next = index_2 + 1

                if next < size:
                    self.y_intercept_second = self._interpolate_y_with_indices(x_interp, index_2, next)
                    y_interp = (self.y_intercept_first + self.y_intercept_second) / 2.0

        return float(y_interp), index

    def _interpolate_y_with_indices(self, x_interp: float, index_1: int, index_2: int) -> float:
        """Interpolate the y value between two x indices.

        Args:
            x_interp (float): The x value to interpolate.
            index_1 (int): The index of the first x value.
            index_2 (int): The index of the second x value.

        Returns:
            float: The interpolated y value.
        """
        # log = math.log10
        x = self.x
        y = self.y

        x_diff = x[index_1] - x[index_2]
        y_diff = y[index_1] - y[index_2]
        if x_diff == 0.0:
            y_interp = y[index_2]
        else:
            y_interp = (y_diff / x_diff) * (x_interp - x[index_2]) + y[index_2]

        return y_interp

    # Logic for log_x or log_y calculations, that may be needed for gradations tool:
        #     if self.use_log_x:
        #     log_x_diff = log(x[index_1]) - log(x[index_2])
        #     if self.use_log_y:
        #         log_y_diff = log(y[index_1]) - log(y[index_2])
        #         if log_x_diff == 0.0:
        #             log_y = log(y[index_2])
        #         else:
        #             log_y = (log_y_diff / log_x_diff) * (log_x_diff)
        #         y_interp = 10**log_y
        #     else:
        #         y_diff = y[index_1] - y[index_2]
        #         if log_x_diff == 0.0:
        #             y_interp = y[index_2]
        #         else:
        #             y_interp = (y_diff / log_x_diff) * (log(x_interp) - log(x[index_2]))
        # else:
        #     x_diff = x[index_1] - x[index_2]
        #     if self.use_log_y:
        #         log_y_diff = log(y[index_1]) - log(y[index_2])
        #         if x_diff == 0.0:
        #             log_y = log(y[index_2])
        #         else:
        #             log_y = (log_y_diff / (x_diff)) * (x_interp - x[index_2])
        #         y_interp = 10**log_y
        #     else:
        #         y_diff = y[index_1] - y[index_2]
        #         if x_diff == 0.0:
        #             y_interp = y[index_2]
        #         else:
        #             y_interp = (y_diff / x_diff) * (x_interp - x[index_2]) + y[index_2]
