"""Plots and table for SRH."""

# 1. Standard Python modules
import copy
import os
from pathlib import Path
import sys

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.guipy.data.plot_and_table_data_base import np_arrays_from_file as npfile, PlotsAndTableDataBase

# 4. Local modules


class PlotsAndTableDataSrh(PlotsAndTableDataBase):
    """A SRH class derived from plots and table."""
    def __init__(self, pe_tree, feature_id, feature_type, cov_uuid, model_name='SRH-2D', feature_ids=None):
        """Construct default legend display options."""
        super().__init__(pe_tree, feature_id, feature_type, cov_uuid, model_name, feature_ids)
        self._is_bc_coverage = False
        self._is_3d_structure = False

        self.default_plot_on = 'Water_Elev'
        if self.feature_type == 'Arc':
            self.default_plot_on = 'Q('

        if not self._pe_tree:
            return
        # get list of simulations from the project explorer
        self._get_list_of_sims()

        for f_id in self.fids:
            self._get_plot_data_files(f_id)

        self.help_url = 'https://www.xmswiki.com/wiki/SMS:SRH-2D_Plots'
        self.window_title = f'{model_name} Solution Plots'

    def _get_list_of_sims(self):
        """Find all the simulations that include this coverage."""
        tree_nodes = [self._pe_tree]
        for tree_node in tree_nodes:
            tree_nodes.extend(tree_node.children)  # noqa: B038 editing a loop's mutable iterable
            if tree_node.uuid == self._cov_uuid and tree_node.is_ptr_item:
                self.simulations.append(tree_node.parent.name)
                self._project_simulations.append(tree_node.parent.name)
                if tree_node.coverage_type == 'Boundary Conditions':
                    self._is_bc_coverage = True
                    self.default_plot_on = 'Discharge'
                elif tree_node.coverage_type == '3D Structure':
                    self._is_3d_structure = True
                    self.default_plot_on = 'Discharge'

    def _get_plot_data_files(self, feature_id: int):
        """Get the data files associated with the feature_id.

        Args:
            feature_id (int): id of the feature
        """
        file_pattern = f'*_LN{feature_id}.dat'
        if self.feature_type == 'Point':
            file_pattern = f'*_PT{feature_id}.dat'
        if self._is_bc_coverage:
            file_pattern = self._get_bc_file_pattern(self._model_dir)
        data_files = [str(f.absolute()) for f in Path(self._model_dir).rglob(file_pattern)]
        if self._is_3d_structure:
            data_files = self._get_3d_structure_files(self._model_dir)

        # only add files that are in the list of simulations
        scenario_sims = set()
        for data_file in data_files:
            if 'Output_MISC' not in data_file:
                continue

            # get the simulation name
            dirname = os.path.dirname
            sim_name = os.path.basename(dirname(dirname(data_file)))
            if sim_name in self.simulations:
                self.plot_data_files[sim_name] = data_file
            else:
                scenario = sim_name
                sim_name = os.path.basename(dirname(dirname(dirname(data_file))))
                if sim_name in self._project_simulations:
                    scenario_sims.add(sim_name)
                    scenario_sim_name = f'{sim_name}/{scenario}'
                    if scenario_sim_name not in self.simulations:
                        self.simulations.append(f'{sim_name}/{scenario}')
                    self.plot_data_files[scenario_sim_name] = data_file

        # read the first file to get the list of plots
        # read all the files because some sims might have sediment stuff and that adds plots
        if len(self.plot_data_files) > 0:
            keys = list(self.plot_data_files.keys())
            for k in keys:
                fname = self.plot_data_files[k]
                try:
                    self.file_data[fname] = npfile(fname)
                    plot_list = self.file_data[fname][0][1:]
                    if len(plot_list) > len(self.plot_list):
                        self.plot_list = plot_list
                except Exception:
                    self.plot_data_files.pop(k)
                    self.err_files.append(fname)

        for item in scenario_sims:
            if item in self.simulations:
                self.simulations.remove(item)
        self.feature_ids_plot_data_files[feature_id] = self.plot_data_files.copy()

    def _get_3d_structure_files(self, srh_dir: str):
        """Get the list of files associated with the selected arc.

        Args:
            srh_dir (str): directory with all SRH models

        Returns:
            (list[str]): list of the output files that are associated with the selected arc
        """
        ret_files = []
        index_files = [str(f.absolute()) for f in Path(srh_dir).rglob('structure_index.txt')]
        for idx_file in index_files:
            df = pd.read_csv(idx_file)
            if 'coverage_uuid' not in df.columns:
                continue
            df = df.loc[df['arc_id'] == self._feature_id]
            df = df.loc[df['coverage_uuid'] == self._cov_uuid]
            if len(df) < 1:
                continue
            struct_type = df['structure_type'].item()
            struct_index = df['structure_index'].item()
            file_pattern = f'*_{struct_type}{struct_index}.dat'
            sim_dir = os.path.dirname(idx_file)
            out_file = [str(f.absolute()) for f in Path(sim_dir).rglob(file_pattern)]
            if len(out_file) == 1:
                ret_files.append(out_file[0])
        return ret_files

    def _get_bc_file_pattern(self, srh_dir: str):
        """Get the filename pattern for a bc coverage.

        Args:
            srh_dir (str): directory with all SRH models

        Returns:
            (str): the file pattern to match
        """
        file_pattern = 'NO_STRUCTURES_FOUND'
        # find the first structure_index.txt file in one of the solution folders
        file_list = [str(f.absolute()) for f in Path(srh_dir).rglob('structure_index.txt')]
        struct_index_files = []
        for f in file_list:
            sim_name = os.path.basename(os.path.dirname(f))
            if sim_name in self.simulations:
                struct_index_files.append(f)
            else:
                sim_name = os.path.basename(os.path.dirname(os.path.dirname(f)))
                if sim_name in self.simulations:
                    struct_index_files.append(f)

        if len(struct_index_files) < 1:
            return file_pattern
        df = pd.read_csv(struct_index_files[0])
        df = df.loc[df['arc_id'] == self._feature_id]
        if len(df) < 1:
            return file_pattern
        struct_type = df['structure_type'].item()
        struct_type = 'HY' if struct_type == 'HY8' else struct_type
        struct_index = df['structure_index'].item()
        file_pattern = f'*_{struct_type}{struct_index}.dat'
        return file_pattern

    def get_plot_data_for_feature(self, f_id, checked_sims, checked_plots):
        """Retrieve and prepare data for plotting for a given feature ID."""
        plot_data = {}
        errors = []
        f_label = f'PT{f_id}' if self.feature_type == 'Point' else f'LN{f_id}'

        self.plot_data_files = self.feature_ids_plot_data_files.get(f_id, {})
        for sim in self.simulations:
            if sim not in checked_sims:
                continue
            data_file = self.plot_data_files.get(sim, "")
            if not data_file:
                continue

            # Read and cache data if not already loaded
            if data_file not in self.file_data:
                try:
                    self.file_data[data_file] = npfile(data_file)
                except Exception:
                    errors.append(data_file)
                    continue

            data = self.file_data[data_file]
            plot_data[sim] = {
                "data": data,
                "label": f"{f_label} - {sim}",
                "plots": [(plot, data[0].index(plot) + 1) for plot in checked_plots if plot in data[0]]
            }

        return plot_data, errors

    def _get_file_data_trimmed_to_time_extents(self, file_name, min_time, max_time):
        """Gets the file data that is within the min/max range if that option is specified.

        Args:
            file_name (str): name of file containing the data
            min_time (int): minimum time range
            max_time (int): maximum time range
        """
        ret_val = self.file_data[file_name]
        if min_time is not None:
            if min_time < max_time:
                time_idx = -1
                for idx, heading in enumerate(ret_val[0]):
                    if heading.startswith('Time'):
                        time_idx = idx
                if time_idx > -1:
                    ret_val = copy.deepcopy(ret_val)
                    data_max_time = max(ret_val[1][time_idx])
                    if data_max_time < min_time:
                        for idx, _ in enumerate(ret_val[1]):
                            ret_val[1][idx] = []
                    else:
                        start_idx = np.argmax(ret_val[1][time_idx] >= min_time)
                        end_idx = np.argmax(ret_val[1][time_idx] >= max_time)
                        if end_idx == 0:
                            end_idx = len(ret_val[1][time_idx])
                        for idx, aa in enumerate(ret_val[1]):
                            ret_val[1][idx] = aa[start_idx:end_idx]
        return ret_val

    def _get_file_data_trimmed_to_time_extents_in_df(self, file_name, min_time, max_time):
        """Gets the file data that is within the min/max range if that option is specified.

        Args:
            file_name (str): name of file containing the data
            min_time (int): minimum time range
            max_time (int): maximum time range
        """
        file_data = self.file_data[file_name]
        file_data_sel = ([], [])

        time_data_added = False
        for plot in self.checked_plots:
            data_idx = self.plot_list.index(plot) + 1
            if not time_data_added:
                time_data_added = True
                file_data_sel[0].append(file_data[0][0])
                file_data_sel[1].append(file_data[1][0])
            if not data_idx < len(file_data[0]):
                break
            file_data_sel[0].append(file_data[0][data_idx])
            file_data_sel[1].append(file_data[1][data_idx])

        # Determine the time heading
        time_heading = ''
        other_headings = []
        for _, heading in enumerate(file_data_sel[0]):
            if heading.startswith('Time'):
                time_heading = heading
            else:
                other_headings.append(heading)

        data_dict = {}
        for index in range(len(file_data_sel[0])):
            data_dict[file_data_sel[0][index]] = file_data_sel[1][index]
        df = pd.DataFrame(data_dict)

        if min_time is not None and len(df) > 0:
            df = df.loc[(df[time_heading] >= min_time) & (df[time_heading] <= max_time)]

        # Do NOT trim the dataframe for the value filtering.
        # The trim would require all columns to be in that range, which is unlikely
        # If we only trim one dataset, then it is inconsistent when show multiple.

        # if self.widgets['value_range_group'].isChecked():
        #     min_value = float(self.widgets['min_value_edit'].text())
        #     max_value = float(self.widgets['max_value_edit'].text())

        #     if len(other_headings) < 2:
        #         for heading in other_headings:
        #             df = df.loc[(df[heading] >= min_value) & (df[heading] <= max_value)]

        return df

    def get_plot_data_and_dataframe(self, sim, min_x, max_x):
        """Retrieves data and dataframe from simulation.

        Args:
            sim: retrieved from self.simulations
            min_x (int): minimum x value.
            max_x (int): maximum x value.
        """
        data = df = None
        update_dlg_min_max = False
        data_file = ''
        if sim in self.plot_data_files:
            data_file = self.plot_data_files[sim]
        if not data_file:
            return data, df, update_dlg_min_max
        if data_file not in self.file_data:
            try:
                self.file_data[data_file] = npfile(data_file)
                update_dlg_min_max = True
            except Exception:
                self.err_files.append(data_file)
                return data, df, update_dlg_min_max
        data = self._get_file_data_trimmed_to_time_extents(data_file, min_x, max_x)
        df = self._get_file_data_trimmed_to_time_extents_in_df(data_file, min_x, max_x)
        return data, df, update_dlg_min_max

    def get_plot_max_x(self):
        """Returns the max x value."""
        max_time = -sys.float_info.max
        for key in self.file_data.keys():
            max_time = max(max_time, max(self.file_data[key][1][0] + 1.0e-7))
        return max_time

    def get_plot_min_max_y(self):
        """Returns the min y, max y, min x, max x."""
        max_value = -sys.float_info.max
        min_value = sys.float_info.max
        log_min_value = sys.float_info.max
        log_max_value = sys.float_info.min
        for key in self.file_data.keys():
            for index in range(1, len(self.file_data[key][0])):
                a = self.file_data[key][1][index]  # for short
                min_value = min(min_value, a.min())
                max_value = max(max_value, a.max())
                # For log, get non-zero min and max. See https://stackoverflow.com/a/7164681/5666265
                non_zero_min = np.ma.masked_equal(np.abs(a), 0.0, copy=False).min()
                non_zero_max = np.ma.masked_equal(np.abs(a), 0.0, copy=False).max()
                log_min_value = min(log_min_value, non_zero_min)
                log_max_value = max(log_max_value, non_zero_max)
        return min_value, max_value, log_min_value, log_max_value

    def get_f_label(self, f_id):
        """Returns f label."""
        return f'PT{f_id}' if self.feature_type == 'Point' else f'LN{f_id}'

    def get_x_label(self, x_heading):
        """Returns x label."""
        if x_heading in ['Time(hr)', 'Time(hour)', 'Time(hr)']:
            x_label = 'Time (hrs)'
        elif x_heading == 'Time_s':
            x_label = 'Time (s)'
        else:
            x_label = 'Time'
        return x_label

    def get_main_warning(self):
        """Returns main error warning at top of widget."""
        main_warning = (
            f'WARNING. No {self.model_name} output files found. '
            f'Run {self.model_name} before viewing output plots.'
        )
        if self._is_bc_coverage:
            main_warning += '\nNote: Bridge structures only produce plots if the overtopping option is checked.'
        if self._is_3d_structure:
            main_warning += '\nNote: 3D structures only produce plots if the overtopping option is checked.'
        return main_warning

    def get_bottom_warning(self):
        """Returns error message at the bottom of the widget."""
        msg = f'WARNING. If the coverage was edited after running {self.model_name}, ' \
              f'rerun {self.model_name} with the updated coverage to ' \
              f'ensure that the plot is consistent with {self.model_name} outputs.'
        return msg

    def get_error(self):
        """Returns error message for reading file."""
        if len(self.err_files) > 0:
            msg = 'Error reading the following files:\n'
            for fname in self.err_files:
                msg += f'{fname}\n'
            self.err_files = []
            return msg
        return ''
