"""PlotPresenter class that will create a model-view for the plot."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import sys

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.FhwaVariable.core_data.units.unit_conversion import ConversionCalc
from xms.FhwaVariable.interface_adapters.view_model.dialog.plot import Plot


class PlotPresenter:
    """A class to create the model-view for the plot."""
    def __init__(self, theme):
        """Initialize the PlotPresenter Class."""
        self.theme = theme

    def build_plot(self, calcdata, plot_name, plot_options=None):
        """Build the plot.

        Args:
            calcdata (CalcData): The CalcData.
            plot_name (str): name of the plot

        Returns:
            Plot: The view-model plot (all the data needed to display the plot).
        """
        vm_plot = Plot()
        self.set_default_colors(vm_plot)

        # TODO: Refactor this to set this option in the plot options (where this sets default,
        # but the user can change it for plots)
        _, vm_plot.is_equal_aspect = calcdata.get_setting('Create plots with equal aspect')

        self.setup_plot_data(vm_plot, calcdata, plot_name, plot_options)

        return vm_plot

    def set_default_colors(self, vm_plot):
        """Set the default plot."""
        vm_plot.bg_color = self.theme['Plot background color']
        vm_plot.text_color = self.theme['Plot text color']
        vm_plot.xtick_color = self.theme['Plot tick color']
        vm_plot.ytick_color = self.theme['Plot tick color']

    def setup_plot_data(self, vm_plot, calcdata, plot_name, plot_options=None):
        """Setup the plot data.

        Args:
            vm_plot (Plot): The plot to setup.
            calcdata (CalcData): The CalcData.
            plot_name (str): name of the plot
            plot_options (dict): The plot options
        """
        _, null_data = calcdata.get_setting('Null data')
        _, selected_unit_system = calcdata.get_setting('Selected unit system', 'U.S. Customary Units')
        unit_converter = ConversionCalc(calcdata.app_data)

        first_time = True

        vm_plot.min_x = sys.float_info.max
        vm_plot.max_x = -sys.float_info.max
        vm_plot.min_y = sys.float_info.max
        vm_plot.max_y = -sys.float_info.max

        vm_plot.null_data = null_data

        if calcdata.plot_dict is not None and plot_name in calcdata.plot_dict:
            vm_plot.plot_options_dict = calcdata.plot_dict[plot_name]
        else:
            if plot_options is not None:
                vm_plot.plot_options_dict = plot_options
            elif 'Plot options' in calcdata.input and plot_name in calcdata.input['Plot options']:
                if hasattr(calcdata, 'plot_dict') and len(calcdata.plot_dict) > 0 and plot_name in calcdata.plot_dict:
                    vm_plot.plot_options_dict = calcdata.plot_dict[plot_name]
                else:
                    vm_plot.plot_options_dict = calcdata.input['Plot options'][
                        plot_name].get_val().get_plot_options_dict()
            else:
                return None

        min_y = sys.float_info.max
        if 'series' in vm_plot.plot_options_dict:
            for series in vm_plot.plot_options_dict['series']:
                min_y = self.find_min_y_in_series(vm_plot.plot_options_dict['series'][series], null_data, min_y)
        if 'points' in vm_plot.plot_options_dict:
            for points in vm_plot.plot_options_dict['points']:
                min_y = self.find_min_y_in_series(vm_plot.plot_options_dict['points'][points], null_data, min_y)
        if 'lines' in vm_plot.plot_options_dict:
            for lines in vm_plot.plot_options_dict['lines']:
                line_y_var = vm_plot.plot_options_dict['lines'][lines].get('y var', None)
                if line_y_var is not None:
                    y_vals = line_y_var.get_val()
                    filtered_y_vals = [y for y in y_vals if y != null_data]
                    if filtered_y_vals:
                        min_y_in_lines = min(filtered_y_vals)
                        if min_y_in_lines < min_y:
                            min_y = min_y_in_lines

        if 'series' in vm_plot.plot_options_dict:
            for series in vm_plot.plot_options_dict['series']:
                self._add_series_to_plot(vm_plot, plot_name, vm_plot.plot_options_dict['series'][series],
                                         unit_converter, first_time, selected_unit_system, null_data, min_y)

        if 'points' in vm_plot.plot_options_dict:
            for points in vm_plot.plot_options_dict['points']:
                self._add_points_to_plot(vm_plot, plot_name, vm_plot.plot_options_dict['points'][points],
                                         unit_converter, first_time, selected_unit_system, null_data)

        # Lines are last, because we want the lines to span the entire plot and need the min/max determined
        if 'lines' in vm_plot.plot_options_dict:
            for lines in vm_plot.plot_options_dict['lines']:
                self._add_lines_to_plot(vm_plot, vm_plot.plot_options_dict['lines'][lines], unit_converter,
                                        first_time, selected_unit_system, null_data)

        if vm_plot.plot_options_dict['Determine X-Axis limits from data']:
            if plot_options is None:
                if plot_name in calcdata.input['Plot options']:
                    calcdata.input['Plot options'][plot_name].get_val().input['Minimum X-Axis limit'].set_val(
                        vm_plot.min_x)
                    calcdata.input['Plot options'][plot_name].get_val().input['Maximum X-Axis limit'].set_val(
                        vm_plot.max_x)

        if vm_plot.plot_options_dict['Determine Y-Axis limits from data']:
            if plot_options is None:
                if plot_name in calcdata.input['Plot options']:
                    calcdata.input['Plot options'][plot_name].get_val().input['Minimum Y-Axis limit'].set_val(
                        vm_plot.min_y)
                    calcdata.input['Plot options'][plot_name].get_val().input['Maximum Y-Axis limit'].set_val(
                        vm_plot.max_y)

    def find_min_y_in_series(self, series, null_data, current_min_y):
        """Find the minimum y value in the series.

        Args:
            series (dict): The series to check
            null_data (float): The null data value
            current_min_y (float): The current minimum y value

        Returns:
            float: The minimum y value
        """
        y_var = series['y var']
        if y_var is not None:
            y_vals = y_var.get_val()
            filtered_y_vals = [y for y in y_vals if y != null_data]
            if filtered_y_vals:
                min_y_in_series = min(filtered_y_vals)
                if min_y_in_series < current_min_y:
                    current_min_y = min_y_in_series

        yy_var = None
        if 'yy var' in series and series['Plot yy']:
            yy_var = series['yy var']

        if yy_var is not None:
            yy_vals = yy_var.get_val()
            filtered_yy_vals = [yy for yy in yy_vals if yy != null_data]
            if filtered_yy_vals:
                min_yy_in_series = min(filtered_yy_vals)
                if min_yy_in_series < current_min_y:
                    current_min_y = min_yy_in_series

        return current_min_y

    def _add_series_to_plot(self, vm_plot, plot_name, series, unit_converter, first_time, selected_unit_system,
                            null_data, min_y=None):
        """Add the channel or culvert geometry to the plot.

        Args:
            plot_name (str): name of the plot
            series (dict): the series to plot
            unit_converter (ConversionCalc): class to convert units
            first_time (bool): whether to set the titles of the axes and units
            selected_unit_system (str): the selected unit system
            null_data (float): the null data value

        Returns:
            handle (list): the handle to the plot, to create the item in the legend
        """
        # series_name = series['Name']
        x_column, y_column, xx_column, yy_column = self._convert_x_y_data_to_plot(
            vm_plot, series, unit_converter, first_time, selected_unit_system, null_data)
        # handle = None
        if x_column is None or y_column is None:
            return

        series['x_data'] = x_column
        series['y_data'] = y_column
        series['yy_data'] = yy_column
        series['xx_data'] = xx_column

        # Check if the points are closed: if the first and last points are the same
        is_closed_set = self.check_if_geometry_is_closed(x_column, y_column, null_data)
        series['is_closed'] = is_closed_set
        # fill_rgb = series['Fill color']
        series['polygons'] = None
        series['y2'] = None
        hatch = series.get('Fill pattern', None)
        polygons = []
        if hatch != 'no fill' and is_closed_set:
            polygons = self.split_points(x_column, y_column, null_data)
            series['polygons'] = polygons
        elif hatch != 'no fill' and series['Fill below line']:
            if min_y is None:
                min_y = self.get_min_y(vm_plot, y_column, null_data)
            y2 = [min_y] * len(y_column)
            series['y2'] = y2

        # Now process the yy variable
        if yy_column is not None:
            # Check if the points are closed: if the first and last points are the same
            if xx_column is not None:
                is_closed_set = self.check_if_geometry_is_closed(xx_column, yy_column, null_data)
            else:
                is_closed_set = self.check_if_geometry_is_closed(x_column, yy_column, null_data)
            series['is_closed_yy'] = is_closed_set
            # fill_rgb = series['Fill color']
            series['polygons_yy'] = None
            series['yy2'] = None
            hatch = series.get('Fill pattern', None)
            if hatch != 'no fill' and is_closed_set and len(polygons) > 0:
                if xx_column is not None:
                    is_closed_set = self.check_if_geometry_is_closed(xx_column, yy_column, null_data)
                else:
                    is_closed_set = self.check_if_geometry_is_closed(x_column, yy_column, null_data)
                series['polygons_yy'] = polygons
            elif hatch != 'no fill' and series['Fill below line yy']:
                if min_y is None:
                    min_y = self.get_min_y(vm_plot, yy_column, null_data)
                y2 = [min_y] * len(yy_column)
                series['yy2'] = y2

    def get_min_y(self, vm_plot, y_column, null_data):
        """Get the minimum y value from the y_column, excluding the null_data values.

        Args:
            vm_plot (Plot): The plot view model
            y_column (list): The y column data
            null_data (float): The null data value

        Returns:
            float: The minimum y value
        """
        # Filter out null_data values from y_column
        filtered_y_column = [y for y in y_column if y != null_data]

        # Combine filtered_y_column with vm_plot_min_y if it is not null_data
        if vm_plot.min_y != null_data:
            combined_values = filtered_y_column + [vm_plot.min_y]
        else:
            combined_values = filtered_y_column

        # Find the minimum value from the combined list
        if combined_values:
            min_y = min(combined_values)
        else:
            min_y = None  # Handle the case where all values are null_data

        return min_y

    def _add_lines_to_plot(self, vm_plot, lines_options, unit_converter, first_time, selected_unit_system, null_data):
        """Add a set of vertical line to the plot.

        Args:
            vm_plot (Plot): The plot to add the lines to
            lines_options (list of dict): The list of lines options
            unit_converter (ConversionCalc): class to convert units
            first_time (bool): whether to set the titles of the axes and units
            selected_unit_system (str): the selected unit system
            null_data (float): the null data value
        """
        # handle = None
        line_color = lines_options['Line color']
        if line_color is None:
            return None

        if first_time:
            selected_unit = ''
            if lines_options['selected_us_unit'] != '' and lines_options['selected_si_unit'] != '':
                if selected_unit_system == 'SI Units':
                    selected_unit = lines_options['selected_si_unit']
                else:
                    selected_unit = lines_options['selected_us_unit']
            vm_plot.x_units = selected_unit
        else:
            selected_unit = vm_plot.x_units

        conversion = 1.0
        if lines_options['native_unit'] != '' and selected_unit != '':
            _, conversion = unit_converter.convert_units(lines_options['native_unit'], selected_unit, 1.0)
        # the x in the next line is the intercept values, could be x or y depending on the alignment
        intercepts = [x * conversion if x != null_data else x for x in lines_options['Line intercepts']]

        if lines_options['Line alignment'] == 'vertical':
            lines_options['vertical_intercepts'] = intercepts
            for _ in intercepts:
                y_min, y_max = vm_plot.min_y, vm_plot.max_y
                y = y_max - (y_max - y_min) * 0.025  # 2.5% from the top
                if lines_options['Text alignment'] == 'bottom':
                    y = y_min + (y_max - y_min) * 0.025  # 2.5% from the bottom
                elif lines_options['Text alignment'] == 'center':
                    y = (y_max + y_min) / 2
                lines_options['label y'] = y

        else:
            lines_options['horizontal_intercepts'] = intercepts
            for _ in intercepts:

                x_min, x_max = vm_plot.min_x, vm_plot.max_x
                x = x_max - (x_max - x_min) * 0.025  # 2.5% from the Right
                if lines_options['Text alignment'] == 'left':
                    x = x_min + (x_max - x_min) * 0.025  # 2.5% from the Left
                elif lines_options['Text alignment'] == 'center':
                    x = (x_max + x_min) / 2
                lines_options['label x'] = x

    def _add_points_to_plot(self, vm_plot, plot_name, points_options, unit_converter, show_axis_titles,
                            selected_unit_system, null_data):
        """Add a set of points to the plot.

        Args:
            plot_name (str): name of the plot
            points_options (dict): The points options
            unit_converter (ConversionCalc): class to convert units
            show_axis_titles (bool): whether to show the titles of the axes
            selected_unit_system (str): the selected unit system
            null_data (float): the null data value

        Returns:
            handle (list): the handle to the plot, to create the item in the legend
        """
        # handle = None
        x_column, y_column, xx_column, yy_column = self._convert_x_y_data_to_plot(
            vm_plot, points_options, unit_converter, show_axis_titles, selected_unit_system, null_data)
        points_options['x_data'] = x_column
        points_options['y_data'] = y_column
        points_options['xx_data'] = xx_column
        points_options['yy_data'] = yy_column

    def _convert_x_y_data_to_plot(self, vm_plot, series, unit_converter, first_time, selected_unit_system,
                                  null_data):
        """Add the channel or culvert geometry to the plot.

        Args:
            series (dict): the series to plot
            unit_converter (ConversionCalc): class to convert units
            first_time (bool): whether to set x, y units and label
            selected_unit_system (str): the selected unit system
            null_data (float): the null data value

        Returns:
            x_column (list): the x column data
            y_column (list): the y column data
        """
        x_var = series['x var']
        y_var = series['y var']
        xx_var = None
        if 'xx var' in series:
            xx_var = series['xx var']
        yy_var = None
        if 'yy var' in series:
            yy_var = series['yy var']

        if 'Plot yy' not in series or series['Plot yy'] is False:
            yy_var = None

        x_column = None
        y_column = None
        xx_column = None
        yy_column = None

        # Validate that we have good data
        if x_var is None or y_var is None:
            return x_column, y_column, xx_column, yy_column
        x_vals = x_var.get_val()
        y_vals = y_var.get_val()
        min_length = 2  # Define the minimum length as needed
        if len(x_vals) < min_length or len(y_vals) < min_length or all(x == 0 for x in x_vals) \
                and all(y == 0 for y in y_vals):
            return x_column, y_column, xx_column, yy_column

        yy_axis_label = ''
        if first_time:
            x_units = x_var.get_selected_unit(selected_unit_system)
            y_units = y_var.get_selected_unit(selected_unit_system)
            xx_units = xx_var.get_selected_unit(selected_unit_system) if xx_var is not None else None
            yy_units = yy_var.get_selected_unit(selected_unit_system) if yy_var is not None else None
            x_axis_label = f'{x_var.name} ({x_var.get_selected_unit(selected_unit_system)})'
            y_axis_label = f'{y_var.name} ({y_var.get_selected_unit(selected_unit_system)})'
            if yy_var is not None:
                yy_axis_label += f'{yy_var.name} ({yy_var.get_selected_unit(selected_unit_system)})'
            vm_plot.x_units = x_units
            vm_plot.y_units = y_units
            vm_plot.x_axis_label = x_axis_label
            vm_plot.y_axis_label = y_axis_label
            if yy_units is not None and yy_units != y_units:
                vm_plot.yy_units = yy_units
                vm_plot.yy_axis_label = yy_axis_label
        else:
            x_units = vm_plot.x_units
            y_units = vm_plot.y_units
            xx_units = vm_plot.xx_units
            yy_units = vm_plot.yy_units

        _, x_conversion = unit_converter.convert_units(x_var.native_unit, x_units, 1.0)
        _, y_conversion = unit_converter.convert_units(y_var.native_unit, y_units, 1.0)

        x_column = [float(x) * x_conversion if x != null_data else x for x in x_var.get_val()]
        y_column = [float(y) * y_conversion if y != null_data else y for y in y_var.get_val()]

        min_x = min(x_column)
        max_x = max(x_column)
        min_y = min(y_column)
        max_y = max(y_column)

        if vm_plot.min_x is None or vm_plot.min_x > min_x:
            vm_plot.min_x = min_x
        if vm_plot.max_x is None or vm_plot.max_x < max_x:
            vm_plot.max_x = max_x
        if vm_plot.min_y is None or vm_plot.min_y > min_y:
            vm_plot.min_y = min_y
        if vm_plot.max_y is None or vm_plot.max_y < max_y:
            vm_plot.max_y = max_y

        # Handle the x variable
        if xx_var is not None:
            _, xx_conversion = unit_converter.convert_units(xx_var.native_unit, xx_units, 1.0)
            xx_column = [float(xx) * xx_conversion if xx != null_data else xx for xx in xx_var.get_val()]

            min_xx = min(xx_column)
            max_xx = max(xx_column)
            if vm_plot.min_x is None or vm_plot.min_x > min_xx:
                vm_plot.min_x = min_xx
            if vm_plot.max_x is None or vm_plot.max_x < max_xx:
                vm_plot.max_x = max_xx

        # Handle the yy variable
        if yy_var is not None:
            _, yy_conversion = unit_converter.convert_units(yy_var.native_unit, yy_units, 1.0)
            yy_column = [float(yy) * yy_conversion if yy != null_data else yy for yy in yy_var.get_val()]

            min_yy = min(yy_column)
            max_yy = max(yy_column)
            if vm_plot.min_y is None or vm_plot.min_y > min_yy:
                vm_plot.min_y = min_yy
            if vm_plot.max_y is None or vm_plot.max_y < max_yy:
                vm_plot.max_y = max_yy

        return x_column, y_column, xx_column, yy_column

    @staticmethod
    def check_if_geometry_is_closed(x_data, y_data, null_data):
        """Check if the geometry is closed.

        Args:
            x_data (list): The x data
            y_data (list): The y data
            null_data (float): The null data value

        Returns:
            bool: True if the geometry is closed
        """
        if len(x_data) < 1 or len(y_data) < 1:
            return False
        is_closed = x_data[0] == x_data[-1] and y_data[0] == y_data[-1]
        if is_closed:
            return is_closed

        # nan_indices = [i for i, x in enumerate(x_data) if math.isnan(x) or math.isnan(y_data[i])]
        nan_indices = [i for i, x in enumerate(x_data) if x == null_data or y_data[i] == null_data]

        if len(nan_indices) == 0:
            return is_closed

        is_closed = False
        for nan_index in nan_indices:
            if nan_index > 0 and (x_data[0] == x_data[nan_index - 1] and y_data[0] == y_data[nan_index - 1]):
                is_closed = True
                break

        if nan_index > 0 and (x_data[nan_indices[-1]] == x_data[-1] and y_data[nan_indices[-1]] == y_data[-1]):
            is_closed = True

        return is_closed

    @staticmethod
    def split_points(x_column, y_column, null_data):
        """Split points into separate polygons based on NaNs or when points close."""
        points = list(zip(x_column, y_column))
        polygons = []
        current_polygon = []

        for _, point in enumerate(points):
            if point[0] == null_data or point[1] == null_data:
                if current_polygon:
                    polygons.append(current_polygon)
                    current_polygon = []
            else:
                current_polygon.append(point)
                # Check if the points close
                if len(current_polygon) > 1 and point == current_polygon[0]:
                    polygons.append(current_polygon)
                    current_polygon = []

        if current_polygon:
            polygons.append(current_polygon)

        return polygons
