"""A widget for assigning transport constituents for boundary conditions."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from adhparam.time_series import TimeSeries
import pandas
from PySide2.QtCore import QModelIndex, Qt

# 3. Aquaveo modules
from xms.guipy.delegates.curve_button_delegate import CurveButtonDelegate
from xms.guipy.delegates.qx_cbx_delegate import QxCbxDelegate
from xms.guipy.models.qx_pandas_table_model import QxPandasTableModel
from xms.guipy.models.rename_model import RenameModel

# 4. Local modules
from xms.adh.data.card_info import CardInfo
from xms.adh.gui.widgets.transport_constituent_assignment_widget import TransportConstituentAssignmentWidget


class BCAssignmentModel(RenameModel):
    """A model to rename header titles, hide unwanted options."""
    def __init__(self, show_snapping, column_names, parent=None):
        """Initializes the filter model.

        Args:
            show_snapping (bool): True if the snapping column should be shown.
            column_names (list): The column names.
            parent (Something derived from :obj:`QObject`): The parent object.

        """
        self._ID_SOURCE_COLUMN = 0
        self._TYPE_SOURCE_COLUMN = 2
        self._SNAP_SOURCE_COLUMN = 4
        self._NAME_COLUMN = 0
        self._TYPE_COLUMN = 1
        self._SERIES_COLUMN = 2
        self._SNAP_COLUMN = 3
        self.show_snapping = show_snapping
        super().__init__(column_names, parent)

    def filterAcceptsColumn(self, source_column: int, source_parent: QModelIndex) -> bool:  # noqa: N802
        """Filters out the 'ID' column.

        Args:
            source_column (int): The column index in the source model.
            source_parent (QModelIndex): The parent index of the view.

        Returns:
            True if the model should keep the column, false if the column should be hidden.
        """
        return source_column != self._ID_SOURCE_COLUMN and \
            (self.show_snapping or source_column != self._SNAP_SOURCE_COLUMN)

    def flags(self, index: QModelIndex):
        """Decides flags for things like enabled state.

        Args:
            index (QModelIndex): The index in the

        Returns:
            Item flags for the model index.
        """
        local_flags = super(BCAssignmentModel, self).flags(index)
        if index.column() == self._NAME_COLUMN:
            local_flags = local_flags & ~Qt.ItemIsEditable
        elif index.column() == self._SERIES_COLUMN:
            current_type = index.model().index(index.row(), self._TYPE_COLUMN, index.parent()).data(Qt.EditRole)

            if current_type == 'None':
                local_flags = local_flags & ~Qt.ItemIsEnabled
            else:
                local_flags = local_flags | Qt.ItemIsEnabled
        elif self.show_snapping and index.column() == self._SNAP_COLUMN:
            # current_type = index.model().index(index.row(), self._TYPE_COLUMN, index.parent()).data(Qt.EditRole)
            # if current_type in ['Dirichlet', 'Equilibrium']:
            #     local_flags = local_flags | Qt.ItemIsEnabled
            # else:
            #     local_flags = local_flags & ~Qt.ItemIsEnabled
            local_flags = local_flags & ~Qt.ItemIsEnabled
        return local_flags

    def data(self, index, role=Qt.DisplayRole):
        """Gets the data for the model.

        Args:
            index (QModelIndex): The location index in the Qt model.
            role (int): The role the data represents.

        Returns:
            An object of data the the role, or returns a blank string for a disabled cell in the snap column.
        """
        source_data = super(RenameModel, self).data(index, role)
        if self.show_snapping and index.column() == self._SNAP_COLUMN and role in [Qt.DisplayRole, Qt.EditRole]:
            current_type = self.sourceModel().index(index.row(), self._TYPE_SOURCE_COLUMN).data(Qt.EditRole)
            if current_type in ['Natural']:
                source_data = 'Edgestring snap'
            elif current_type in ['Dirichlet', 'Equilibrium']:
                source_data = 'Point snap'
            else:
                source_data = ''
        return source_data


class TransportConstituentBCAssignmentWidget(TransportConstituentAssignmentWidget):
    """A dialog for assigning transport constituents to strings."""
    def __init__(self, parent, pe_tree, is_arc, time_series, use_transport, is_sediment):
        """Allows the user to edit which transport constituents are used and how.

        Args:
            parent (Something derived from :obj:`QWidget`): The parent window.
            pe_tree (): The project explorer tree.
            is_arc (bool): True if this should be used for arcs.
            time_series (dict): A dictionary of TimeSeries with an integer key.
            use_transport (bool): True if transport is being used.
            is_sediment (bool): True if the transport is sediment transport.
        """
        super().__init__(parent, pe_tree, use_transport, is_sediment)
        self.time_series = time_series
        self.filter_model = BCAssignmentModel(is_arc, ['Name', 'Type', 'Time Series', 'Snapping Method'], self)
        self.type_delegate = QxCbxDelegate(self)
        types = ['None', 'Dirichlet']
        if is_arc:
            types.append('Natural')
        if self.is_sediment:
            types.append('Equilibrium')
        self.type_delegate.set_strings(types)
        self.snap_delegate = QxCbxDelegate(self)
        self.snap_delegate.set_strings(CardInfo.snap_options)
        self.series_delegate = CurveButtonDelegate(self._get_series, self._set_series, self)
        self.series_delegate.dialog_title = 'XY Series Editor'
        self._TYPE_COLUMN = 1  # Needs to be the filtered column index.
        self._SERIES_COLUMN = 2  # Needs to be the filtered column index.
        self._SNAP_COLUMN = 3  # Needs to be the filtered column index.
        self.ui.constituents_table.setItemDelegateForColumn(self._TYPE_COLUMN, self.type_delegate)
        self.ui.constituents_table.setItemDelegateForColumn(self._SERIES_COLUMN, self.series_delegate)
        self.ui.constituents_table.setItemDelegateForColumn(self._SNAP_COLUMN, self.snap_delegate)

    def set_transport(self, constituents_uuid, constituents, assignments=None, transport_name=''):
        """Sets the transport constituents and current assignments.

        Args:
            constituents_uuid (str): The uuid of the transport constituents component.
            constituents (TransportConstituentsIO/SedimentConstituentsIO): The transport constituents data.
            assignments (pandas.DataFrame): The current assignments to edit.
            transport_name (str): The name of the transport constituents component as it shows in the project explorer.
        """
        super().set_transport(constituents_uuid, constituents, assignments, transport_name)

        assignment_list = []

        # Add assignments for constituents.
        if self.is_sediment and constituents:
            comp_ids = constituents.param_control.sand['ID'].values.tolist()
            comp_ids.extend(constituents.param_control.clay['ID'].values.tolist())
            names = constituents.param_control.sand['NAME'].values.tolist()
            names.extend(constituents.param_control.clay['NAME'].values.tolist())
            for comp_id, name in zip(comp_ids, names):
                self._add_constituent(comp_id, name, assignments, assignment_list)
        elif constituents:
            if constituents.param_control.salinity:
                self._add_constituent(1, 'Salinity', assignments, assignment_list)
            if constituents.param_control.temperature:
                self._add_constituent(2, 'Temperature', assignments, assignment_list)
            if constituents.param_control.vorticity:
                self._add_constituent(3, 'Vorticity', assignments, assignment_list)
            comp_ids = constituents.user_constituents['ID'].data.tolist()
            names = constituents.user_constituents['NAME'].data.tolist()
            for comp_id, name in zip(comp_ids, names):
                self._add_constituent(comp_id, name, assignments, assignment_list)
        self.assignments = pandas.DataFrame(
            assignment_list, columns=['CONSTITUENT_ID', 'NAME', 'TYPE', 'SERIES_ID', 'SNAPPING']
        )
        self.model = QxPandasTableModel(self.assignments, self)
        self.filter_model.setSourceModel(self.model)
        self.ui.constituents_table.setModel(self.filter_model)
        self.ui.constituents_table.update()

    @staticmethod
    def _add_constituent(con_id, con_name, assignments, assignment_list):
        """Adds the constituent from the component if possible, otherwise it adds a default.

        Args:
            con_id (int): The constituent id to look for.
            con_name (str): The constituent name.
            assignments (pandas.DataFrame): The current constituent assignments.
            assignment_list (list): The full list of constituent assignments that the current constituent
                                    will be added to.
        """
        con_assignment = None
        if assignments is not None and not assignments.empty:
            con_assignment = assignments.loc[assignments['CONSTITUENT_ID'] == con_id].values.tolist()
        if con_assignment:
            con_assignment[0][2] = con_name  # Reset the name since it may have changed.
            assignment_list.append(con_assignment[0][1:])
        else:
            assignment_list.append([con_id, con_name, 'None', 0, 'Point snap'])

    def _get_series(self, curve_id):
        """Returns a dataframe.

        Args:
            curve_id (int): The time series curve id.

        Returns:
            A pandas dataframe for the curve and the new curve id.
            If the given curve id was valid, then it will be the same id.
        """
        if curve_id <= 0:
            if not self.time_series:
                curve_id = 1
            else:
                curve_id = int(max(self.time_series.keys())) + 1
            self.time_series[curve_id] = TimeSeries()
        return self.time_series[curve_id].time_series, curve_id

    def _set_series(self, curve_id, dataframe):
        """Sets the meteorological station curve.

        Args:
            curve_id (int): The time series curve id.
            dataframe (pandas.Dataframe): The time series curve.
        """
        self.time_series[curve_id].time_series = dataframe
