"""This module represents a single sediment material."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.guipy.data.polygon_texture import PolygonOptions

# 4. Local modules


class SedimentMaterial:
    """A storage class for a single material."""
    def __init__(self, number):
        """Constructs a new material.

        Args:
            number (int): The material id.
        """
        self.name = f'Sediment Material {number}'
        self.display = PolygonOptions()
        self.bed_layer_override = False
        self.bed_layer_cohesive_override = False
        self.consolidation_override = False
        self.displacement_off = False
        self.local_scour = False
        self.use_bedload_diffusion = False
        self.bedload_diffusion = 1.0
        self._constituent_columns = ['layer_id', 'constituent_id', 'fraction']
        self._bed_layer_columns = [
            'layer_id', 'thickness', 'porosity', 'critical_shear', 'erosion_constant', 'erosion_exponent'
        ]
        self._consolidation_columns = [
            'time_id', 'elapsed_time', 'porosity', 'critical_shear', 'erosion_constant', 'erosion_exponent'
        ]
        self.constituents = pd.DataFrame(data=[], columns=self._constituent_columns)
        self.bed_layers = pd.DataFrame(data=[], columns=self._bed_layer_columns)
        self.consolidation = pd.DataFrame(data=[], columns=self._consolidation_columns)
        self.sediment_material_properties = {}

    def set_bed_layers(self, layers, constituents):
        """Clears out previous bed layer overrides and adds new values.

        Args:
            layers (list): A list of int layer ids.
            constituents (list): A list of int constituent ids.
        """
        new_bed_layers = [[layer, 0.0, 0.0, 0.0, 0.0, 0.0] for layer in layers]
        self.bed_layers = pd.DataFrame(data=new_bed_layers, columns=self._bed_layer_columns)
        self.set_constituents(constituents)

    def set_consolidation(self, times):
        """Clears out previous consolidation overrides and adds new values.

        Args:
            times (list): A list of int time ids.
        """
        new_consolidations = [[time, 0.0, 0.0, 0.0, 0.0, 0.0] for time in times]
        self.consolidation = pd.DataFrame(data=new_consolidations, columns=self._consolidation_columns)

    def set_constituents(self, constituents):
        """Clears out previous bed layer constituents overrides and adds new values.

        Args:
            constituents (list): A list of int constituent ids.
        """
        new_constituent_layers = []
        layer_ids = self.bed_layers['layer_id'].tolist()
        for constituent in constituents:
            for layer in layer_ids:
                new_constituent_layers.append([layer, constituent, 0.0])
        self.constituents = pd.DataFrame(data=new_constituent_layers, columns=self._constituent_columns)

    def update_constituents(self, constituents):
        """Removes deleted constituent ids and adds new constituent ids.

        Args:
            constituents (list): A list of current int constituent ids.
        """
        # Remove values for constituents that have been deleted.
        self.constituents = self.constituents[self.constituents.constituent_id.isin(constituents)]

        # Add values for constituents that have been added.
        con_ids = self.constituents['constituent_id'].values.tolist()
        missing_con_ids = np.setdiff1d(constituents, con_ids, assume_unique=False).tolist()
        if missing_con_ids:
            layer_ids = list(set(self.constituents['layer_id'].values.tolist()))
            default_percent = 0.0
            new_values = []
            for layer_id in layer_ids:
                new_values.extend([[layer_id, con_id, default_percent] for con_id in missing_con_ids])
            self.constituents = pd.concat(
                [self.constituents, pd.DataFrame(data=new_values, columns=self._constituent_columns)],
                ignore_index=True
            )

    def update_bed_layers(self, layer_ids, global_values):
        """Removes deleted bed layers and adds new bed layers.

        Args:
            layer_ids (list): A list of current int bed layer ids.
            global_values (pandas.DataFrame): The global bed layer values to use as default values.
        """
        # Remove values for constituents that have been deleted.
        self.bed_layers = self.bed_layers[self.bed_layers.layer_id.isin(layer_ids)]

        # Add values for constituents that have been added.
        old_layer_ids = self.bed_layers['layer_id'].values.tolist()
        missing_layer_ids = np.setdiff1d(layer_ids, old_layer_ids, assume_unique=False).tolist()
        if missing_layer_ids:
            new_values = []
            for layer_id in missing_layer_ids:
                layer_list = global_values.loc[global_values.layer_id == layer_id].values.tolist()
                new_values.append(*layer_list)
            self.bed_layers = pd.concat(
                [self.bed_layers, pd.DataFrame(data=new_values, columns=self._bed_layer_columns)], ignore_index=True
            )

    def update_consolidation_times(self, time_ids, global_values):
        """Removes deleted consolidation time ids and adds new consolidation time ids.

        Args:
            time_ids (list): A list of current int consolidation time ids.
            global_values (pandas.DataFrame): The global consolidation values to use as default values.
        """
        # Remove values for constituents that have been deleted.
        self.consolidation = self.consolidation[self.consolidation.time_id.isin(time_ids)]

        # Add values for constituents that have been added.
        old_time_ids = self.consolidation['time_id'].values.tolist()
        missing_time_ids = np.setdiff1d(time_ids, old_time_ids, assume_unique=False).tolist()
        if missing_time_ids:
            new_values = []
            for time_id in missing_time_ids:
                time_list = global_values.loc[global_values.time_id == time_id].values.tolist()
                new_values.append(*time_list)
            self.consolidation = pd.concat(
                [self.consolidation,
                 pd.DataFrame(data=new_values, columns=self._consolidation_columns)],
                ignore_index=True
            )
