"""The station points solution plots dialog."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import os
from pathlib import Path
import sys
import webbrowser

# 2. Third party modules
from matplotlib import patches
from matplotlib.backends.backend_qt5agg import FigureCanvas
from matplotlib.figure import Figure
import numpy as np
from PySide2.QtCore import Qt
from PySide2.QtWidgets import (
    QDialogButtonBox, QGroupBox, QHBoxLayout, QLabel, QLineEdit, QListWidget, QListWidgetItem, QPushButton, QVBoxLayout
)

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.guipy.dialogs import message_box
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.validators.qx_double_validator import QxDoubleValidator

# 4. Local modules
from xms.adcirc.file_io.solution_importer import ASCII_STATIONS, NETCDF_STATIONS
from xms.adcirc.file_io.station_solution_importer import StationSolutionReader


class StationPlotsDialog(XmsDlg):
    """A dialog for showing station point solution plots."""
    def __init__(self, parent, feature_id, cov_uuid, pe_tree):
        """Initializes the dialog, sets up the ui.

        Args:
            parent (:obj:`QWidget`): Parent dialog
            feature_id (:obj:`int`): id of point or arc
            cov_uuid (:obj:`str`): uuid of the monitor coverage
            pe_tree (:obj:`xms.guipy.tree.tree_node.TreeNode`): The SMS project explorer tree
        """
        super().__init__(parent, 'xms.adcirc.gui.station_plots_dialog')
        self._feature_id = feature_id
        self._cov_uuid = cov_uuid
        self._pe_tree = pe_tree
        self._project_file = XmEnv.xms_environ_project_path()
        self._default_plot_on = 'Water Surface (eta)'
        self._simulations = []
        self._plot_data_files = {}  # {sim_name: [files]}
        self._plot_list = {}  # {sim_name: {file: [y-columns]}}
        self._file_data = {}
        self._checked_sims = []
        self._checked_plots = []
        self._err_files = []
        self._get_simulations()
        self._max_time = sys.float_info.max
        if self._file_data:
            key = next(iter(self._file_data))
            if self._file_data[key][1][0].size > 0:
                self._max_time = max(self._file_data[key][1][0])
        self._updating_time_range = False

        self.figure = None  # matplotlib.figure Figure
        self.canvas = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self.ax = None  # matplotlib Axes

        self.help_url = 'https://www.xmswiki.com/wiki/SMS:ADCIRC#Output_files'
        self.widgets = {}
        self.setWindowTitle('ADCIRC Solution Plots')
        self._setup_ui()

    def _get_simulations(self):
        """Gets the simulations that use this Recording Stations coverage."""
        # Get the list of simulations from the project explorer that have a child link to this coverage.
        tree_nodes = [self._pe_tree]
        for tree_node in tree_nodes:
            tree_nodes.extend(tree_node.children)
            if tree_node.uuid == self._cov_uuid and tree_node.is_ptr_item:
                self._simulations.append(tree_node.parent.name)
        proj_dir = os.path.dirname(self._project_file)
        proj_name = os.path.basename(os.path.splitext(self._project_file)[0])
        adcirc_dir = os.path.join(proj_dir, f'{proj_name}_models', 'ADCIRC')
        # ADCIRC filenames are hardcoded.
        netcdf_files = [str(f.absolute()) for f in Path(adcirc_dir).rglob('*') if f.name.lower() in NETCDF_STATIONS]
        ascii_files = [str(f.absolute()) for f in Path(adcirc_dir).rglob('*') if f.name.lower() in ASCII_STATIONS]
        # Only add files that are in the list of taking simulations.
        for netcdf_file in netcdf_files:
            self._add_solution_if_taken_by_sim(netcdf_file)
        for ascii_file in ascii_files:
            self._add_solution_if_taken_by_sim(ascii_file)
        # Read the solution files
        if len(self._plot_data_files) > 0:
            reader = StationSolutionReader(self._feature_id)
            for sim_name, fnames in self._plot_data_files.items():
                good_indices = []
                for idx, fname in enumerate(fnames):
                    try:
                        file_data = reader.read(fname)
                        if file_data is None or len(file_data[1]) < 1:
                            raise RuntimeError()
                        self._file_data[fname] = file_data
                        plot_list = self._file_data[fname][0][1:]
                        self._plot_list.setdefault(sim_name, {})[fname] = plot_list
                        good_indices.append(idx)
                    except Exception:
                        self._err_files.append(fname)
                self._plot_data_files[sim_name] = [
                    self._plot_data_files[sim_name][good_index] for good_index in good_indices
                ]

    def _add_solution_if_taken_by_sim(self, filename):
        """Add a solution file to the list of plots if the Recording Stations coverage is taken by the simulation."""
        # Get the simulation name from the file's path: .../<project_name>_models/ADCIRC/<sim_name>/<filename>
        sim_name = os.path.basename(os.path.dirname(filename))
        if sim_name in self._simulations:
            self._plot_data_files.setdefault(sim_name, []).append(filename)

    def _setup_ui(self):
        """Sets up the dialog controls."""
        self.widgets['main_vert_layout'] = QVBoxLayout()
        self.setLayout(self.widgets['main_vert_layout'])

        main_warning = 'WARNING. No station output files found. Run ADCIRC before viewing output plots.'
        self.widgets['main_warning'] = QLabel(main_warning)
        self.widgets['main_warning'].setStyleSheet('font-weight: bold; color: red')
        show_main_warning = False
        if len(self._simulations) < 1 or len(self._plot_list) < 1:
            show_main_warning = True

        # horizontal layout
        if show_main_warning:
            self.widgets['main_vert_layout'].addWidget(self.widgets['main_warning'])
        self.widgets['main_horiz_layout'] = QHBoxLayout()
        self.widgets['main_vert_layout'].addLayout(self.widgets['main_horiz_layout'])
        # 2 vertical layout 1 for list box and time range controls and 2 for the plot on the right
        self.widgets['left_vert_layout'] = QVBoxLayout()
        self.widgets['main_horiz_layout'].addLayout(self.widgets['left_vert_layout'], stretch=1)
        self.widgets['right_vert_layout'] = QVBoxLayout()
        self.widgets['main_horiz_layout'].addLayout(self.widgets['right_vert_layout'])

        # list control and 2 edit fields for left side of dialog
        self.widgets['simulation_label'] = QLabel('Simulations:')
        self.widgets['left_vert_layout'].addWidget(self.widgets['simulation_label'])
        self.widgets['simulation_list'] = QListWidget()
        self.widgets['simulation_list'].setMaximumWidth(200)
        self.widgets['simulation_list'].itemChanged.connect(self._on_list_state_change)
        for sim in self._simulations:
            item = QListWidgetItem(sim)
            item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
            item.setCheckState(Qt.Checked)
            self.widgets['simulation_list'].addItem(item)
        self.widgets['left_vert_layout'].addWidget(self.widgets['simulation_list'])
        self.widgets['button_horiz_layout'] = QHBoxLayout()
        self.widgets['left_vert_layout'].addLayout(self.widgets['button_horiz_layout'])
        self.widgets['sim_all_on_btn'] = QPushButton('All on')
        self.widgets['sim_all_on_btn'].clicked.connect(self._on_btn_sim_all_on)
        self.widgets['button_horiz_layout'].addWidget(self.widgets['sim_all_on_btn'])
        self.widgets['sim_all_off_btn'] = QPushButton('All off')
        self.widgets['sim_all_off_btn'].clicked.connect(self._on_btn_sim_all_off)
        self.widgets['button_horiz_layout'].addWidget(self.widgets['sim_all_off_btn'])

        # list control for plot types
        self.widgets['plot_type_label'] = QLabel('Plots:')
        self.widgets['left_vert_layout'].addWidget(self.widgets['plot_type_label'])
        self.widgets['plot_type_list'] = QListWidget()
        self.widgets['plot_type_list'].setMaximumWidth(200)
        self.widgets['plot_type_list'].itemChanged.connect(self._on_list_state_change)
        unique_plots = set()
        for _, sim_plots in self._plot_list.items():
            for _, plot_names in sim_plots.items():
                for plot_name in plot_names:
                    if plot_name in unique_plots:
                        continue  # Already added this plot
                    item = QListWidgetItem(plot_name)
                    item.setFlags(item.flags() | Qt.ItemIsUserCheckable)
                    item.setCheckState(Qt.Unchecked)
                    if plot_name.startswith(self._default_plot_on):
                        item.setCheckState(Qt.Checked)
                    self.widgets['plot_type_list'].addItem(item)
        self.widgets['left_vert_layout'].addWidget(self.widgets['plot_type_list'])
        # time range controls
        self.widgets['time_range_group'] = QGroupBox('Specify time range')
        self.widgets['time_range_group'].setMaximumWidth(200)
        self.widgets['time_range_group'].setCheckable(True)
        self.widgets['time_range_group'].setChecked(False)
        self.widgets['time_range_group'].toggled.connect(self._on_min_max_time_changed)
        self.widgets['left_vert_layout'].addWidget(self.widgets['time_range_group'])
        self.widgets['time_range_vlayout'] = QVBoxLayout()
        self.widgets['time_range_group'].setLayout(self.widgets['time_range_vlayout'])

        self.widgets['min_time_label'] = QLabel('Minimum time:')
        self.widgets['time_range_vlayout'].addWidget(self.widgets['min_time_label'])
        self.widgets['min_time_edit'] = QLineEdit('0.0')
        self.widgets['min_time_edit'].setValidator(QxDoubleValidator(bottom=0.0, top=self._max_time, decimals=3))
        self.widgets['min_time_edit'].editingFinished.connect(self._on_min_max_time_changed)
        self.widgets['time_range_vlayout'].addWidget(self.widgets['min_time_edit'])
        self.widgets['max_time_label'] = QLabel('Maximum time:')
        self.widgets['time_range_vlayout'].addWidget(self.widgets['max_time_label'])
        self.widgets['max_time_edit'] = QLineEdit(str(self._max_time))
        self.widgets['max_time_edit'].setValidator(QxDoubleValidator(bottom=0.0, top=self._max_time, decimals=3))
        self.widgets['max_time_edit'].editingFinished.connect(self._on_min_max_time_changed)
        self.widgets['time_range_vlayout'].addWidget(self.widgets['max_time_edit'])
        self.widgets['left_vert_layout'].addStretch()

        # right side of the dialog
        self.figure = Figure()
        self.figure.set_tight_layout(True)  # Frames the plots
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setMinimumWidth(300)  # So user can't resize it to nothing
        self.widgets['right_vert_layout'].addWidget(self.canvas)
        self.ax = self.figure.add_subplot(111)

        msg = 'WARNING. If the coverage was edited after running ADCIRC, rerun ADCIRC with the updated coverage to ' \
              'ensure that the plot is consistent with ADCIRC outputs.'
        self.widgets['warning_label'] = QLabel(msg)
        self.widgets['warning_label'].setWordWrap(True)
        self.widgets['main_vert_layout'].addWidget(self.widgets['warning_label'])

        # add the ok, cancel, help... buttons at the bottom of the dialog
        self._setup_ui_bottom_button_box()

        self._on_list_state_change()

    def _setup_ui_bottom_button_box(self):
        """Add buttons to the bottom of the dialog."""
        # Add Import and Export buttons
        self.widgets['btn_horiz_layout'] = QHBoxLayout()
        self.widgets['btn_box'] = QDialogButtonBox()
        self.widgets['btn_box'].setOrientation(Qt.Horizontal)
        self.widgets['btn_box'].setStandardButtons(QDialogButtonBox.Close | QDialogButtonBox.Help)
        self.widgets['btn_box'].accepted.connect(self.accept)
        self.widgets['btn_box'].rejected.connect(self.reject)
        self.widgets['btn_box'].helpRequested.connect(self._help_requested)
        self.widgets['btn_horiz_layout'].addWidget(self.widgets['btn_box'])
        self.widgets['main_vert_layout'].addLayout(self.widgets['btn_horiz_layout'])

    def _help_requested(self):  # pragma: no cover
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)

    def _get_file_data_trimmed_to_time_extents(self, file_name):
        """Gets the file data that is within the min/max range if that option is specified.

        Args:
            file_name (:obj:`str`): name of file containing the data
        """
        ret_val = self._file_data[file_name]
        if self.widgets['time_range_group'].isChecked():
            min_time = float(self.widgets['min_time_edit'].text())
            max_time = float(self.widgets['max_time_edit'].text())
            if min_time < max_time:
                time_idx = -1
                for idx, heading in enumerate(ret_val[0]):
                    if heading.startswith('Time'):
                        time_idx = idx
                if time_idx > -1:
                    ret_val = copy.deepcopy(ret_val)
                    if len(ret_val[1][time_idx]) < 1:
                        return ret_val
                    start_idx = np.argmax(ret_val[1][time_idx] > min_time)
                    end_idx = np.argmax(ret_val[1][time_idx] > max_time)
                    if end_idx == 0:
                        end_idx = len(ret_val[1][time_idx]) - 1
                    for idx, aa in enumerate(ret_val[1]):
                        ret_val[1][idx] = aa[start_idx:end_idx]
        return ret_val

    def _on_min_max_time_changed(self, checked=True):
        """Called when the min or max time has changed."""
        if self._updating_time_range:
            return
        # make sure the min/max are not inverted
        min_time = float(self.widgets['min_time_edit'].text())
        max_time = float(self.widgets['max_time_edit'].text())
        if min_time > max_time:
            self._updating_time_range = True
            self.widgets['min_time_edit'].setText(str(max_time))
            self.widgets['max_time_edit'].setText(str(min_time))
            self._updating_time_range = False

        self._on_list_state_change()

    def _on_list_state_change(self):
        """Update plots when state changes."""
        # give message about bad files
        if len(self._err_files) > 0:
            msg = 'Error reading the following files:\n'
            for fname in self._err_files:
                msg += f'{fname}\n'
            app_name = XmEnv.xms_environ_app_name()
            message_box.message_with_ok(
                parent=self.parent(), message=msg, app_name=app_name, icon='Warning', win_icon=self.windowIcon()
            )
            self._err_files = []
        # get all checked simulations
        self._checked_sims = []
        for i in range(self.widgets['simulation_list'].count()):
            item = self.widgets['simulation_list'].item(i)
            if item.checkState() == Qt.Checked:
                self._checked_sims.append(item.text())
        self._checked_plots = []
        for i in range(self.widgets['plot_type_list'].count()):
            item = self.widgets['plot_type_list'].item(i)
            if item.checkState() == Qt.Checked:
                self._checked_plots.append(item.text())
        self._update_plots()

    def _update_plots(self):
        """Draws all plots base on selection in the list boxes."""
        self.ax.clear()
        reader = StationSolutionReader(self._feature_id)
        show_legend = False
        for sim_name in self._checked_sims:
            data_files = self._plot_data_files.get(sim_name, [])
            if not data_files:
                continue
            for data_file in data_files:
                if data_file not in self._file_data:
                    try:
                        data_vals = reader.read(data_file)
                        if data_vals is None:
                            raise RuntimeError()
                        self._file_data[data_file] = data_vals
                    except Exception:
                        self._err_files.append(data_file)
                        continue
                data = self._get_file_data_trimmed_to_time_extents(data_file)
                fname_plots = self._plot_list.get(sim_name, {}).get(data_file, [])
                for plot in self._checked_plots:
                    try:
                        plot_idx = fname_plots.index(plot) + 1
                    except ValueError:
                        continue
                    if plot_idx < len(data[1]):
                        time_vals = data[1][0]
                        if not plot.endswith('Vector') and not plot.endswith('Vector - ASCII'):
                            y_vals = data[1][plot_idx]
                            self.ax.plot(time_vals, y_vals, label=f'{sim_name}:{data[0][plot_idx]}')
                        else:
                            self._draw_vector_arrows(time_vals, data, plot_idx, sim_name)
                        show_legend = True
        if show_legend:
            self.ax.legend()
        self.ax.set_xlabel('Time (hrs)')
        self.canvas.draw()

    def _draw_vector_arrows(self, time_vals, data, plot_idx, sim_name):
        """Draw arrows for a vector dataset.

        Args:
            time_vals (:obj:`numpy.ndarray`): The time values for the curve
            data (:obj:`list[numpy.ndarray]`): The array of curve values for the point and a linked simulation
            plot_idx (:obj:`int`): Index of the vector curve values
            sim_name (:obj:`str`): Name of the simulation we are plotting curves for
        """
        # Create arrows from Vx and Vy datasets
        label = f'{sim_name}:{data[0][plot_idx]}'
        self.ax.plot(time_vals, [0.0] * len(time_vals), label=label)
        x_vals = data[1][plot_idx + 1]  # Vx should come immediately after the '* Vector' dataset in the plot list.
        y_vals = data[1][plot_idx + 2]  # Vy should come immediately after the '* Vx' dataset in plot list.

        # Check for empty or inconsistent datasets. See Mantis issue 14544
        if x_vals.size != y_vals.size or x_vals.size == 0:
            msg = f'Empty or inconsistent dataset values in file: {label}\n'
            app_name = XmEnv.xms_environ_app_name()
            message_box.message_with_ok(
                parent=self.parent(), message=msg, app_name=app_name, icon='Warning', win_icon=self.windowIcon()
            )
            return

        max_x = np.max(x_vals)
        min_x = np.min(x_vals)
        range_x = max_x - min_x
        dx_vals = x_vals.tolist()
        max_y = np.max(y_vals)
        min_y = np.min(y_vals)
        range_y = max_y - min_y
        dy_vals = y_vals.tolist()
        min_t = time_vals[0]
        max_t = time_vals[-1]
        time_vals = time_vals.tolist()
        scale_factor = (max_t - min_t) / (range_y / range_x)
        min_t = min(min_t, scale_factor * dx_vals[0])
        max_t = max(max_t, scale_factor * dx_vals[-1])
        scale_factor = (max_t - min_t) / (range_y / range_x)
        for time, dx_val, dy_val in zip(time_vals, dx_vals, dy_vals):
            self.ax.add_patch(
                patches.FancyArrowPatch(
                    (time, 0.0), (scale_factor * dx_val + time, dy_val), mutation_scale=10, arrowstyle='->'
                )
            )

    def _on_btn_sim_all_on(self):
        """Signal for when the user clicks the All on button."""
        for i in range(self.widgets['simulation_list'].count()):
            item = self.widgets['simulation_list'].item(i)
            item.setCheckState(Qt.Checked)

    def _on_btn_sim_all_off(self):
        """Signal for when the user clicks the All on button."""
        for i in range(self.widgets['simulation_list'].count()):
            item = self.widgets['simulation_list'].item(i)
            item.setCheckState(Qt.Unchecked)
