"""SWI package tools."""

__copyright__ = '(C) Copyright Aquaveo 2025'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import copy
from dataclasses import dataclass, field
from pathlib import Path
import shutil
import sqlite3

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.guipy.models.qx_pandas_table_model import is_datetime_or_timedelta_dtype
from xms.testing import tools

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.components.component_creator import (
    add_and_link, create_add_component_action, create_data_objects_component
)
from xms.mf6.components.gwf.drn_component import DrnComponent
from xms.mf6.components.gwf.ghb_component import GhbComponent
from xms.mf6.components.gwf.rch_component import RchComponent
from xms.mf6.components.package_component_base import PackageComponentBase
from xms.mf6.data import data_util
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.data.gwf.drn_data import DrnData
from xms.mf6.data.gwf.ghb_data import GhbData
from xms.mf6.data.gwf.rch_array_data import RchArrayData
from xms.mf6.data.gwf.rch_list_data import RchListData
from xms.mf6.data.gwf.swi_data import DEFAULT_DRN_CONDUCTANCE, DEFAULT_GHB_CONDUCTANCE, SwiData
from xms.mf6.data.mfsim_data import MfsimData
from xms.mf6.data.tva_data import TvaData
from xms.mf6.file_io import database, io_factory
from xms.mf6.file_io.writer_options import WriterOptions
from xms.mf6.gui.dialog_input import DialogInput
from xms.mf6.mapping.package_builder_base import (
    add_list_package_period_data, cell_active, cell_has_chd, update_chd_cells_from_package
)


@dataclass
class SwiSetUpInputs:
    """Information needed to set up the SWI related packages."""
    sea_level_csv: str = ''
    sea_level_df: pd.DataFrame | None = None
    create_ghb: bool = True
    ghb_cond: float = DEFAULT_GHB_CONDUCTANCE
    create_drn: bool = True
    drn_cond: float = DEFAULT_DRN_CONDUCTANCE
    fix_rch: bool = True
    create_tva_file: bool = True  # SALTWATER_HEAD


@dataclass
class SwiSetUpOutputs:
    """Results (new package data) from creating/fixing the SWI related packages."""
    ghb: GhbData | None = None
    drn: DrnData | None = None
    rch_data: list[RchListData | RchArrayData] = field(default_factory=list)
    tva_filepath: Path | None = None  # SALTWATER_HEAD
    errors: list[str] = field(default_factory=list)


def set_up_swi(
    mfsim: MfsimData, swi_package: SwiData, inputs: SwiSetUpInputs, temp_dir: Path | None = None
) -> SwiSetUpOutputs:
    """Creates the GHB package, adds BCs, and returns the GhbData.

    Args:
        mfsim: The MODFLOW simulation.
        swi_package: The SWI6 package.
        inputs: Data needed to create/fix the related packages.
        temp_dir: Only used when testing. Temporary directory where packages are created.

    Return:
        SwiSetUpOutputs: The new packages and any error messages.
    """
    setter_upper = SwiSetterUpper(mfsim, swi_package, inputs, temp_dir)
    return setter_upper.do_set_up()


def add_packages(dlg_input: DialogInput, outputs: SwiSetUpOutputs) -> None:
    """Adds the GHB package component to XMS.

    Args:
        dlg_input: Dialog data.
        outputs: Results of running the tool.
    """
    comp_class_dict = {'GHB6': GhbComponent, 'DRN6': DrnComponent, 'RCH6': RchComponent}
    for data in [outputs.ghb, outputs.drn, *outputs.rch_data]:
        if data is None:
            continue
        klass = comp_class_dict[data.ftype]
        _copy_to_components_dir(dlg_input, data, klass)
        _create_component_and_add_to_xms(dlg_input, data, klass)
        data.update_displayed_cell_indices()
        model = data.model
        dlg_input.actions.append(dmi_util.update_display_action(model.ftype, model.filename, data.filename))


class SwiSetterUpper:
    """Creates a new GHB package for the SWI package."""
    def __init__(
        self, mfsim: MfsimData, swi_package: SwiData, inputs: SwiSetUpInputs, temp_dir: Path | None = None
    ) -> None:
        """Initializes the class.

        Args:
            mfsim: The MODFLOW simulation.
            swi_package: The SWI6 package.
            inputs: Data needed to create/fix the related packages.
            temp_dir: Only used when testing. Temporary directory where packages are created.
        """
        self._mfsim = mfsim
        self._swi_package = swi_package
        self._inputs = inputs
        self._temp_dir: Path | None = temp_dir if temp_dir else XmEnv.xms_environ_process_temp_directory()

        self._model = swi_package.model
        self._chd_cells: dict[int, set[int]] = {}  # Cells containing a CHD bc. Stress period # -> set of cell indexes
        self._date_times: bool = False  # True if we're using dates/times
        self._sp_times: list = []  # Stress period times, size of stress periods + 1 (for end time)
        self._sea_level_times: list = []  # Will be column[0] from _sea_level_df
        self._sea_levels: list = []  # Will be column[1] from _sea_level_df
        self._ghb_data: GhbData | None = None
        self._drn_data: DrnData | None = None
        self._rch_data: list[RchListData | RchArrayData] = []  # List of fixed RCH packages
        self._grid_info: GridInfo | None = self._model.grid_info() if self._model else None
        self._idomain: list[int] = []  # Use the idomain() @property and don't use this directly
        self._tops: list[float] = []  # Use the tops() @property and don't use this directly
        self._errors: list[str] = []  # List of errors

        self._split_sea_level_df()
        self._setup_period_times()

    def do_set_up(self) -> SwiSetUpOutputs:
        """Creates the GHB package, adds BCs, and returns the GhbData.

        Return:
            Path to the new GHB file.
        """
        if self._errors:
            return SwiSetUpOutputs(errors=self._errors)
        if self._inputs.create_ghb or self._inputs.create_drn:
            self._find_chd_cells()
            if self._inputs.create_ghb:
                self._ghb_data = self._init_data(GhbData, None)
            if self._inputs.create_drn:
                self._drn_data = self._init_data(DrnData, None)
            self._add_ghb_and_drn_bcs(self._inputs.create_ghb, self._inputs.create_drn)
        if self._inputs.fix_rch:
            self._fix_rch()
        tva_filepath = None
        if self._inputs.create_tva_file:
            tva_filepath = self._create_saltwater_head_tva()
        self._write()

        return SwiSetUpOutputs(
            ghb=self._ghb_data,
            drn=self._drn_data,
            rch_data=self._rch_data,
            tva_filepath=tva_filepath,
            errors=self._errors
        )

    @property
    def idomain(self) -> list[int]:
        """Returns the idomain.

        Using a property to make testing easier. This way I don't have to set up idomain in the tests.
        """
        if not self._idomain:
            self._idomain = _get_idomain(self._model.get_dis())
        return self._idomain

    @property
    def tops(self) -> list[float]:
        """Returns the cell top elevations.

        Using a property to make testing easier. This way I don't have to set up tops in the tests.
        """
        if not self._tops:
            self._tops = self._model.get_dis().get_tops()
        return self._tops

    def get_writer_options(self) -> WriterOptions:
        """Returns the WriterOptions."""
        # mfsim_dir and dmi_sim_dir need to be the temp directory in which we created the temporary component dirs
        mfsim_dir = Path(self._temp_dir) / 'dummy'
        dmi_sim_dir = str(Path(self._temp_dir))
        return WriterOptions(mfsim_dir=mfsim_dir, use_open_close=False, dmi_sim_dir=dmi_sim_dir, use_periods_db=True)

    def _setup_period_times(self) -> bool:
        """Gets the stress periods from the TDIS package and saves them to self._sp_times.

        The type of self._sp_times is set to match the type of the times in the sea level df (floats, datetimes...).

        Returns:
            True if the sim period times and the sea level dataframe times are compatible.
        """
        if not self._mfsim:  # This can be None only when testing
            return True

        tdis = self._mfsim.tdis
        column = self._inputs.sea_level_df.columns[0]
        self._date_times = is_datetime_or_timedelta_dtype(self._inputs.sea_level_df[column])
        df = tdis.get_period_times(as_date_times=self._date_times)
        self._sp_times = df['Time'].to_list()
        return self._validate_time_series()

    def _split_sea_level_df(self) -> None:
        """Splits the sea level df into two lists for convenience."""
        if self._inputs.sea_level_df is None:  # This can be None only when testing
            return

        self._sea_level_times = self._inputs.sea_level_df[self._inputs.sea_level_df.columns[0]].to_list()
        self._sea_levels = self._inputs.sea_level_df[self._inputs.sea_level_df.columns[1]].to_list()
        assert len(self._sea_level_times) == len(self._sea_levels)

    def _get_sea_level(self, sp_idx: int) -> float | None:
        """Returns the last sea level that is before or at the end of the stress period.

        If no sea levels are before or at the end of the stress period, returns None.

        Args:
            sp_idx: The stress period index.

        Returns:
            See description.
        """
        sp_end_time = self._sp_times[sp_idx + 1]
        sea_level_idx = None
        for i, sea_level_time in enumerate(self._sea_level_times):
            if sea_level_time <= sp_end_time:
                sea_level_idx = i
            else:
                break
        return self._sea_levels[sea_level_idx] if sea_level_idx is not None else None

    def _find_chd_cells(self) -> None:
        """Gets and saves the CHD cells."""
        chds = self._model.packages_from_ftype('CHD6')
        for chd in chds:
            update_chd_cells_from_package(self._grid_info, chd, self._chd_cells)

    def _validate_time_series(self) -> bool:
        """Checks that sea level time series and stress periods are compatible.

        Returns False and adds an error if there's a problem.
        """
        # Check dates/times
        sp_is_date_times = is_datetime_or_timedelta_dtype(self._sp_times[0])
        if self._date_times != sp_is_date_times:
            self._errors.append(
                'Sea level time series uses dates/times, but TDIS does not. Turn on START_DATE_TIME option'
                ' in TDIS, or do not use dates/times in sea level time series. Process aborted.'
            )
            return False

        # Check if all sea level rise occurs after the sim and, if so, abort. If all sea level rise occurs before the
        # sim, it's OK, and we'll continue because the last sea level will still affect the sim.
        first_sea_level_time = self._inputs.sea_level_df.iloc[0, 0]
        if first_sea_level_time >= self._sp_times[-1]:
            self._errors.append(
                'Sea level time series starts after the end of all model stress periods. Process aborted.'
            )
            return False

        return True

    def _init_data(self, klass: type[BaseFileData] | None, orig) -> GhbData | DrnData | RchListData | RchArrayData:
        """Initializes and returns the package data object.

        Args:
            klass: Data class.
            orig: Original data.

        Returns:
            The data object.
        """
        if klass and orig:
            raise ValueError('Pass klass to create from scratch, orig to copy existing.')

        # Put it in a component-like uuid dir inside a temp dir
        new_comp_uuid = tools.new_uuid()
        new_comp_dir = Path(self._temp_dir) / new_comp_uuid
        new_comp_dir.mkdir()
        if klass:
            extension = data_util.extension_from_ftype(klass().ftype)
            new_main_file = new_comp_dir / f'model{extension}'
            data = klass(filename=str(new_main_file), mfsim=self._mfsim, model=self._model)
            data.periods_db = database.database_filepath(str(new_main_file))
        else:
            new_main_file = new_comp_dir / Path(orig.filename).name
            shutil.copytree(Path(orig.filename).parent, new_main_file.parent, dirs_exist_ok=True)
            component = RchComponent(orig.filename)
            component.duplicate(orig.filename, new_main_file)
            reader = io_factory.reader_from_ftype(orig.ftype)
            data = reader.read(new_main_file, mfsim=orig.mfsim, model=orig.model)
        return data

    def _add_ghb_and_drn_bcs(self, create_ghb: bool, create_drn: bool) -> None:
        """Adds the GHB and DRN BCs."""
        # Iterate through stress periods
        last_sea_level = None
        for sp_idx, _sp_time in enumerate(range(len(self._sp_times) - 1)):
            sea_level = self._get_sea_level(sp_idx)
            if sea_level is None:
                continue

            if sea_level == last_sea_level:
                pass  # Not defining the stress period makes it the same as the previous one
            else:
                highest_active_cells = self._highest_active_cells(self.idomain, sp_idx)
                ghb_period_rows = []
                drn_period_rows = []
                for cell_idx in range(self._grid_info.cells_per_layer()):
                    cellid_tuple = highest_active_cells.get(cell_idx)
                    if not cellid_tuple:
                        continue

                    cell_idx = self._grid_info.cell_index_from_modflow_cellid(cellid_tuple)
                    top = self.tops[cell_idx]
                    row = [x for x in cellid_tuple]
                    if top < sea_level:  # Add ghb bc
                        if create_ghb:
                            row.extend([float(sea_level), self._inputs.ghb_cond])
                            ghb_period_rows.append(row)
                    elif create_drn:  # Add drn bc
                        row.extend([top, self._inputs.drn_cond])
                        drn_period_rows.append(row)
                if ghb_period_rows:
                    add_list_package_period_data(self._ghb_data, sp_idx + 1, ghb_period_rows)
                if drn_period_rows:
                    add_list_package_period_data(self._drn_data, sp_idx + 1, drn_period_rows)
            last_sea_level = sea_level

    def _fix_rch(self) -> None:
        """Fixes the RCH packages to turn off recharge when sea level rises above the top."""
        # Iterate through rch packages
        rchs = self._model.packages_from_ftype('RCH6')
        for rch in rchs:
            new_rch = self._init_data(klass=None, orig=rch)
            if isinstance(rch, RchListData):
                changes_made = self._fix_list_rch(new_rch)
            else:
                changes_made = self._fix_array_rch(new_rch)
            if changes_made:
                self._rch_data.append(new_rch)

    def _fix_array_rch(self, rch: RchArrayData) -> bool:
        """Fixes a RchArrayData package.

        Args:
            rch: The RCH package.

        Returns:
            True if changes were made.
        """
        rch.fill_missing_periods(len(self._sp_times) - 1)

        # Iterate through stress periods
        changes_made: bool = False
        for sp_idx, _sp_time in enumerate(range(len(self._sp_times) - 1)):
            sea_level = self._get_sea_level(sp_idx)
            if sea_level is None:
                continue

            # Get IRCH if it is defined
            irch_array = rch.period_data[sp_idx + 1].array('IRCH')
            irch = irch_array.get_values() if irch_array is not None else None

            # Set RECHARGE to 0.0 wherever the top of the cell where recharge is applied is below sea level
            rch_array = rch.period_data[sp_idx + 1].array('RECHARGE')
            values = rch_array.get_values()
            changes_in_sp: bool = False
            for i in range(len(values)):
                # Get the cell index where recharge is applied
                if irch:
                    layer = irch[i]
                    cell_idx = self._grid_info.cell_index_from_lay_cell2d(layer, i + 1)
                else:  # If IRCH is omitted, recharge by default is applied to cells in layer 1.
                    cell_idx = i

                # Set recharge to 0.0 if top is below sea level
                top = self.tops[cell_idx]
                if top < sea_level and values[i] != 0.0:
                    values[i] = 0.0
                    changes_in_sp = True

            if changes_in_sp:
                shape = rch_array.layer(0).shape
                rch.period_data[sp_idx + 1].array('RECHARGE').set_values(values, shape, True)
                changes_made = True

        return changes_made

    def _fix_list_rch(self, rch: RchListData) -> bool:
        """Fixes a RchListData package.

        Everything is done on the database using SQL.

        Args:
            rch: The RCH package.

        Returns:
            True if changes were made.
        """
        db_filename = database.filepath_from_package(rch)
        changes_made: bool = False
        try:
            with sqlite3.connect(db_filename) as cxn:
                rch.fill_missing_periods(len(self._sp_times) - 1, cxn)
                database.add_tops_table(rch, cxn)

                # Iterate through stress periods
                for sp_idx, _sp_time in enumerate(range(len(self._sp_times) - 1)):
                    sea_level = self._get_sea_level(sp_idx)
                    if sea_level is None:
                        continue

                    rows_changed = _update_recharge_sql(sea_level, sp_idx + 1, cxn)
                    if rows_changed:
                        changes_made = True
                _drop_tops_table_sql(cxn)
                cxn.commit()

        except sqlite3.Error as er:  # pragma no cover
            database.log_sqlite_error(er)
        except Exception as error:  # pragma no cover
            raise RuntimeError(str(error))
        return changes_made

    def _create_saltwater_head_tva(self) -> Path | None:
        """Create the TVA6 file containing SALTWATER_HEAD."""
        tva_dict = self._get_period_sea_levels()
        filepath = self._build_tva_data(tva_dict)
        return filepath

    def _build_tva_data(self, tva_dict: dict[int, float]) -> Path:
        """Build the TvaData object and return the filepath.

        Args:
            tva_dict: dict of period -> constant; the sea levels at each stress period.

        Returns:
            See description.
        """
        filepath = Path(self._swi_package.filename).with_name('saltwater_head.tva')
        tva_data = TvaData(filename=str(filepath))
        tva_data.options_block.set('AUXILIARY', True, ['SALTWATER_HEAD'])
        tva_data._griddata_names = []
        # shape = self._swi_package.block('GRIDDATA').array('ZETASTRT').layer(0).shape
        for period, constant in tva_dict.items():
            tva_data.add_period(period)
            array = tva_data.add_transient_array(period, 'SALTWATER_HEAD')
            array.ensure_layer_exists(name='SALTWATER_HEAD', constant=constant)
        writer_options = self.get_writer_options()
        tva_data.write(writer_options)
        return filepath

    def _get_period_sea_levels(self) -> dict[int, float]:
        """Create and return a dict of period -> constant; the sea levels at each stress period.

        Returns:
            See description.
        """
        tva_dict: dict[int, float] = {}
        last_sea_level = None
        for sp_idx, _sp_time in enumerate(range(len(self._sp_times) - 1)):
            sea_level = self._get_sea_level(sp_idx)
            if sea_level is None:
                continue

            if sea_level == last_sea_level:
                pass  # Not defining the stress period makes it the same as the previous one
            else:
                tva_dict[sp_idx + 1] = sea_level
                last_sea_level = sea_level
        return tva_dict

    def _highest_active_cells(self, idomain: list[int], sp_idx: int) -> dict:
        """Returns a dict of a cell index -> cellid_tuple of the highest active cell in that column.

        Args:
            idomain: The idomain array.
            sp_idx: The stress period index.

        Returns:
            See description.
        """
        highest_active = {}
        for i in range(self._grid_info.cells_per_layer()):
            cell_id = self._grid_info.fix_cellid(i, layer=1)
            cell_id, cell_idx = _highest_active_cell(self._grid_info, cell_id, idomain)
            if not cell_id:
                continue
            if cell_has_chd(sp_idx + 1, cell_idx, self._chd_cells):
                continue
            highest_active[i] = cell_id
        return highest_active

    def _write(self) -> None:
        """Writes the packages to disk."""
        writer_options = self.get_writer_options()
        if self._ghb_data:
            self._ghb_data.write(writer_options)
        if self._drn_data:
            self._drn_data.write(writer_options)
        for rch in self._rch_data:
            options = copy.deepcopy(writer_options)
            # If array-based, write each stress period to separate files since there's no periods.db
            options.use_open_close = False if isinstance(rch, RchListData) else True
            rch.write(options)


def _get_idomain(dis):
    """Returns the IDOMAIN from the DIS package, if there is one.

    Returns:
        (list): See description.
    """
    if dis:
        griddata = dis.block('GRIDDATA')
        if griddata and griddata.has('IDOMAIN'):
            array = griddata.array('IDOMAIN')
            return [int(v) for v in array.get_values()]
    return None


def _copy_to_components_dir(dlg_input: DialogInput, data, klass) -> None:
    """Copy the temporary data folder to the components directory.

    Args:
        dlg_input: DialogInput.
        data: The package data object.
        klass: Component class.
    """
    components_dir = Path(dlg_input.data.filename).parent.parent
    uuid = Path(data.filename).parent.name
    new_main_file = components_dir / uuid / Path(data.filename).name
    shutil.copytree(Path(data.filename).parent, new_main_file.parent, dirs_exist_ok=False)
    component = klass(data.filename)
    component.duplicate(data.filename, new_main_file)
    data.filename = str(new_main_file)
    data.periods_db = str(new_main_file.with_name('periods.db'))


def _create_component_and_add_to_xms(
    dlg_input: DialogInput, data: type[BaseFileData], component_class: type[PackageComponentBase]
) -> None:
    """Creates the component and adds it to XMS.

    Args:
        dlg_input: DialogInput.
        data: The package data.
        component_class: The component class.
    """
    main_file = data.filename
    ftype = data.ftype
    do_comp = create_data_objects_component(main_file, f'{ftype[:-1]}-from-SWI', 'MODFLOW 6', ftype)
    comp = component_class(str(main_file))
    action = create_add_component_action(comp, 'get_initial_display_options')
    actions = [action]
    model = dlg_input.data.model
    parent_uuid = model.tree_node.uuid
    add_and_link(dlg_input.query, do_comp, actions, parent_uuid)


def _highest_active_cell(grid_info: GridInfo, cell_id, idomain):
    """Returns the highest active cell in the column of cells that cell_idx is in.

    Args:
        grid_info: The GridInfo object.
        cell_id: The cell id.
        idomain: The idomain array.

    Returns:
        cell id and cell index (or None, None)
    """
    for layer in range(grid_info.nlay):
        cellid_for_layer = layer + 1, *cell_id[1:]  # create cellid tuple with the correct layer
        cell_idx = grid_info.cell_index_from_modflow_cellid(cellid_for_layer)
        if cell_active(idomain, cell_idx):
            return cellid_for_layer, cell_idx
    return None, None


def _drop_tops_table_sql(cxn: sqlite3.Connection):
    """SQL for dropping the 'tops' table.

    Args:
        cxn: An sqlite3 connection.
    """
    cxn.execute('DROP TABLE tops')


def _update_recharge_sql(sea_level: float, period: int, cxn: sqlite3.Connection) -> int:
    """SQL for setting RECHARGE to 0.0 for cells whose top elevation is below the sea level.

    Args:
        sea_level: The sea level.
        period: The stress period.
        cxn: An sqlite3 connection.

    Returns:
        Number of rows affected
    """
    stmt = (
        'UPDATE data'
        ' SET RECHARGE = 0.0'
        ' FROM (SELECT TOP, CELLIDX FROM tops) as t'
        ' WHERE data.CELLIDX == t.CELLIDX AND TOP < ? AND data.PERIOD = ?'
    )
    cursor = cxn.cursor()
    cursor.execute(stmt, (sea_level, period))
    return cursor.rowcount
