"""DisDataBase class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy

# 2. Third party modules
from typing_extensions import override

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs

# 4. Local modules
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.data.griddata_base import GriddataBase
from xms.mf6.data.options_block import OptionsBlock
from xms.mf6.file_io import io_util
from xms.mf6.gui.options_defs import Checkbox, CheckboxComboBox, CheckboxField


class DisDataBase(GriddataBase):
    """Base class for all DIS packages."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self._grid_info: GridInfo | None = None

    @override
    def grid_info(self) -> GridInfo | None:
        """Returns the GridInfo object.

        Returns:
            (GridInfo): See description.
        """
        if self._grid_info:
            return self._grid_info
        else:
            return super().grid_info()

    def make_copy(self):
        """Returns a copy of this class."""
        return copy.deepcopy(self)

    def get_tops(self) -> list[float]:
        """Returns the top elevations of ALL the cells (not just the top layer)."""
        griddata = self.block('GRIDDATA')
        top_array = griddata.array('TOP')
        tops = top_array.get_values()
        bottoms = self.get_bottoms()
        tops.extend(bottoms[:-len(tops)])  # Append bottoms except last layer (also works for DISU)
        return tops

    def set_tops(self, values) -> None:
        """Sets the top elevations.

        Only uses the values up to the number of cells in the first layer. For DISU, size of values should equal total
        number of cells.
        """
        top_array = self.block('GRIDDATA').array('TOP')
        cells_per_layer = self.grid_info().cells_per_layer()
        shape = self._get_shape()
        top_array.layer(0).set_values(values[:cells_per_layer], shape)

    def get_bottoms(self) -> list[float]:
        """Returns the bottom elevations of all the cells."""
        bottom_name = 'BOT' if self.ftype == 'DISU6' else 'BOTM'
        griddata = self.block('GRIDDATA')
        bottom_array = griddata.array(bottom_name)
        return bottom_array.get_values()

    def set_bottoms(self, values: list[float]) -> None:
        """Set the bottom elevations.

        Args:
            values: The bottom elevations.
        """
        bottom_name = 'BOT' if self.ftype == 'DISU6' else 'BOTM'
        bottom_array = self.block('GRIDDATA').array(bottom_name)
        shape = self._get_shape()
        bottom_array.set_values(values, shape, combine=False)

    def get_idomain(self) -> list[int] | None:
        """Return the idomain of all the cells, if its defined, else None."""
        griddata = self.block('GRIDDATA')
        idomain_array = griddata.array('IDOMAIN')
        return idomain_array.get_values() if idomain_array is not None else None

    def set_idomain(self, values) -> None:
        """Sets the IDOMAIN values."""
        idomain_array = self.block('GRIDDATA').array('IDOMAIN')
        if idomain_array is None:
            return
        shape = self._get_shape()
        idomain_array.set_values(values, shape, False)

    def get_update_ugrid_data(self) -> tuple[list[float], list[float], list[int] | None]:
        """Get the tops, bottoms, and idomain arrays needed for updating the UGrid in GMS.

        Returns:
            (tuple): See description
        """
        tops = self.get_tops()
        bottoms = self.get_bottoms()
        idomain = self.get_idomain()
        return tops, bottoms, idomain

    def _get_shape(self) -> tuple[int, int]:
        """Return the array shape.

        Returns:
            See description.
        """
        cells_per_layer = self.grid_info().cells_per_layer()
        if self.ftype == 'DIS6':
            shape = (self.grid_info().nrow, self.grid_info().ncol)
        else:
            shape = (cells_per_layer, 1)
        return shape

    @override
    def _setup_options(self) -> OptionsBlock:
        """Returns the definition of all the available options.

        Returns:
            See description.
        """
        angrot_brief = 'Counter-clockwise rotation angle (in degrees) of the lower-left corner of the model grid'
        return OptionsBlock(
            [
                CheckboxComboBox(
                    'LENGTH_UNITS',
                    brief='length units used for this model',
                    items=['FEET', 'METERS', 'CENTIMETERS', 'UNKNOWN'],
                    value='UNKNOWN',
                    check_box_method='on_chk_length_units',
                    combo_box_method='on_cbx_length_units'
                ),
                Checkbox('NOGRB', brief='No binary grid file'),
                CheckboxField('GRB6 FILEOUT', brief='Name of binary grid output file', type_='str'),
                CheckboxField(
                    'XORIGIN',
                    brief='X-position of the lower-left corner of the model grid',
                    type_='float',
                    value=0.0,
                    read_only=True
                ),
                CheckboxField(
                    'YORIGIN',
                    brief='Y-position of the lower-left corner of the model grid',
                    type_='float',
                    value=0.0,
                    read_only=True
                ),
                CheckboxField('ANGROT', brief=angrot_brief, type_='float', value=0.0, read_only=True),
            ]
        )


def copy_list_blocks(list_blocks: dict) -> dict:
    """Copies the list blocks, including copying the files."""
    new_list_blocks = {}
    for name, filepath in list_blocks.items():
        new_list_blocks[name] = io_util.get_temp_filename()
        fs.copyfile(filepath, new_list_blocks[name])
    return new_list_blocks
