"""A Qt model for grain size distribution to be shown in a plot."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from PySide2.QtCore import QModelIndex

# 3. Aquaveo modules
from xms.guipy.models.rename_model import RenameModel

# 4. Local modules


class DistributionPlotModel(RenameModel):
    """A class to filter out columns for the plot."""
    def __init__(self, parent=None):
        """Initializes the filter model.

        This model assumes that the source model has a pandas DataFrame with columns: 'layer_id', 'constituent_id',
        'percent', 'ID', 'NAME', 'GRAIN_DIAMETER'. The DataFrame should be sorted by 'GRAIN_DIAMETER'.

        Args:
            parent (Something derived from :obj:`QObject`): The parent object.
        """
        self._SIZE_SOURCE_COLUMN = 5
        self._PERCENT_SOURCE_COLUMN = 2
        super().__init__(['Distribution Fraction', 'Grain diameter (m)'], parent)

    def filterAcceptsColumn(self, source_column: int, source_parent: QModelIndex) -> bool:  # noqa: N802
        """Filters out columns the plot doesn't need.

        Args:
            source_column (int): The column index in the source model.
            source_parent (QModelIndex): The parent index of the view.

        Returns:
            True if the model should keep the column, false if the column should be hidden.
        """
        return source_column in [self._SIZE_SOURCE_COLUMN, self._PERCENT_SOURCE_COLUMN]

    def get_plot_values(self):
        """Gets plot data values. Sums the percentages with smaller sizes.

        Returns:
            Returns a tuple of sizes, and distribution percents. Both values are of type list[float].
        """
        # Convert to python lists right away for iteration and so we don't edit values in the source DataFrame.
        x_column = self.sourceModel().data_frame.iloc[:, self._SIZE_SOURCE_COLUMN].values.tolist()
        y_column = self.sourceModel().data_frame.iloc[:, self._PERCENT_SOURCE_COLUMN].values.tolist()
        for y_idx in range(1, len(y_column)):
            y_column[y_idx] += y_column[y_idx - 1]
        return x_column, y_column
