"""SolutionPlotsWindow class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from pathlib import Path

# 2. Third party modules
from matplotlib.axes import Axes
from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT
from matplotlib.figure import Figure
from PySide2.QtCore import Qt
from PySide2.QtGui import QKeyEvent
from PySide2.QtWidgets import QApplication, QMainWindow, QToolBar, QTreeWidgetItem, QTreeWidgetItemIterator

# 3. Aquaveo modules
from xms.coverage.xy.xy_io import XyReader
from xms.coverage.xy.xy_series import XySeries
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.settings import SettingsManager
from xms.guipy.widgets import widget_builder

# 4. Local modules
from xms.hgs.gui import gui_util, plot_data_dialog
from xms.hgs.gui.solution_plots_window_ui import Ui_solution_plots_window


def remove_filename_from_names(xy_series_list) -> tuple[list[str], list[str]]:
    """Removes filename from xy series names (XySeries.name has both, separated by '|') and returns lists of both.

    Returns:
        (tuple[list[str], list[str]]): See description.
    """
    filenames = []
    names = []
    for xy_series in xy_series_list:
        pos = xy_series.name.find('|')
        if pos >= 0:
            filenames.append(xy_series.name[:pos])
            names.append(xy_series.name[pos + 1:])
            xy_series.name = names[-1]
    return filenames, names


def run(xy_series_file: Path, parent=None) -> None:
    """Runs the dialog.

    Args:
        xy_series_file (Path): Path to a .h5 file containing XySeries.
        parent: Parent widget.
    """
    qapp = QApplication.instance()
    window = SolutionPlotsWindow(xy_series_file, parent)
    window.show()
    window.activateWindow()
    window.raise_()

    from xms.api.dmi import XmsEnvironment
    if XmsEnvironment.xms_environ_running_tests() == 'TRUE':  # If testing, just accept immediately.
        window.accept()
    else:
        qapp.exec_()  # pragma no cover - can't hit this line if testing


class SolutionPlotsWindow(QMainWindow, XmsDlg):
    """Class to show plots of XySeries.

    Derives from QMainWindow to avoid a problem where the matplotlib configure dialog, accessible from one of the
    toolbar buttons, would flash and then disappear behind the dialog. See NavigationToolbar2QT and
    See https://matplotlib.org/stable/gallery/user_interfaces/embedding_in_qt_sgskip.html.
    """

    VIEW_DATA_SVG = ':/resources/icons/scalar_dataset_active.svg'
    MARKERS = ['o', 'v', '^', '<', '>', '8', 's', 'p', '*', 'h', 'X', 'D']

    def __init__(self, xy_series_file: Path, parent=None) -> None:
        """Initializes the class.

        Args:
            xy_series_file (Path): Path to a .h5 file containing XySeries.
            parent: Parent widget.
        """
        self._dlg_name = 'xms.hgs.gui.solution_plots_window'
        QMainWindow.__init__(self, parent)
        XmsDlg.__init__(self, parent, self._dlg_name)
        self.ui = Ui_solution_plots_window()
        self.ui.setupUi(self)

        self._xy_series_list: list[XySeries] = []
        self.figure: Figure | None = None  # matplotlib.figure Figure
        self.canvas: FigureCanvas | None = None  # matplotlib.backends.backend_qt5agg FigureCanvas
        self.ax: Axes | None = None  # matplotlib Axes
        self._xy_names_to_model_index = {}  # xy series name -> QModelIndex (assumes names are not duplicated)
        self._marker = 0
        self._help_getter = gui_util.help_getter(self._dlg_name)

        self._read_xy_series(xy_series_file)
        self._remove_newton_runtime_info()
        self.ui.splitter.setSizes([350, 500])  # Initialize this to something decent
        self._setup_tree()
        self._setup_plot_canvas()
        self._setup_plot_toolbars()
        self._update_plot()

        # Signals
        self.ui.tree_xy_series.itemChanged.connect(self._on_tree_change)  # Handles check/uncheck events
        self.ui.btn_expand_all.clicked.connect(self._on_expand_all)
        self.ui.btn_collapse_all.clicked.connect(self._on_collapse_all)
        self.ui.buttonBox.accepted.connect(self.accept)
        self.ui.buttonBox.helpRequested.connect(self.help_requested)

    def _read_xy_series(self, xy_series_file: Path) -> None:
        """Reads the XySeries from the file.

        Args:
            xy_series_file (Path): Path to a .h5 file containing XySeries.
        """
        reader = XyReader()
        self._xy_series_list = reader.read_from_h5(xy_series_file)

    def _remove_newton_runtime_info(self) -> None:
        """Removes the newton_runtime_info.dat series.

        Aquanty wanted it removed. It has a different number of times than the rest of the data so it caused a bug in
        the Plot Data dialog.
        """
        self._xy_series_list = [series for series in self._xy_series_list if 'newton_info.dat' not in series.name]

    def _expand_or_collapse_all(self, expand: bool) -> None:
        """Expands or collapses all the tree items.

        Args:
            expand (bool): If True, all items are expanded, otherwise they are all collapsed.
        """
        it = QTreeWidgetItemIterator(self.ui.tree_xy_series)
        while it.value():
            item = it.value()
            if expand and not item.isExpanded():
                self.ui.tree_xy_series.expandItem(item)
            elif not expand and item.isExpanded():
                self.ui.tree_xy_series.collapseItem(item)
            it += 1

    def _on_expand_all(self) -> None:
        """Expands all the tree items."""
        self._expand_or_collapse_all(expand=True)

    def _on_collapse_all(self) -> None:
        """Collapses all the tree items."""
        self._expand_or_collapse_all(expand=False)

    def _setup_tree(self) -> None:
        """Sets up the tree of plots."""
        filenames, names = remove_filename_from_names(self._xy_series_list)
        file_groups = self._group_by_file_and_add_index(filenames, names)
        self.ui.tree_xy_series.setColumnCount(1)
        self.ui.tree_xy_series.header().hide()
        first_plot = None  # First plot of them all
        initial_plot = None  # The plot we want to be the initial plot if it exists
        for file, names_and_indexes in file_groups.items():
            file_item = QTreeWidgetItem(self.ui.tree_xy_series, [file])
            expanded_file = 'water_balance.dat' in file.lower()  # This is the file we want expanded
            file_item.setExpanded(expanded_file)
            for name, index in names_and_indexes:
                name_item = QTreeWidgetItem(file_item, [name])
                self._xy_names_to_model_index[name] = self.ui.tree_xy_series.indexFromItem(name_item)
                name_item.setCheckState(0, Qt.Unchecked)  # This gets the checkbox to appear
                name_item.setData(0, Qt.UserRole, index)  # Store index into self._xy_series_list in the user data
                if not first_plot:
                    first_plot = name_item
                if not initial_plot and expanded_file and name.lower().endswith('-pm'):
                    initial_plot = name_item

        # Check the default item
        default_item = initial_plot if initial_plot is not None else first_plot
        if default_item:
            self.ui.tree_xy_series.setCurrentItem(default_item)
            default_item.setCheckState(0, Qt.Checked)

    def _group_by_file_and_add_index(self, filenames: list[str], names: list[str]) -> dict[str, list[tuple[str, int]]]:
        """Groups the xy series names by their file name, and adds the xy series index (into self._xy_series_list).

        Args:
            filenames (list[str]): List of filenames.
            names (list[str]): List of xy series names

        Returns:
            (dict[str, list[tuple[str, int]]]): See description.
        """
        file_groups: dict[str, list[tuple[str, int]]] = {}
        for i in range(len(filenames)):
            filename = filenames[i]
            if filename not in file_groups:
                file_groups[filename] = []
            file_groups[filename].append((names[i], i))
        return file_groups

    def _setup_plot_canvas(self) -> None:
        """Makes the plot canvas."""
        self.figure = Figure()
        self.figure.set_layout_engine(layout='tight')
        self.canvas = FigureCanvas(self.figure)
        self.canvas.setMinimumWidth(300)  # So user can't resize it to nothing
        self.ui.vlay_grp_box.insertWidget(0, self.canvas)
        self.ax = self.figure.add_subplot(111)
        self.ax.grid(True)

    def _setup_plot_toolbars(self):
        """Sets up the plot toolbars.

        I tried to add a tool to the matplotlib NavigationToolbar2QT but never got that to
        work so we just use our own toolbar.
        """
        self._setup_xms_toolbar()
        self.ui.hlay_toolbars.addWidget(NavigationToolbar2QT(self.canvas))

    def _setup_xms_toolbar(self):
        """Sets up the XMS plot toolbar."""
        self._toolbar = QToolBar(self)
        button_list = [
            [SolutionPlotsWindow.VIEW_DATA_SVG, 'View Data', self._view_data],
        ]
        self._actions = widget_builder.setup_toolbar(self._toolbar, button_list)
        self.ui.hlay_toolbars.addWidget(self._toolbar)

    def _enable_xms_toolbar(self):
        """Enables/disables the toolbar."""
        displayed_xy_idxs = self._get_series_to_display()
        enable = bool(displayed_xy_idxs)
        self._toolbar.widgetForAction(self._actions[SolutionPlotsWindow.VIEW_DATA_SVG]).setEnabled(enable)

    def _view_data(self) -> None:
        """Displays the plot data in a dialog."""
        displayed_xy_idxs = self._get_series_to_display()
        if displayed_xy_idxs:
            xy_series = [self._xy_series_list[index] for index in displayed_xy_idxs]
            plot_data_dialog.run(self, xy_series)

    def _on_tree_change(self) -> None:
        """Called when something with the tree changes."""
        self._update_plot()
        self._enable_xms_toolbar()

    def _remove_old_lines(self, displayed_xy_idxs) -> None:
        """Removes lines from the plot for xy series that are no longer checked."""
        xy_names = {self._xy_series_list[idx].name for idx in displayed_xy_idxs}
        lines = self.ax.get_lines()
        for line in lines:
            if line.get_label() not in xy_names:
                line.remove()

    def _plot_shows_xy_series(self, xy_series) -> bool:
        """Returns True if the plot is displaying the xy series."""
        for line in self.ax.get_lines():
            if line.get_label() == xy_series.name:
                return True
        return False

    def _update_plot(self) -> None:
        """Updates the plot with the current selected XySeries."""
        displayed_xy_idxs = self._get_series_to_display()
        self._remove_old_lines(displayed_xy_idxs)
        x_label = ''
        y_labels = []
        if displayed_xy_idxs:
            for index in displayed_xy_idxs:
                xy_series = self._xy_series_list[index]
                if not self._plot_shows_xy_series(xy_series):
                    if len(xy_series.x) == 1:  # Use a marker if there's just one value so it can be seen
                        self.ax.plot(xy_series.x, xy_series.y, label=xy_series.name, marker=self.MARKERS[self._marker])
                        self._marker = (self._marker + 1) % len(self.MARKERS)
                    elif len(xy_series.x > 0):
                        self.ax.plot(xy_series.x, xy_series.y, label=xy_series.name)
                x_label = xy_series.x_title
                y_labels.append(xy_series.y_title)
        self.ax.relim()  # With next line, frames the lines in the plot
        self.ax.autoscale()  # With previous line, frames the lines in the plot
        self.ax.legend()
        self.ax.set_xlabel(x_label)
        self.ax.set_ylabel(', '.join(y_labels))
        self.canvas.draw()

    def _get_series_to_display(self) -> list[int]:
        """Returns the indices (into self._xy_series_list) of the xy series to display."""
        displayed_xy_idxs = []
        it = QTreeWidgetItemIterator(self.ui.tree_xy_series)
        while it.value():
            item = it.value()
            if item.checkState(0) == Qt.Checked:
                index = item.data(0, Qt.UserRole)
                if index is not None:
                    displayed_xy_idxs.append(index)
            it += 1
        return displayed_xy_idxs

    def showEvent(self, event):  # noqa: N802 - function name should be lowercase
        """Restore last position and geometry when showing dialog."""
        super().showEvent(event)
        self._restore_splitter_geometry()

    def _save_splitter_geometry(self) -> None:
        """Save the current position of the splitter."""
        settings = SettingsManager()
        settings.save_setting('xms.hgs', f'{self._dlg_name}.splitter', self.ui.splitter.sizes())

    def _restore_splitter_geometry(self) -> None:
        """Restore the position of the splitter."""
        splitter = self._get_splitter_sizes()
        if not splitter:
            return
        splitter_sizes = [int(size) for size in splitter]
        self.ui.splitter.setSizes(splitter_sizes)

    def _get_splitter_sizes(self):
        """Returns a list of the splitter sizes that are saved in the registry."""
        settings = SettingsManager()
        splitter = settings.get_setting('xms.hgs', f'{self._dlg_name}.splitter')
        return splitter

    def accept(self) -> None:
        """Called on OK."""
        self._save_splitter_geometry()
        super().accept()
        self.close()

    def keyPressEvent(self, event: QKeyEvent) -> None:  # noqa: N802 - function name should be lowercase
        """Handles key presses so that ESC causes window to close.

        Args:
            event (QKeyEvent): The event.
        """
        if event.key() == Qt.Key_Escape:
            self.close()
        else:
            super().keyPressEvent(event)
