"""CalcData for performing Culvert crossing operations."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import sys

# 2. Third party modules
from sortedcontainers import SortedDict

# 3. Aquaveo modules
from xms.FhwaVariable.core_data.calculator.calculator import Calculator

# 4. Local modules
from xms.HydraulicToolboxCalc.util.interpolation import Interpolation


class CulvertCrossingCalc(Calculator):
    """A class that defines a culvert crossing and performs Culvert crossing computations."""

    def assign_results(self):
        """Assigns the results from a computational run to the appropriate locations."""
        # Assign the results
        _, gamma = self.get_data('Unit weight of water (γw)')

        multiple = False
        if multiple:
            pass
        else:
            self.boundary_shear_stress = []
            self.max_shear_stress = []

            for energy_slope, hyd_rad, depth in zip(self.energy_slope,
                                                    self.hydraulic_radius,
                                                    self.y):
                self.boundary_shear_stress.append(gamma * hyd_rad * energy_slope)
                self.max_shear_stress.append(gamma * depth * energy_slope)

            self.results['Station'] = self.wse_stations
            self.results['WSE'] = self.wse_elevations
            self.results['Distance from inlet'] = self.x
            self.results['Depth'] = self.y
            self.results['Flow area'] = self.flow_area
            self.results['Wetted perimeter'] = self.wetted_perimeter
            self.results['Top width'] = self.top_width
            self.results['Hydraulic radius'] = self.hydraulic_radius
            self.results['Manning n'] = self.manning_n
            self.results['Velocity'] = self.velocity
            self.results['Energy'] = self.energy
            self.results['Energy loss'] = self.energy_loss
            self.results['Energy slope'] = self.energy_slope
            self.results['Boundary shear stress'] = self.boundary_shear_stress
            self.results['Max shear stress'] = self.max_shear_stress

            if self.input['Display channel slope as flat']:
                self.plot_x = self.x
                self.plot_y = self.y
            else:
                self.plot_x = self.wse_stations
                self.plot_y = self.wse_elevations

            if self.hyd_jump_swept_out:
                self.warnings['Hydraulic jump'] = (
                    'An hydraulic jump forms near the culvert outlet but may be swept into the tailwater channel'
                )

    def _get_can_compute(self):
        """Determines if there is enough data to make a computation and if there isn't, add a warning for each reason.

        Returns:
            bool: True if can compute
        """
        result = True

        found = False
        flows = self.input_dict['calc_data']['Flows']
        # flows.compute_data()
        if hasattr(flows, '__len__') and len(flows) < 1:
            self.warnings['Flows'] = "Please enter a flow"
            result = False
        else:
            for flow in flows:
                if flow > 0.0:
                    found = True
                    break
            if not found:
                self.warnings['Flows'] = "Please enter a positive, non-zero flow"
                result = False

        if found:
            self.input_dict['calc_data']['Tailwater data']['calculator'].input_dict['calc_data']['Flows'] = \
                self.input_dict['calc_data']['Flows']
        else:
            # Just to remove warnings
            self.input_dict['calc_data']['Tailwater data']['calculator'].input_dict['calc_data']['Flows'] = [10]
        if not self.input_dict['calc_data']['Tailwater data']['calculator'].get_can_compute():
            self.warnings.update(self.input_dict['calc_data']['Tailwater data']['calculator'].warnings)
            result = False

        if not self.input_dict['calc_data']['Roadway data']['calculator'].get_can_compute():
            self.warnings.update(self.input_dict['calc_data']['Roadway data']['calculator'].warnings)
            result = False

        self.culvert_barrels = []
        for culvert_data in self.input_dict['calc_data']['Culvert data']:
            if culvert_data in ['Selected item']:
                continue
            culvert = self.input_dict['calc_data']['Culvert data'][culvert_data]['calculator']
            culvert.input_dict = copy.copy(self.input_dict)
            culvert.input_dict['calc_data'] = copy.copy(self.input_dict['calc_data']['Culvert data'][culvert_data])
            self.culvert_barrels.append(culvert)

        for culvert in self.culvert_barrels:
            if not culvert.get_can_compute():
                # self.warnings.extend(self.input['Site data'].get_val().warnings)
                result = False

        # Initialize flow/hw list for culverts (it will be initialized and setup later)
        self.culv_hw_flow_list = [SortedDict() for _ in range(len(self.culvert_barrels))]

        return result

    def _compute_data(self):
        """Computes the data possible; stores results in self.

        Returns:
            bool: True if successful
        """
        # Range for table (brackets for our fin specific flow solutions)
        # Flow rate of the Mississippi River is 593,000 cfs
        # flows = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144,\
        #         524288]
        # I probably don't need to work out a curve based on flows, if I already have a flow to start with
        flows = self.input_dict['calc_data']['Flows']

        # Reset the dictionary of crossing flow vs hw
        self.cross_invert_elevation = sys.float_info.max
        self.cross_crown_elevation = -sys.float_info.max
        self.overtopping_elevation = self.input_dict['calc_data']['Roadway data']['calculator']. \
            determine_overtopping_elevation()
        self.cross_max_3Rise_elevation = -sys.float_info.max
        self.max_rise = 0.0
        for culv_barrel_calc in self.culvert_barrels:
            culv_inlet_elev = culv_barrel_calc.input_dict['calc_data']['Site data']['calculator'].input_dict[
                'calc_data']['Inlet invert elevation']
            if self.cross_invert_elevation > culv_inlet_elev:
                self.cross_invert_elevation = culv_inlet_elev
            rise = culv_barrel_calc.input_dict['calc_data']['Inlet data']['Geometry']['Rise']
            if culv_barrel_calc.input_dict['calc_data']['Inlet data']['Culvert shape'] == 'Circular':
                rise = culv_barrel_calc.input_dict['calc_data']['Inlet data']['Geometry']['Diameter']
            if rise > self.max_rise:
                self.max_rise = rise
            if self.cross_crown_elevation < culv_inlet_elev + rise:
                self.cross_crown_elevation = culv_inlet_elev + rise
            if self.cross_max_3Rise_elevation < culv_inlet_elev + 3 * rise:
                self.cross_max_3Rise_elevation = culv_inlet_elev + 3 * rise
        self.cross_hw_flow = SortedDict({self.cross_invert_elevation: 0.0})

        # Determine the tailwater (downstream water depth
        self.determine_tw_depths_and_velocities(flows)

        # Bracket each culvert barrel (determine Q vs HW for each barrel)

        # Balance the flow across culvert barrels and roadway (weir)
        max_flow = max(flows)
        for index in range(len(flows)):
            headwater = self.balance_flows_across_barrels_and_roadway(
                flows[index], self.tw_depths[index], self.tw_elevations[index], self.tw_velocities[index],
                self.tw_constant, max_flow)
            if 'Headwater' not in self.results:
                self.results['Headwater'] = []
            self.results['Headwater'].append(headwater)

        return True

    def determine_tw_depths_and_velocities(self, flows):
        """Determines the tailwater depths and velocities for the given flow rates.

        Args:
            flows (list): The flow rates to evaluate.

        Returns:
            tuple: A tuple containing two lists - the tailwater depths and velocities.
        """
        self.tailwater_channel = self.input_dict['calc_data']['Tailwater data']['calculator']
        self.tailwater_channel.input_dict['calc_data']['Flows'] = flows
        self.tailwater_channel.compute_data()
        self.tw_crossing_flows = flows
        self.tw_depths = self.tailwater_channel.results['Tailwater depths']
        self.tw_elevations = self.tailwater_channel.results['Tailwater elevations']
        self.tw_velocities = self.tailwater_channel.results['Tailwater velocities']
        # self.tailwater_depths = self.tailwater_channel.results['Depths']
        # self.tailwater_velocities = self.tailwater_channel.results['Average velocity']
        self.tw_constant = self.input_dict['calc_data']['Tailwater data']['calculator'].input_dict[
            'calc_data']['Tailwater type'] == 'Constant tailwater elevation'
        # tw_invert = self.tailwater_channel.input_dict['calc_data']['Channel invert elevation']
        # for tw_depth in self.tw_depths:
        #     self.tw_elevations.append(tw_depth + tw_invert)
        return self.tw_depths, self.tw_elevations, self.tw_velocities

    def run_single_culvert_run(self, culvert_index, culvert_flow, crossing_flow, update_tw=False):
        """Runs the culvert with an updated tailwater condition for the crossing flow.

        Args:
            culvert_index (int): The index of the culvert to run
            culvert_flow (float): the flow to pass through the culvert
            crossing_flow (float): the flow passing through the tailwater structure

        Returns:
            (bool): Whether the computation succeeded or not
        """
        self.culvert_barrels[culvert_index].input_dict['calc_data']['Flows'] = [culvert_flow]
        if update_tw:
            self.tailwater_channel.input_dict['calc_data']['Flows'] = [crossing_flow]
            self.tailwater_channel.compute_data()
            tw_depths = self.tailwater_channel.results['Depths']
            tw_velocities = self.tailwater_channel.results['Average velocity']
            self.culvert_barrels[culvert_index].input_dict['calc_data']['Downstream water depth']['calculator'].\
                set_specified_depths_and_velocities(tw_depths, tw_velocities)
        return self.culvert_barrels[culvert_index]._compute_data()

    def balance_flows_across_barrels_and_roadway(self, target_flow, tailwater_depth, tw_elevation,
                                                 tw_velocity, tw_constant=False, max_flow=None):
        """Balance the flows across all of the culvert barrels and the roadway (weir).

        Args:
            target_flow (float): the flow rate that we need to match by finding the correct headwater.
            tailwater_depth (float): the tailwater depth at this flow rate
            tw_elevation (float): the tailwater elevation at this flow rate
            tw_velocity (float): the velocity of the tailwater flow
            tw_constant (bool): Whether the tailwater is constant across all flow rates

        Returns:
            hw_guess (float): The headwater that results in the target flow.
        """
        hw_guess = 0.0
        _, hw_tol = self.get_data('HW error')
        # hw_tol_init = 1.0  # We really don't need to be very accurate in setting up our interpolation data
        _, max_loops = self.get_data('Max number of iterations')
        difference = 1.0
        count = 0

        if not tw_constant:
            # Adjust tailwater
            for culvert_calc in self.culvert_barrels:
                culvert_calc.input_dict['calc_data']['Downstream water depth'][
                    'calculator'].set_specified_depths_and_velocities([tailwater_depth], [tw_velocity])

            # Set HW to Flow table per culvert
            self.setup_hw_to_flow_lists_per_culvert_barrel(tw_elevation, max_flow)

        _, null_data = self.get_data('Null data')
        _, zero_tol = self.get_data('Zero tolerance')
        flow_interp = Interpolation([], [], null_data=null_data, zero_tol=zero_tol)
        flow_interp.use_second_interpolation = True

        self.setup_hw_to_flow_lists_for_crossing(tw_elevation, tw_velocity, hw_tol, max_flow, null_data, zero_tol)

        # # Add crown elevation, at overtopping, and 3 * rise to bracket our interpolation
        _, max_crown_crossing_flow = self.compute_crossing_flow_for_hw(self.cross_crown_elevation, tw_elevation,
                                                                       hw_tol, null_data, zero_tol)
        self.cross_hw_flow[self.cross_crown_elevation] = max_crown_crossing_flow

        _, overtopping_crossing_flow = self.compute_crossing_flow_for_hw(self.overtopping_elevation, tw_elevation,
                                                                         hw_tol, null_data, zero_tol)
        self.cross_hw_flow[self.overtopping_elevation] = overtopping_crossing_flow

        if overtopping_crossing_flow < max_crown_crossing_flow:
            _, max_3rise_crossing_flow = self.compute_crossing_flow_for_hw(self.cross_max_3Rise_elevation, tw_elevation,
                                                                           hw_tol, null_data, zero_tol)
            self.cross_hw_flow[self.cross_max_3Rise_elevation] = max_3rise_crossing_flow

        _, flow_err = self.get_data('Flow error', 0.001)
        _, flow_err_p = self.get_data('Flow % error', 0.005)
        curr_flow_err = min(target_flow * flow_err_p, flow_err)

        while curr_flow_err < difference and count < max_loops:
            computed_crossing_flow = 0.0

            # # Convert the dictionary to a list so we can interpolate from it
            # unk_list = []
            # flow_list = sorted(self.cross_hw_flow.keys())
            # for flow_val in flow_list:
            #     unk_list.append(self.cross_hw_flow[flow_val])
            # hw_guess = np.interp(target_flow, flow_list, unk_list)

            flow_list = list(self.cross_hw_flow.values())
            hw_list = list(self.cross_hw_flow.keys())
            flow_interp.x = flow_list
            flow_interp.y = hw_list
            hw_guess, _ = flow_interp.interpolate_y(target_flow, True)
            # cur_flow = self._compute_flow_from_elevation(hw_guess)

            _, computed_crossing_flow = self.compute_crossing_flow_for_hw(hw_guess, tw_elevation, hw_tol, null_data,
                                                                          zero_tol)

            # Update our dictionary with the crossing results
            self.cross_hw_flow[hw_guess] = computed_crossing_flow
            # self.cross_hw_flow[computed_crossing_flow] = hw_guess
            difference = abs(target_flow - computed_crossing_flow)
            count += 1
            if count == 5:
                flow_interp.use_second_interpolation = False

        return hw_guess

    def setup_hw_to_flow_lists_per_culvert_barrel(self, tw_elevation, max_flow):
        """Setup a headwater to flow lists for each culvert.

        Args:
            tw_elevation (float): elevation of the tailwater
            max_flow (float): maximum flow to compute
        """
        # Need to complete full flow calcs, then add remainder of the computations
        # flows = [1.0, 2.0, 4.0, 8.0, 16, ]  # 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
        result = True
        # num_culverts = self.input_dict['calc_data']['Number of barrels']
        # self.culv_hw_flow_list = num_culverts * [{}]
        num_culverts = len(self.culvert_barrels)
        for culvert_index in range(num_culverts):
            flow = 0.0
            max_hw = 0.0
            invert_elevation = self.culvert_barrels[culvert_index].input_dict['calc_data']['Site data']['calculator'].\
                input_dict['calc_data']['Inlet invert elevation']
            # TODO figure out why in some cases, then computed HW is less than TW
            # min_elevation = invert_elevation
            min_elevation = max(invert_elevation, tw_elevation)
            self.culv_hw_flow_list[culvert_index] = SortedDict({min_elevation: flow})
            while flow < max_flow and result:
                if flow <= 0:  # adjust after check, so our flow will be greater than max_flow
                    flow = 1.0
                else:
                    flow *= 5.0
                result = self.run_single_culvert_run(culvert_index, flow, flow)
                if result:
                    computed_hw = self.culvert_barrels[culvert_index].hw
                    self.culv_hw_flow_list[culvert_index][computed_hw] = flow
                    if max_hw < computed_hw:
                        max_hw = computed_hw

            # # Continue adding to the culvert data until we have reached at least 4 * rise of culvert barrel
            # rise = self.culvert_barrels[culvert_index].rise
            # inlet_elev = self.culvert_barrels[
            #     culvert_index].inlet_invert_elevation
            # flow = flows[-1]
            # while max_hw < 4.0 * rise + inlet_elev and (max_flow is None or flow < max_flow):
            #     flow *= 2.0
            #     self.culvert_barrels[culvert_index].input['Flows'] = [flow]
            #     result = self.run_single_culvert_run(culvert_index, flow, flow)
            #     if result:
            #         computed_hw = self.culvert_barrels[culvert_index].hw
            #         self.culv_hw_flow_list[culvert_index][flow] = computed_hw
            #         if max_hw < computed_hw:
            #             max_hw = computed_hw

    def setup_hw_to_flow_lists_for_crossing(self, tw_elevation, tw_velocity, hw_tol, max_flow, null_data, zero_tol):
        """Setup a headwater to flow lists for each culvert.

        Args:
            tw_elevation (float): elevation of the tailwater
            tw_velocity (float): velocity of the tailwater
            hw_tol (float): headwater tolerance
            max_flow (float): maximum flow to compute
            null_data (float): null data value
            zero_tol (float): zero tolerance value
        """
        # min_elevation = self.cross_invert_elevation
        # # if min_elevation < tw_elevation:
        # #     min_elevation = tw_elevation
        # # TODO Put this in the settings calc!
        # number_of_rises_to_support = 3.0  # data backs up 2-3 * the culvert rise;
        # # max_elevation = self.max_rise * number_of_rises_to_support + min_elevation
        # # num_increments = 20
        # _, num_increments = self.get_data('Number of divisions for interpolation curve', 10)
        # distance_increment = (self.max_rise * number_of_rises_to_support) / num_increments
        # if distance_increment < 0.25:
        #     distance_increment = 0.25
        # # max_hw_elev = self.max_rise * number_of_rises_to_support

        # for index in range(num_increments):
        #     elevation = min_elevation + index * distance_increment + distance_increment  # Don't start at zero
        #     result, flow = self.compute_crossing_flow_for_hw(elevation, tw_elevation, hw_tol, null_data, zero_tol)
        #     if result:
        #         self.cross_hw_flow[elevation] = flow
        #     if max_flow is not None and flow > max_flow:
        #         return
        # # Make sure we cover up to max_flow, even if above our 3 * rise elevation
        # if max_flow is not None:
        #     while flow < max_flow:
        #         elevation += distance_increment
        #         result, flow = self.compute_crossing_flow_for_hw(elevation, tw_elevation, hw_tol, null_data, zero_tol)
        #         if result:
        #             self.cross_hw_flow[elevation] = flow
        self.cross_hw_flow = SortedDict()
        flow = 0.0
        self.cross_hw_flow[self.cross_invert_elevation] = flow

    def compute_crossing_flow_for_hw(self, hw_target, tw_elevation, hw_tol, null_data, zero_tol):
        """Compute the crossing flow for a headwater.

        Args:
            hw_target (float): headwater that we are trying to match.
            tw_elevation (float): elevation of the tailwater.

        Returns:
            flow_guess (float): flow that results in the headwater target
        """
        _, max_loops = self.get_data('Max number of iterations')
        # _, null_data = self.get_data('Null data')
        # _, zero_tol = self.get_data('Zero tolerance')
        flow_guess = 0.0
        computed_crossing_flow = 0.0
        result = True
        num_culverts = len(self.culvert_barrels)

        flow_interp = Interpolation([], [], null_data=null_data, zero_tol=zero_tol)
        flow_interp.use_second_interpolation = True

        for culvert_index in range(num_culverts):
            count = 0
            difference = 1.0
            change = 1.0
            flow_guess_list = []

            while hw_tol < difference and count < max_loops and change > 0.0:
                flow_list = list(self.culv_hw_flow_list[culvert_index].values())
                hw_list = list(self.culv_hw_flow_list[culvert_index].keys())
                flow_interp.x = hw_list
                flow_interp.y = flow_list
                flow_guess, _ = flow_interp.interpolate_y(hw_target, True)

                # If we are using the largest value in our list, let's float it. Better to interpolate between numbers
                # than try to extrapolate out of them.
                if flow_guess >= flow_list[-1]:
                    flow_guess *= 2.0

                if len(flow_guess_list) > 0:
                    change = flow_guess - flow_guess_list[-1]
                flow_guess_list.append(flow_guess)
                self.culvert_barrels[culvert_index].input_dict['calc_data']['Flows'] = [flow_guess]
                self.culvert_barrels[culvert_index]._compute_data()
                computed_hw = self.culvert_barrels[culvert_index].hw
                self.culv_hw_flow_list[culvert_index][computed_hw] = flow_guess
                difference = abs(hw_target - computed_hw)
                count += 1
                if count == 5:
                    flow_interp.use_second_interpolation = False

            computed_crossing_flow += flow_guess

            if count >= max_loops:
                result = False

        # Get the flow for the roadway overtopping
        # TODO: We need to include the tailwater elevation to pass to the following function for submerged weir flow
        roadway_result, computed_flow = self.input_dict['calc_data']['Roadway data']['calculator'].compute_data_for_wse(
            hw_target, tw_elevation)
        if roadway_result:
            computed_crossing_flow += computed_flow[0]
        else:
            result = False

        return result, computed_crossing_flow
