"""Qt delegate for a curve editor button displaying a curve preview."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import math

# 2. Third party modules
import pandas
from PySide2.QtCore import QEvent, QPointF, QSize, Qt
from PySide2.QtGui import QBrush, QPen, QPixmap
from PySide2.QtWidgets import QPushButton, QStyle, QStyledItemDelegate

# 3. Aquaveo modules

# 4. Local modules
from xms.adh.gui.bed_layer_constituent_dialog import BedLayerConstituentDialog


class GrainSizeDistributionDelegate(QStyledItemDelegate):
    """Qt delegate for a curve editor button displaying a grain size distribution curve preview."""
    def __init__(self, data, material_id, parent=None):
        """Initializes the class.

        Args:
            data (SedimentMaterialsIO): The sediment options with materials.
            material_id (int): The id of the current material (0 for global unassigned).
            parent (QObject): The parent object.
        """
        super().__init__(parent)
        self.parent_dlg = parent
        self.pt_cache = {}
        self.size_cache = {}
        self.data = data
        self.material_id = material_id
        self.constituents = None
        self.constituent_sizes = None
        self.paint_sizes = []
        self.constituent_id_to_size_idx = {}

    def updateEditorGeometry(self, editor, option, index):  # noqa: N802
        """Override of QStyledItemDelegate method of same name.

        Args:
            editor (QWidget): The editor widget.
            option (QStyleOptionViewItem): The style options.
            index (QModelIndex): The index in the model.
        """
        editor.setGeometry(option.rect)

    def paint(self, painter, option, index):
        """Override of QStyledItemDelegate method of same name.

        Args:
            painter (QPainter): The painter.
            option (QStyleOptionViewItem): The style options.
            index (QModelIndex): The index in the model.
        """
        if not index.isValid():
            return
        if (index.flags() & Qt.ItemIsEditable) == 0:
            dis_brush = QBrush(option.palette.window())
            painter.setBrush(dis_brush)

        if index.flags() & QStyle.State_Selected:
            sel_brush = QBrush(option.palette.highlight())
            painter.setBrush(sel_brush)

        if index.flags() & Qt.ItemIsEnabled:
            has_sizes = self.constituent_sizes is not None
            btn = QPushButton('Edit Curve...' if has_sizes else 'UNAVAILABLE')
            height = option.rect.height()
            width = option.rect.width()
            min_dim = min(height, width)
            icon_size = QSize(min_dim, min_dim)
            btn.setIconSize(icon_size)
            try:
                layer_id = int(index.data())
            except ValueError:
                layer_id = int(float(index.data()))
            if layer_id <= 0:
                points = []
            elif layer_id not in self.pt_cache or self.size_cache[layer_id] != icon_size:
                df = self.data.materials[self.material_id].constituents
                df = df.loc[df['layer_id'] == layer_id]
                points = self._get_curve_preview_points(df, icon_size)
                self.pt_cache[layer_id] = points
                self.size_cache[layer_id] = icon_size
            else:
                points = self.pt_cache[layer_id]
            if points and has_sizes:
                # Set the pen and draw the border rectangle.
                old_pen = painter.pen()
                pen = QPen()
                pen.setStyle(Qt.SolidLine)
                pen.setColor(Qt.black)
                painter.setPen(pen)
                # add the offsets
                points = copy.deepcopy(points)
                min_x = points[0].x() + option.rect.x()
                min_y = points[0].y() + option.rect.y()
                max_x = min_x
                max_y = min_y
                for pt in points:
                    x = pt.x() + option.rect.x()
                    y = pt.y() + option.rect.y()
                    pt.setX(x)
                    pt.setY(y)
                    min_x = min(min_x, x)
                    min_y = min(min_y, y)
                    max_x = max(max_x, x)
                    max_y = max(max_y, y)
                painter.drawRect(min_x, min_y, max_x - min_x, max_y - min_y)

                # draw the line
                pen.setColor(Qt.red)
                painter.setPen(pen)
                painter.drawPolyline(points)

                # reset the pen
                painter.setPen(old_pen)
            else:
                btn.setFixedWidth(width)
                btn.setFixedHeight(height)

                pix = QPixmap(option.rect.size())
                btn.render(pix)
                painter.drawPixmap(option.rect.topLeft(), pix)

    def editorEvent(self, event, model, option, index):  # noqa: N802
        """Called when the XY series editor button is clicked.

        Args:
            event (QEvent): The editor event that was triggered.
            model (QAbstractItemModel): The data model.
            option (QStyleOptionViewItem): The style options.
            index (QModelIndex): The index in the model.
        """
        if self.constituent_sizes is None:
            return True

        if index.isValid() and index.flags() & Qt.ItemIsEnabled:
            if event.type() == QEvent.MouseButtonRelease:
                try:
                    layer_id = int(index.data())
                except ValueError:
                    layer_id = int(float(index.data()))
                del_cache = layer_id in self.pt_cache

                # Create a DataFrame with the following columns: 'layer_id', 'constituent_id', 'percent',
                # 'ID', 'NAME', 'GRAIN_DIAMETER'.
                df = self.data.materials[self.material_id].constituents
                df = df.loc[df['layer_id'] == layer_id]
                if df.empty:
                    ids = self.constituent_sizes['ID'].values.tolist()
                    if ids:
                        default_fraction = 1.0 / float(len(ids))
                    else:
                        default_fraction = 0.0
                    df = pandas.DataFrame(
                        data=[[layer_id, con_id, default_fraction] for con_id in ids],
                        columns=['layer_id', 'constituent_id', 'fraction']
                    )
                df = df.merge(self.constituent_sizes, left_on='constituent_id', right_on='ID')

                # Sort the DataFrame by grain size, smallest to largest.
                df.sort_values(by=['GRAIN_DIAMETER'], inplace=True)

                dialog = BedLayerConstituentDialog(self.parent_dlg, df)
                if dialog.exec_():
                    # remove old values for the layer id
                    self.data.materials[self.material_id].constituents = \
                        self.data.materials[self.material_id].constituents[
                            self.data.materials[self.material_id].constituents.layer_id != layer_id]

                    # add the new values in
                    self.data.materials[self.material_id].constituents = \
                        pandas.concat([self.data.materials[self.material_id].constituents,
                                       dialog.model.data_frame[['layer_id', 'constituent_id', 'fraction']]])
                    if del_cache:
                        del self.pt_cache[layer_id]
                return True

        return super().editorEvent(event, model, option, index)

    def _get_curve_preview_points(self, data_frame, graph_size):
        """Gets the points of the curve to draw onto the button.

        Args:
            data_frame (): The data frame that holds the curve
            graph_size (QSize): The size of the drawing area on the button.

        Returns:
            A list of QPointF of the curve to draw.
        """
        if not self.paint_sizes:
            return []

        zero_tol = 0.000001
        x_col = 1
        y_col = 2

        # Get the drawing area bounds to calculate line locations.
        graph_width = graph_size.width()
        graph_height = graph_size.height()
        xmin = min(self.paint_sizes)
        xmax = max(self.paint_sizes)
        ymin = data_frame.iloc[:, y_col].min()
        ymax = data_frame.iloc[:, y_col].sum()
        # add 10% margin and reset dx & dy
        dx = (xmax - xmin) * 0.1
        dy = (ymax - ymin) * 0.1
        xmin -= dx
        xmax += dx
        ymin -= dy
        ymax += dy
        dx = xmax - xmin
        dy = ymax - ymin
        # set scale factors
        sx = graph_width / dx if dx > zero_tol else 0.0
        sy = graph_height / dy if dy > zero_tol else 0.0
        point_draw = []
        prev_sum = 0.0
        # sort the values by grain size
        size_and_y = [
            [self.constituent_id_to_size_idx[x_data], y_data]
            for x_data, y_data in zip(data_frame.iloc[:, x_col], data_frame.iloc[:, y_col])
        ]
        size_and_y = sorted(size_and_y, key=lambda size_idx: size_idx[0])
        for (x_data, y_data) in size_and_y:
            x_data = self.paint_sizes[x_data]
            if sx == 0.0:
                x = 0.5 * graph_width
            else:
                x = (x_data - xmin) * sx

            y_data += prev_sum
            prev_sum = y_data
            if sy == 0.0:
                y = 0.5 * graph_height
            else:
                y = (y_data - ymax) * sy * -1.0

            point_draw.append(QPointF(x, y))
        return point_draw

    def set_constituents(self, constituents, clear_cache=False):
        """Sets the sediment constituents to use for this delegate.

        Args:
            constituents (SedimentConstituentsIO): The sediment constituent data.
            clear_cache (bool): True if the point cache for drawing should be cleared.
        """
        if clear_cache:
            self.pt_cache.clear()
            self.paint_sizes = []
            self.constituent_sizes = None
            self.constituent_id_to_size_idx = {}
            self.data.materials[self.material_id].constituents = \
                self.data.materials[self.material_id].constituents.iloc[0:0]
        self.constituents = constituents
        if self.constituents is None:
            self.constituent_sizes = None
            return

        # Set the constituent sizes based on data in constituents.
        ids = constituents.param_control.sand['ID'].values.tolist()
        sizes = constituents.param_control.sand['GRAIN_DIAMETER'].values.tolist()
        names = constituents.param_control.sand['NAME'].values.tolist()
        ids.extend(constituents.param_control.clay['ID'].values.tolist())
        sizes.extend(constituents.param_control.clay['GRAIN_DIAMETER'].values.tolist())
        names.extend(constituents.param_control.clay['NAME'].values.tolist())
        df = pandas.DataFrame(data=list(zip(ids, names, sizes)), columns=['ID', 'NAME', 'GRAIN_DIAMETER'])
        if ids:
            # Get sorted logarithmic grain sizes for the preview.
            self.constituent_sizes = df.sort_values(by=['GRAIN_DIAMETER'])
            for row, row_data in enumerate(df.itertuples()):
                con_id = int(row_data.ID)
                self.constituent_id_to_size_idx[con_id] = row  # store this for painting later
            self.paint_sizes = []
            for size in df['GRAIN_DIAMETER'].values.tolist():
                paint_size = math.log10(size) if size > 0.0 else 0.0
                self.paint_sizes.append(paint_size)
