"""Calculator for performing Manning's N operations."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import sys

# 2. Third party modules
import numpy

# 3. Aquaveo modules

# 4. Local modules


def get_manning_n_calc_data(exith_param):
    """Get the Manning n data to set up the Channel Calculator.

    Args:
        exith_param (:obj:`BcDataExitH`): Param data object of an Exit-H arc

    Returns:
        (:obj:`list[float]`): Hours or flow from an xy series if one already exists
    """
    option = exith_param.water_surface_elevation_option
    data_list = None
    units = "U.S. Customary Units"
    if option == 'Constant':
        if exith_param.constant_wse_units == 'Meters':
            units = "SI Units (Metric)"
    elif option == 'Time series':
        if exith_param.time_series_wse_units == "hrs -vs- meters":
            units = "SI Units (Metric)"
        data_list = exith_param.time_series_wse['hrs'].to_list()
    elif option == 'Rating curve':
        if exith_param.rating_curve_units == 'cms -vs- meters':
            units = "SI Units (Metric)"
        data_list = exith_param.rating_curve['vol/sec'].to_list()
    return data_list, units


class ManningNCalc:
    """A class that defines a channel and performs Manning's n computations."""
    def __init__(self):
        """Initializes the Manning's n calculator."""
        # constants
        self.gravity = 32.2
        self.constant = 1.486
        self.gamma = 62.4

        # Input
        self.type = "WSE"
        self.original_stations = []
        self.original_elevations = []
        self.stations = []
        self.elevations = []
        self.units = "U.S. Customary Units"
        self.composite_n = 0.00
        self.slope = 0.0
        self.flows = []
        self.freeboard = 0.0

        # Intermediate
        self.lowest_station = 0.0
        self.perimeter = 0.0
        self.area = 0.0
        self.top_width = 0.0
        self.hydraulic_radius = 0.0
        self.high_elev = 0.0
        self.low_elev = 0.0

        # Results
        self.normal_depths = []
        self.normal_wses = []
        self.normal_wses_stations = [[]]
        self.critical_depths = []
        self.critical_wses = []
        # self.critical_wses_stations = [[]]
        # Maybe another day, but this is beyond current computations

        self.original_channel_wse = 0.0
        self.warnings = []

    @staticmethod
    def determine_low_and_high_of_elevations(elevations):
        """Determines the low and and high of a given set of elevations.

        Args:
            elevations (:obj:`list[float]`): the set of elevations used to find the low and high values

        Returns:
            (:obj:`tuple(float, float)`): See description
        """
        high_elev = -sys.float_info.max
        low_elev = sys.float_info.max
        for elev in elevations:
            if elev > high_elev:
                high_elev = elev
            if elev < low_elev:
                low_elev = elev
        return low_elev, high_elev

    @staticmethod
    def add_vertical_wall_to_elevation(stations, elevations, elev):
        """Adds vertical walls to a set of stations and elevations to a specified elevation.

        Note that it also adds to the stations to keep them matching.

        Args:
            stations (:obj:`list[float]`) : stations along a channel cross section
            elevations (:obj:`list[float]`) : elevations at those points along a channel cross section
            elev (:obj:`float`): elevation of the top of the vertical wall

        Returns:
            (:obj:`tuple(list,list)`): The stations and elevations with a vertical wall (lists of doubles)
        """
        if len(elevations) > 1:
            if elevations[0] != elev:
                elevations.insert(0, elev)
                stations.insert(0, stations[0])
            if elevations[-1] != elev:
                elevations.append(elev)
                stations.append(stations[-1])
        return stations, elevations

    def _initialize_geometry(self):
        """Initialize the geometry of the cross section to prepare for calculations.

        First, correct for negative values.
        Second, Change the stations to begin at zero.
        Third, Add a vertical wall to even up the sides.
        Note, that this is done on a new station/elevation set so the originals remain unchanged.
        """
        self.low_elev, self.high_elev = self.determine_low_and_high_of_elevations(self.original_elevations)

        # Correct cross section for negative values
        # Elevation
        self.elevations = []
        correction = 0.0
        if self.low_elev < 0.0:
            correction = self.low_elev
            self.low_elev -= correction
            self.high_elev -= correction

        for elev in self.original_elevations:
            self.elevations.append(elev - correction)

        # Station
        self.lowest_station = sys.float_info.max
        for station in self.original_stations:
            if station < self.lowest_station:
                self.lowest_station = station

        self.stations = []
        for station in self.original_stations:
            self.stations.append(station - self.lowest_station)

        # Add vertical wall to make both sides equal to max elev
        if self.low_elev == self.high_elev:
            self.high_elev = self.high_elev + 2
        self.stations, self.elevations = self.add_vertical_wall_to_elevation(
            self.stations, self.elevations, self.high_elev
        )

    def _check_for_vertical_wall_warning(self):
        """Checks if vertical walls were necessary to determine normal and critical flow depths.

        Note that is function is run in the Check for Warnings function call.

        Returns:
            (:obj:`str`): Warning string indicating which flows needed vertical walls
        """
        warning = ""
        wall_for_normal = False
        wall_for_critical = False
        for wse in self.normal_wses:
            if wse > self.original_channel_wse:
                wall_for_normal = True
        for wse in self.critical_wses:
            if wse > self.original_channel_wse:
                wall_for_critical = True
        if wall_for_normal:
            if wall_for_critical:
                warning = "Vertical walls were added for normal and critical depth computations"
            else:
                warning = "Vertical walls were added for normal depth computations"
        elif wall_for_critical:
            warning = "Vertical walls were added for critical depth computations"
        return warning

    def check_warnings(self):
        """Checks for warnings that are given during computations or a check if we can compute (get_can_compute).

        Returns:
            (:obj:`list[str]`): The warnings found (if any)
        """
        warning = self._check_for_vertical_wall_warning()
        if warning:
            self.warnings.append(warning)
        return self.warnings

    def get_can_compute_critical(self):
        """Determines if there is enough data to make a computation and if there isn't, add a warning for each reason.

        Returns:
            (:obj:`bool`): True if can compute
        """
        result = True
        self.warnings = []
        if len(self.original_stations) < 2 or len(self.original_elevations) < 2:
            self.warnings.append("Please enter Cross Section geometry")
            result = False
        if len(self.original_stations) != len(self.original_elevations):
            self.warnings.append("The number of stations and elevations in the Cross Section does not match!")
            result = False
        if self.composite_n < 0.0001:
            self.warnings.append("Please enter a composite Manning's n value")
            result = False
        if len(self.flows) < 1:
            self.warnings.append("Please enter a flow")
            result = False
        found = False
        for flow in self.flows:
            if flow > 0.0:
                found = True
                break
        if not found:
            self.warnings.append("Please enter a positive, non-zero flow")
            result = False
        return result

    def get_can_compute(self):
        """Determines if there is enough data to make a computation and if there isn't, add a warning for each reason.

        Returns:
            (:obj:`bool`): True if can compute
        """
        result = self.get_can_compute_critical()
        if self.slope < 0.0000000001:
            self.warnings.append("Please enter a slope for normal depth computations")
            result = False
        return result

    def compute_data(self):
        """Computes the data possible; stores results in self.

        Returns:
            (:obj:`bool`): True if successful
        """
        if self.get_can_compute_critical():
            self.original_channel_wse = self.original_elevations[0]
            if self.original_channel_wse > self.original_elevations[-1]:
                self.original_channel_wse = self.original_elevations[-1]
            self._initialize_geometry()
            if self.type == "WSE":
                self._compute_wses_from_flow()
                return True
            # Add type = "Depth" or "Flow" if wanted
        return False

    def _compute_wses_from_flow(self):
        """Computes Water Surface normal and critical Elevations from the specified flows."""
        if self.get_can_compute():
            self._compute_normal_depth_from_flow()
        self._compute_critical_depth_from_flow()

    def _compute_normal_depth_from_flow(self):
        """Computes Water Surface Elevations (wse-s) and depths from specified flows using normal depth function."""
        self.normal_wses_stations = [[]]
        wse_list, self.normal_wses_stations = self._compute_wse_from_flow(
            self._compute_flow_from_elevation, self.normal_wses_stations
        )
        self.normal_wses = wse_list
        depth_list = []
        for wse in wse_list:
            depth_list.append(wse - self.low_elev)
        self.normal_depths = depth_list

    def _compute_critical_depth_from_flow(self):
        """Computes Water Surface Elevations (wse-s) and depths from specified flows using critical depth function."""
        station_list_to_drop = [[]]
        wse_list, station_list_to_drop = self._compute_wse_from_flow(
            self._compute_critical_flow_for_elevation, station_list_to_drop
        )
        self.critical_wses = wse_list
        depth_list = []
        for wse in wse_list:
            depth_list.append(wse - self.low_elev)
        self.critical_depths = depth_list

    def _compute_wse_from_flow(self, compute_function, wse_stations):
        """Computes the Water Surface Elevation (wse) for the specified flows given a computing functor.

        Args:
            compute_function: functor that computes a wse for a given flow (normal or critical; perhaps shape specific)
            wse_stations: The stations of where the WSE meets the channels with 'nan's between segments (used for
             plotting)

        Returns:
            (:obj:`tuple(list,list)`): The WSE at each flow, the stations of where the WSE meets the channels with
            'nan's between segments (used for plotting)
        """
        station_list_to_drop = []
        # Add Zero (floor of computation)
        wse_flow_list = {self.low_elev: 0.0}

        # Add Midpoints
        max_depth = (self.high_elev - self.low_elev)
        if max_depth <= 0.0:
            max_depth = 20.0  # Just give us some value to start with

        number_of_divisions = 20
        for i in range(1, number_of_divisions):
            cur_wse = max_depth * (i / (number_of_divisions + 1)) + self.low_elev
            cur_flow = compute_function(cur_wse, station_list_to_drop)
            wse_flow_list[cur_wse] = cur_flow

        # Add Max Depth (top of computation)
        max_flow = compute_function(self.high_elev, station_list_to_drop)
        wse_flow_list[self.high_elev] = max_flow

        tol = 0.001
        wse_results = []
        wse_stations_initialized = False
        num_iterations = 500
        for flow in self.flows:
            cur_wse = -999.0
            cur_flow = 0.0
            count = 0
            cur_wse_stations = []
            if flow <= 0.0:
                cur_wse = self.low_elev
            elif flow > max_flow:
                # determine a good depth guess
                depth_guess = (flow / max_flow) * (self.high_elev - self.low_elev) + self.low_elev
                flow_guess = compute_function(depth_guess, station_list_to_drop)
                wse_flow_list[depth_guess] = flow_guess

                # Add Max Depth (top of computation)
                max_flow = compute_function(self.high_elev, station_list_to_drop)
                wse_flow_list[self.high_elev] = max_flow
            if flow > 0.0:
                while abs(cur_flow - flow) > tol and count < num_iterations:
                    cur_wse_stations = []
                    flow_list = []
                    wse_list = sorted(wse_flow_list.keys())
                    for wse in wse_list:
                        flow_list.append(wse_flow_list[wse])
                    cur_wse = numpy.interp(flow, flow_list, wse_list)
                    cur_flow = compute_function(cur_wse, cur_wse_stations)
                    wse_flow_list[cur_wse] = cur_flow
                    count += 1
                if count >= num_iterations:
                    warning = "Calculations were unable to converge on a solution when trying to find a flow of "
                    warning += str(flow) + " cfs"
                    self.warnings.append(warning)
            wse_results.append(cur_wse)
            if not wse_stations_initialized:
                wse_stations = [cur_wse_stations]
                wse_stations_initialized = True
            else:
                wse_stations.append(cur_wse_stations)

        return wse_results, wse_stations

    def _compute_flow_from_elevation(self, elevation, elevation_stations):
        """Compute the normal flow for one specified elevation.

        Args:
            elevation: given water surface elevation
            elevation_stations: list of stations where the water surface elevation meets the channel walls
                (this will be updated in this function)

        Returns:
            (:obj:`float`): See description
        """
        self._check_elevation_to_max_depth(elevation)

        self._compute_geometry_from_elevation(elevation, elevation_stations)

        if self.perimeter <= 0.0:
            self.hydraulic_radius = 0.0
        else:
            self.hydraulic_radius = self.area / self.perimeter
        self.flow = (self.constant / self.composite_n) * self.area * (self.hydraulic_radius**(2.0 / 3.0))
        self.flow *= self.slope**0.5

        # self.velocity = self.flow / self.flowArea
        # self.froude = self.velocity / pow((self.gravity * self.flowArea) / self.topWidth, 0.5)

        # self.calcMaxShearStress = self.gamma * self.channelDepth * self.longSlope
        # self.calcAvgShearStress = self.gamma * self.hydRadius * self.longSlope

        return self.flow

    def _compute_critical_flow_for_elevation(self, elevation, elevation_stations):
        """Compute the critical flow for one specified elevation.

        Args:
            elevation (:obj:`float`): The specified elevation
            elevation_stations (:obj:`list`): Stations where wse meets channel; however, this function will NOT update
                this list. It is only here, so that the previous function can call this function or the normal depth
                function. This equation could be modified to determine these stations.

        Returns:
            (:obj:`float`): See description
        """
        self._check_elevation_to_max_depth(elevation)

        self._compute_geometry_from_elevation(elevation, elevation_stations)
        # Critical Depth (Q) = sqrt(A^3*g/(TopWidth))
        self.critical_flow = (self.area**3.0 * self.gravity / self.top_width)**0.5
        return self.critical_flow

    def _check_elevation_to_max_depth(self, elevation):
        """Check if an elevation is above the maximum depth.

        Args:
            elevation (:obj:`float`): The elevation to check
        """
        max_depth = self.high_elev - self.low_elev

        channel_depth = elevation - self.low_elev
        if channel_depth > max_depth:
            # Add vertical walls
            new_high_elevation = channel_depth * 2.0 + self.low_elev
            self.stations, self.elevations = self.add_vertical_wall_to_elevation(
                self.stations, self.elevations, new_high_elevation
            )
            self.low_elev, self.high_elev = self.determine_low_and_high_of_elevations(self.elevations)

    def _compute_geometry_from_elevation(self, elevation, elevation_stations):
        """Determines the geomety variables (area and perimeter mainly, wet stations too) for a given elevation.

        Args:
            elevation (:obj:`float`): Elevation of the water surface (wse)
            elevation_stations (:obj:`list`): Stations where the wse meets the channel (computed here and passed back
                for plotting)
        """
        # h_tri = 0.0 # height of tri. in x-sectional area
        # x_left = 0.0

        # area_triangle_left = 0.0
        area_triangle_right = 0.0
        area_trapezoids = 0.0  # accumulated area of trap.

        area = 0.0
        top_width = 0.0
        perimeter = 0.0

        # perimeter_triangle_left = 0.0
        perimeter_triangle_right = 0.0
        perimeter_trapezoid = 0.0

        i = 1
        while i < len(self.stations):
            # Do analysis on all points
            if (self.elevations[i] < elevation) or (self.elevations[i - 1] < elevation):
                h_tri = abs(elevation - self.elevations[i])
                x_dist = self.stations[i] - self.stations[i - 1]

                # Similar Triangles-- 1 is the larger triangle of the segment
                #                     2 is the smaller triangle with the water elevation
                d_x1 = x_dist
                d_y1 = self.elevations[i - 1] - self.elevations[i]
                d_y2 = elevation - self.elevations[i]

                d_x2 = d_x1 * d_y2 / d_y1

                x_left = self.stations[i - 1] + d_x1 - d_x2
                if len(elevation_stations):
                    elevation_stations.append(float("nan"))
                elevation_stations.append(self.lowest_station + x_left)

                if i < len(self.elevations):
                    if self.elevations[i] == elevation:
                        x_left = self.stations[i]
                    elif self.elevations[i - 1] == elevation:
                        x_left = self.stations[i - 1]

                # Computation of x position at waterlevel using similar triangles
                area_triangle_left = (0.5 * h_tri * (self.stations[i] - x_left))

                perimeter_triangle_left = ((x_left - self.stations[i])**2 + h_tri**2)**(1 / 2)

                i += 1

                # While still below water, you have trapezoidal areas
                count = 0
                while i < len(self.stations) and self.elevations[i] < elevation and self.elevations[i - 1] < elevation:
                    count += 1
                    x_dist = self.stations[i] - self.stations[i - 1]
                    h1 = abs(elevation - self.elevations[i - 1])
                    h2 = abs(elevation - self.elevations[i])
                    temp_area = (((h1 + h2) * 0.5) * x_dist)
                    area_trapezoids += temp_area
                    temp_perimeter = (x_dist**2 + (h1 - h2)**2)**0.5
                    perimeter_trapezoid += temp_perimeter

                    if self.stations[i] != self.stations[i - 1]:
                        slope = (self.elevations[i] - self.elevations[i - 1])
                        slope /= (self.stations[i] - self.stations[i - 1])
                    else:
                        # If vertical, does next point go up or down? --set slope
                        if self.elevations[i] > self.elevations[i - 1]:
                            slope = 1
                        else:
                            slope = -1
                    if i < len(self.elevations):
                        if self.elevations[i - 1] == elevation and slope < 0.0 and count == 1:
                            x_left = self.stations[i - 1]
                    i += 1
                # Now it's above water
                if i < len(self.stations):
                    # Only do this if haven't reached last point
                    h_tri = abs(elevation - self.elevations[i - 1])
                    x_dist = self.stations[i] - self.stations[i - 1]

                    # Similar Triangles-- 1 is the larger triangle of the segment
                    #                     2 is the smaller triangle with the water elevation
                    d_x1 = x_dist
                    d_y1 = self.elevations[i] - self.elevations[i - 1]
                    d_y2 = elevation - self.elevations[i - 1]

                    d_x2 = d_x1 * d_y2 / d_y1

                    x_right = self.stations[i - 1] + d_x2
                    elevation_stations.append(self.lowest_station + x_right)

                    if self.elevations[i] == elevation:
                        x_right = self.stations[i]
                    elif self.elevations[i - 1] == elevation:
                        x_right = self.stations[i - 1]

                    area_triangle_right = (.5 * h_tri * (x_right - self.stations[i - 1]))

                    perimeter_triangle_right = ((x_right - self.stations[i - 1])**2 + h_tri**2)**0.5

                else:
                    x_right = self.stations[len(self.stations) - 1]
                # Add final results
                area += area_triangle_left + area_trapezoids + area_triangle_right
                perimeter += perimeter_triangle_left + perimeter_trapezoid + perimeter_triangle_right
                top_width += abs(x_right - x_left)

                # area_triangle_left = perimeter_triangle_left = x_left = 0.0
                area_trapezoids = area_triangle_right = 0.0
                perimeter_triangle_right = perimeter_trapezoid = 0.0

                if i < len(self.stations):
                    i += 1
            else:
                # Point is not below water yet, so check the next
                i += 1
        # end while
        self.perimeter = perimeter
        self.area = area
        self.top_width = top_width

    @staticmethod
    def _newton_raphson(func_ptr, x1_lower_range, x2_upper_range, x_precision, target):
        """This function is used in Hydraulic Toolbox for defined shapes (that can have a derivative).

        I'm leaving this function for the day that this code is expanded to that (it should find a solution faster
        than simple interpolation).

        Args:
            func_ptr (:obj:`callable`): Method to calculate flow and flow eq derivative. Should take a float argument
                and return a tuple of two floats.
            x1_lower_range (:obj:`float`): Minimum allowable result
            x2_upper_range (:obj:`float`): Maximum allowable result
            x_precision (:obj:`float`): The allowable difference of result between iterations
            target (:obj:`float`): The target result
        """
        # flow_eq_derivative = 0.0
        # dx = 0.0
        # dx_old = 0.0
        # f = 0.0
        # flow_higher = 0.0
        # flow_lower = 0.0
        #
        # temp = 0.0
        # x_higher = 0.0
        # x_lower = 0.0
        # rts = 0.0

        # This equation uses the function divided by the function derived
        flow_lower, flow_eq_derivative = func_ptr(x1_lower_range)
        flow_higher, flow_eq_derivative = func_ptr(x2_upper_range)

        if ((flow_lower > target) and (flow_higher > target)) or ((flow_lower < target) and (flow_higher < target)):
            raise Exception("Root must be bracketed in NewtonRaphson")
        if flow_higher == target:
            return x2_upper_range
        if flow_lower < target:
            x_lower = x1_lower_range
            x_higher = x2_upper_range
        else:
            x_higher = x1_lower_range
            x_lower = x2_upper_range
        rts = 0.25 * (x1_lower_range + x2_upper_range)
        dx_old = abs(x2_upper_range - x1_lower_range)
        dx = dx_old
        f, flow_eq_derivative = func_ptr(rts)

        for _ in range(1, 100):
            if ((rts - x_higher) * flow_eq_derivative - f) * ((rts - x_lower) * flow_eq_derivative - f) > 0.0 or \
                    abs(2.0 * f) > abs(dx_old * flow_eq_derivative):
                dx_old = dx
                dx = 0.5 * (x_higher - x_lower)
                rts = x_lower + dx
                if x_lower == rts:
                    return rts
            else:
                dx_old = dx
                dx = f / flow_eq_derivative
                temp = rts
                rts -= dx
                if temp == rts:
                    return rts
            if abs(dx) < x_precision:
                return rts
            f, flow_eq_derivative = func_ptr(rts)
            if f < target:
                x_lower = rts
            else:
                x_higher = rts
        raise Exception("Maximum number of iterations exceeded")
