"""SRH-2D solution plots class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import os
from pathlib import Path

# 2. Third party modules
from matplotlib.figure import Figure
import numpy as np

# 3. Aquaveo modules
from xms.guipy.dialogs.plot_and_table_dialog import np_arrays_from_file

# 4. Local modules


class SimulationPlots:
    """Creates solution plots, e.g. Net_Q/INLET_Q, Mass Balance, Wet Elements, Monitor Points WSE..."""
    def __init__(self, inf_files):  # pragma: no cover
        """Initializes the class.

        Args:
            inf_files (:obj:`list[str]`): full paths to inf_files
        """
        file_patterns = ['PT', 'LN', 'GATE', 'HY', 'INTERNAL', 'WEIR']
        self.scenarios = {}
        for inf in inf_files:
            run_name = os.path.basename(inf).replace('_INF.dat', '')
            self.scenarios[run_name] = {}
            self.scenarios[run_name]['inf_file'] = inf
            self.scenarios[run_name]['plot_files'] = {'INF': inf}
            base_dir = os.path.dirname(inf)
            base_file_name = os.path.splitext(os.path.basename(inf))[0].replace('_INF', '')
            for pattern in file_patterns:
                # Get files like 'I35_LN1.dat', 'I35_PT1_SED.dat', but not 'I35_LN_InletQ1.dat
                file_name = f'{base_file_name}_{pattern}[1234567890]*.dat'
                self.scenarios[run_name]['plot_files'][pattern] =\
                    [str(f.absolute()) for f in Path(base_dir).rglob(file_name)]
            self.scenarios[run_name]['file_data'] = {inf: np_arrays_from_file(inf)}

        self.show_legend = True
        self._inf_file = None
        self._plot_files = None
        self._file_data = None
        run_name = os.path.basename(inf_files[0]).replace('_INF.dat', '')
        self.scenarios_to_plot = [run_name]
        self._set_current_scenario(run_name)

        self.figure = None  # matplotlib.figure Figure
        self.ax = None  # matplotlib Axes

        self.figure = Figure(tight_layout=True)
        self.ax = self.figure.add_subplot(111)

    def _set_current_scenario(self, scenario_name):
        """Set local variables based on the scenario name.

        Args:
            scenario_name (:obj:`str`): name of the scenario
        """
        if scenario_name in self.scenarios:
            self._inf_file = self.scenarios[scenario_name]['inf_file']
            self._plot_files = self.scenarios[scenario_name]['plot_files']
            self._file_data = self.scenarios[scenario_name]['file_data']

    def get_max_time(self):
        """Returns the maximum time.

        Returns:
            (:obj:`float`): The max time.
        """
        return max(self._file_data[self._inf_file][1][1])

    def get_unique_solution_plot_file_name(self, last_part):
        """Returns a unique filename for use with a solution plot.

        Args:
            last_part (:obj:`str`): The extension to use (e.g. '.png').

        Returns:
            (:obj:`str`): Parent directory name of inf_file followed by last_part
            (e.g. 'Steady State plot_monitor_lines_q.png').
        """
        return os.path.basename(os.path.dirname(self._inf_file)) + last_part

    def _get_file_data_trimmed_to_time_extents(self, file_name, min_time, max_time):
        """Gets the file data that is within the min/max range if that option is specified.

        Args:
            file_name (:obj:`str`): name of file containing the data
            min_time (:obj:`float`): A specified minimum time, or None if not specifying.
            max_time (:obj:`float`): A specified maximum time, or None if not specifying.
        """
        ret_val = self._file_data[file_name]
        for d in ret_val[1]:
            d[d == -999] = np.nan
        if min_time is None or max_time is None:
            return ret_val

        if min_time < max_time:
            time_idx = -1
            for idx, heading in enumerate(ret_val[0]):
                if heading.startswith('Time'):
                    time_idx = idx
            if time_idx > -1:
                ret_val = copy.deepcopy(ret_val)
                start_idx = np.argmax(ret_val[1][time_idx] > min_time)
                end_idx = np.argmax(ret_val[1][time_idx] > max_time)
                if end_idx == 0:
                    end_idx = len(ret_val[1][time_idx]) - 1
                for idx, aa in enumerate(ret_val[1]):
                    ret_val[1][idx] = aa[start_idx:end_idx]
        return ret_val

    def create_net_q_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False
    ):
        """Creates the plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
        """
        plot_has_data = False
        self.ax.clear()
        for run in self.scenarios_to_plot:
            self._set_current_scenario(run)
            data = self._get_file_data_trimmed_to_time_extents(self._inf_file, min_time, max_time)
            if log:
                self.ax.plot(data[1][1], np.abs(data[1][2]), label=f'{run} - {data[0][2]}')
                self.ax.set_yscale('log', nonpositive='mask')
            else:
                self.ax.plot(data[1][1], data[1][2], label=f'{run} - {data[0][2]}')
            plot_has_data = True
        self.ax.set_ybound(min_value, max_value)
        self.ax.set_xlabel('Time (hrs)')
        self.ax.set_ylabel('Percent')
        self.ax.set_title('Net_Q/INLET_Q')
        if plot_has_data and self.show_legend:
            self.ax.legend()

    def create_mass_balance_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False
    ):
        """Creates the plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
        """
        plot_has_data = False
        self.ax.clear()
        for run in self.scenarios_to_plot:
            self._set_current_scenario(run)
            data = self._get_file_data_trimmed_to_time_extents(self._inf_file, min_time, max_time)
            if log:
                self.ax.plot(data[1][1], np.abs(data[1][3]), label=f'{run} - {data[0][3]}')
                self.ax.plot(data[1][1], np.abs(data[1][4]), label=f'{run} - {data[0][4]}')
                self.ax.set_yscale('log', nonpositive='mask')
            else:
                self.ax.plot(data[1][1], data[1][3], label=f'{run} - {data[0][3]}')
                self.ax.plot(data[1][1], data[1][4], label=f'{run} - {data[0][4]}')
            plot_has_data = True
        self.ax.set_ybound(min_value, max_value)
        self.ax.set_xlabel('Time (hrs)')
        self.ax.set_ylabel('Error')
        self.ax.set_title('Mass Balance')
        if plot_has_data and self.show_legend:
            self.ax.legend()

    def create_wet_elements_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False
    ):
        """Creates the plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
        """
        plot_has_data = False
        self.ax.clear()
        for run in self.scenarios_to_plot:
            self._set_current_scenario(run)
            data = self._get_file_data_trimmed_to_time_extents(self._inf_file, min_time, max_time)
            if log:
                self.ax.plot(data[1][1], np.abs(data[1][5]), label=f'{run} - {data[0][5]}')
                self.ax.set_yscale('log', nonpositive='mask')
            else:
                self.ax.plot(data[1][1], data[1][5], label=f'{run} - {data[0][5]}')
            plot_has_data = True
        self.ax.set_ybound(min_value, max_value)
        self.ax.set_xlabel('Time (hrs)')
        self.ax.set_ylabel('Number of Wet Elements')
        self.ax.set_title('Wet Elements')
        if plot_has_data and self.show_legend:
            self.ax.legend()

    def create_monitor_point_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False,
        column=''
    ):
        """Creates the plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
            column (:obj:`str`): the column to plot from each monitor file
        """
        self.ax.clear()
        plot_has_data = False
        col = 3  # bed elevation
        if column == 'wse':
            col = 4
        y_label = 'Elevation (ft)'
        for run in self.scenarios_to_plot:
            self._set_current_scenario(run)
            i = 0
            for f in self._plot_files['PT']:
                if Path(f).name.endswith('_SED.dat'):  # skip sediment files for now
                    continue
                if f not in self._file_data:
                    self._file_data[f] = np_arrays_from_file(f)
                data = self._get_file_data_trimmed_to_time_extents(f, min_time, max_time)
                if log:
                    self.ax.plot(data[1][0], np.abs(data[1][col]), label=f'{run} - {data[0][col]} - PT{i+1}')
                    self.ax.set_yscale('log', nonpositive='mask')
                else:
                    self.ax.plot(data[1][0], data[1][col], label=f'{run} - {data[0][col]} - PT{i + 1}')
                if data[0][col].endswith('m') or data[0][col].endswith('meter'):
                    y_label = 'Elevation (m)'
                plot_has_data = True
                i += 1
        self.ax.set_ybound(min_value, max_value)
        self.ax.set_xlabel('Time (hrs)')
        self.ax.set_ylabel(y_label)
        if column == 'wse':
            self.ax.set_title('Monitor Point Water Surface Elevation (WSE)')
        else:
            self.ax.set_title('Monitor Point Bed Elevation')
        if plot_has_data and self.show_legend:
            self.ax.legend()

    def create_monitor_line_q_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False
    ):
        """Creates the monitor line Q plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
        """
        self.create_monitor_line_plot(min_time, max_time, min_value, max_value, log, 'Q', 1)

    def create_monitor_line_qs_plot(
        self,
        min_time=None,
        max_time=None,
        min_value: float | None = None,
        max_value: float | None = None,
        log: bool = False
    ):
        """Creates the monitor line QS plot (sediment).

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
        """
        self.create_monitor_line_plot(min_time, max_time, min_value, max_value, log, 'QS', 2)

    def create_monitor_line_plot(
        self, min_time, max_time, min_value: float, max_value: float, log: bool, y_label: str, column: int
    ):
        """Creates the plot.

        Args:
            min_time: A specified minimum time, or None if not specifying.
            max_time: A specified maximum time, or None if not specifying.
            min_value: A specified minimum value, or None if not specifying.
            max_value: A specified maximum value, or None if not specifying.
            log: A flag to indicate if the y-axis should be log scale.
            y_label: The label for the y-axis.
            column: The column to plot
        """
        self.ax.clear()
        plot_has_data = False
        for run in self.scenarios_to_plot:
            self._set_current_scenario(run)
            for i, f in enumerate(self._plot_files['LN']):
                if f not in self._file_data:
                    self._file_data[f] = np_arrays_from_file(f)
                data = self._get_file_data_trimmed_to_time_extents(f, min_time, max_time)
                if log:
                    self.ax.plot(data[1][0], np.abs(data[1][column]), label=f'{run} - {data[0][column]} - LN{i+1}')
                    self.ax.set_yscale('log', nonpositive='mask')
                else:
                    self.ax.plot(data[1][0], data[1][column], label=f'{run} - {data[0][column]} - LN{i + 1}')
                y_label = data[0][column]
                plot_has_data = True
        self.ax.set_ybound(min_value, max_value)
        self.ax.set_xlabel('Time (hrs)')
        self.ax.set_ylabel(y_label)
        self.ax.set_title('Monitor Line ' + y_label)
        if plot_has_data and self.show_legend:
            self.ax.legend()
