"""Qt table model using an SQL database for storage."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from PySide2.QtCore import QModelIndex, Qt, Signal, Slot
from PySide2.QtGui import QColor
from PySide2.QtSql import QSqlDatabase, QSqlTableModel

# 3. Aquaveo modules

# 4. Local modules


class QxSqlTableModel(QSqlTableModel):
    """An XMS class derived from QSqlTableModel."""

    watched_column_change = Signal(QModelIndex)

    def __init__(self, parent=None, db=None):
        """Initializes the class.

        Args:
            parent (Something derived from QWidget): The parent window.
            db (QSqlDatabase): The database
        """
        super().__init__(parent, db if db is not None else QSqlDatabase())

        self.watched_column_change.connect(self.on_watched_column_change)
        self.read_only_columns = set()  # Columns that will be read only
        self.watched_columns = set()  # Columns that, when changed, will cause a signal
        self.horizontal_header_tooltips = None

    @Slot(QModelIndex)
    def on_watched_column_change(self, index):
        """Signal that is called when a change is made in a watched column.

        Args:
            index (QModelIndex): The index of the column that changed
        """
        pass

    def set_read_only_columns(self, read_only_columns):
        """Sets which columns are supposed to be read-only.

        Args:
            read_only_columns (set{int}): The read only columns.
        """
        self.read_only_columns = read_only_columns

    def set_watched_columns(self, watched_columns):
        """Sets which columns are being watched.

        Changes made in a watched column results in watched_column_change signal.

        Args:
            watched_columns (set{int}): The watched columns.
        """
        self.watched_columns = watched_columns

    def data(self, index, role=Qt.DisplayRole):
        """Depending on the index and role given, return data, or None.

        Args:
            index (QModelIndex): The index.
            role (int): The role.

        Returns:
            The data at index, or None.
        """
        if role == Qt.UserRole:  # Just return the data (don't convert to string...)
            # Qt.UserRole, which is used in QxTableView._fill_data(), isn't supported by QSqlTableModel, so switch
            # the role to Qt.DisplayRole
            return super().data(index, role=Qt.DisplayRole)
        elif role == Qt.BackgroundColorRole:
            if index.column() in self.read_only_columns:
                return QColor(240, 240, 240)
        return super().data(index, role)

    def setData(self, index, value, role=Qt.EditRole):  # noqa: N802
        """Adjust the data (set it to <value>) depending on index and role.

        Args:
            index (QModelIndex): The index.
            value (object): The value.
            role (int): The role.

        Returns:
            (bool): True if successful; otherwise False.
        """
        if not index.isValid():
            return None

        if index.column() in self.read_only_columns:
            return False

        if role == Qt.EditRole or role == Qt.CheckStateRole:
            new_value = self._validate_value(index, value)
            super().setData(index, new_value, role)
            if index.column() in self.watched_columns:
                self.watched_column_change.emit(index)  # Send a signal

            return True

        return False

    def headerData(self, section, orientation, role=Qt.DisplayRole):  # noqa: N802
        """Returns the data for the given role and section in the header.

        Args:
            section (int): The section.
            orientation (Qt.Orientation): The orientation.
            role (int): The role.

        Returns:
            The data.
        """
        if role == Qt.ToolTipRole:
            if orientation == Qt.Horizontal and self.horizontal_header_tooltips:
                return self.horizontal_header_tooltips.get(section)
            else:
                return None  # I didn't implement vertical header tooltips
        else:
            return super().headerData(section, orientation, role)

    def set_horizontal_header_tooltips(self, tooltips):
        """Sets the tooltips for the header.

        Args:
            tooltips (dict{int, str}): Tooltips dict where int is section number.
        """
        self.horizontal_header_tooltips = tooltips

    def _validate_value(self, index, value):
        """Validate value by trying to convert it to the same type as the column.

        Args:
            index (QModelIndex): The index.
            value (object): The value.

        Returns:
            The validated value.

        """
        if not value:
            return value

        field = self.record().field(index.column())
        field_type = field.type()
        try:
            new_value = field_type(value)
        except ValueError:
            new_value = field.value()
        return new_value

    def flags(self, index):
        """Returns the item flags for the given index.

        Args:
            index (QModelIndex): The index.

        Returns:
            (int): The flags.
        """
        if not index.isValid():
            return Qt.ItemIsEnabled

        # Make it non-editable if needed
        flags = super().flags(index)
        index_column = index.column()
        if index_column in self.read_only_columns:
            flags = flags & (~Qt.ItemIsEditable)
        else:
            flags = flags | Qt.ItemIsEditable

        return flags
