"""OcWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data import data_util
from xms.mf6.data import oc_data
from xms.mf6.data.oc_data import OcPresetOutputEnum
from xms.mf6.file_io import database, io_util
from xms.mf6.file_io.list_package_writer import ListPackageWriter
from xms.mf6.misc.settings import Settings


class OcWriter(ListPackageWriter):
    """Writes an OC6 package file."""
    def __init__(self):
        """Initializes the class."""
        super().__init__()

    # @overrides
    def _write_options(self, fp):
        """Writes the options block.

        Args:
            fp (_io.TextIOWrapper): The file.
        """
        self._set_fileout_filenames_if_auto_naming()
        super()._write_options(fp)

    def _write_dimensions(self, fp) -> None:

        pass

    def _set_fileout_filenames_if_auto_naming(self):
        if self._auto_file_naming():
            prefix = self._data.model.mname if self._data.model else 'model'
            changed_options = {}
            for key, _ in self._data.options_block.dict().items():
                if 'FILEOUT' in key:
                    extension = oc_data.fileout_extension(key)
                    changed_options[key] = data_util.auto_file_name(prefix, extension)
            for key, value in changed_options.items():
                self._data.options_block.set(key, True, value)

    def _auto_file_naming(self):
        """Returns True if the Auto name checkbox is checked.

        Returns:
            (bool): See description.
        """
        if self._data.mfsim:
            return self._data.mfsim.gms_options.get('AUTO_FILE_NAMING', False)
        else:
            return True

    def _create_preset_external_file(self) -> Path:
        """Creates a file with the preset (OC_EVERY_TIME_STEP, OC_LAST_TIME_STEPS) lines.

        Returns:
            The filepath.
        """
        # 'HEAD', 'CONCENTRATION', or 'TEMPERATURE'
        word = oc_data.oc_first_word(self._data.model.ftype) if self._data.model else 'HEAD'
        if self._data.preset_output == OcPresetOutputEnum.OC_EVERY_TIME_STEP:
            list_lines = ['PRINT BUDGET ALL\n', f'PRINT {word} ALL\n', 'SAVE BUDGET ALL\n', f'SAVE {word} ALL\n']
        elif self._data.preset_output == OcPresetOutputEnum.OC_LAST_TIME_STEPS:
            list_lines = ['PRINT BUDGET LAST\n', f'PRINT {word} LAST\n', 'SAVE BUDGET LAST\n', f'SAVE {word} LAST\n']
        else:
            raise ValueError()

        with open(io_util.get_temp_filename(suffix='.mf6_tmp'), mode='wt') as new_file:
            for line in list_lines:
                new_file.write(line)
        return new_file.name

    def _write_stress_periods_preset(self, fp) -> None:
        """Writes the stress periods when one of the presets has been selected.

        Args:
            fp (_io.TextIOWrapper): The file.
        """
        # We only need to write stress period 1. All remaining stress periods will then 'use previous'
        external_file = self._create_preset_external_file()
        self._data.period_files = {1: external_file}
        super()._write_stress_periods(fp)

    def _write_stress_periods(self, fp):
        """Writes the stress periods.

        Args:
            fp (_io.TextIOWrapper): The file.
        """
        # Build database first, as it can speed up determining the preset
        periods_db = self._data.periods_db
        if self._writer_options.dmi_sim_dir:
            database.build(self._data)
            periods_db = self._data.periods_db
        if not periods_db or not Path(periods_db).exists() or Path(periods_db).stat().st_size == 0:
            # Create the periods database in a temp file so we can determine the preset
            periods_db = io_util.get_temp_filename(suffix='.db')
            database.build(self._data, fill=True, db_filepath=periods_db)

        # If preset_output is OC_USER_SPECIFIED, see if the data matches OC_EVERY_TIME_STEP or OC_LAST_TIME_STEPS and,
        # if so, change it to that.
        if self._data.preset_output == OcPresetOutputEnum.OC_USER_SPECIFIED:
            self._data.preset_output = self._compute_preset_output_from_data(periods_db)

        if self._data.preset_output == OcPresetOutputEnum.OC_USER_SPECIFIED:
            super()._write_stress_periods(fp)
        else:
            self._write_stress_periods_preset(fp)

    def _compute_preset_output_from_data(self, periods_db: str | Path) -> OcPresetOutputEnum:
        """If preset_output is OC_USER_SPECIFIED, see if the data matches OC_EVERY_TIME_STEP.

         If so, preset_output is changed.

        Args:
            periods_db: The db filepath.
        """
        preset = OcPresetOutputEnum.OC_USER_SPECIFIED

        # First stress period must be defined to match OC_EVERY_TIME_STEP or OC_LAST_TIME_STEPS
        if not self._data.period_files or min(self._data.period_files.keys()) != 1:
            return preset

        model_ftype = self._data.model.ftype if self._data.model else 'GWF6'
        word = oc_data.oc_first_word(model_ftype).lower()  # 'head' or 'concentration'
        flags = {
            'print_budget_all': 1,
            'print_budget_last': 1,
            f'print_{word}_all': 1,
            f'print_{word}_last': 1,
            'save_budget_all': 1,
            'save_budget_last': 1,
            f'save_{word}_all': 1,
            f'save_{word}_last': 1,
            'every_time_step': 1,
            'last_time_steps': 1
        }
        last_preset: OcPresetOutputEnum = OcPresetOutputEnum.OC_END
        preset = OcPresetOutputEnum.OC_USER_SPECIFIED
        for sp in sorted(self._data.period_files.keys()):
            # Reset flags for stress period
            flags = {k: 0 if k.startswith('print') or k.startswith('save') else v for k, v in flags.items()}

            # Read period into a dataframe
            df = self._data.get_period_df(sp, periods_db)

            # If it's an empty stress period, can't do preset
            if len(df) == 0:
                preset = OcPresetOutputEnum.OC_USER_SPECIFIED
                break

            # Check this stress period
            for _, row in df.iterrows():
                # Extract values from dataframe row into convenience variables
                print_save = row[oc_data.PRINT_SAVE].lower()
                rtype = row[oc_data.RTYPE].lower()
                ocsetting = row[oc_data.OCSETTING].lower()

                # Update flags and abort if there's something that doesn't match our presets
                key = f'{print_save}_{rtype}_{ocsetting}'
                if key not in flags:
                    flags['every_time_step'] = flags['last_time_steps'] = 0
                    preset = OcPresetOutputEnum.OC_USER_SPECIFIED
                    break
                else:
                    flags[key] = True

            # See if flags match one of our presets
            done = self._check_preset_validity(flags, word)
            if done:
                break

            # Set the preset
            if flags['every_time_step']:
                preset = OcPresetOutputEnum.OC_EVERY_TIME_STEP
            elif flags['last_time_steps']:
                preset = OcPresetOutputEnum.OC_LAST_TIME_STEPS
            else:
                preset = OcPresetOutputEnum.OC_USER_SPECIFIED
                break

            # See if this stress period is different from the previous ones
            if last_preset != OcPresetOutputEnum.OC_END and last_preset != preset:
                preset = OcPresetOutputEnum.OC_USER_SPECIFIED
                break
            last_preset = preset
        return preset

    def _check_preset_validity(self, flags, word):
        """See if we can still do the presets and if not, flag them as 0.

        Args:
            flags (dict of str -> int): Dict of our flags.

        Returns:
            done (bool): True if we can break out of the calling for loop.
        """
        done = False
        # See if every_time_step is still valid
        fg = ['print_budget_all', f'print_{word}_all', 'save_budget_all', f'save_{word}_all']  # for short
        if not flags[fg[0]] or not flags[fg[1]] or not flags[fg[2]] or not flags[fg[3]]:
            flags['every_time_step'] = 0
            if not flags['last_time_steps']:
                done = True

        # See if last_time_steps is still valid
        fg = ['print_budget_last', f'print_{word}_last', 'save_budget_last', f'save_{word}_last']  # for short
        if not flags[fg[0]] or not flags[fg[1]] or not flags[fg[2]] or not flags[fg[3]]:
            flags['last_time_steps'] = 0
            if not flags['every_time_step']:
                done = True

        return done

    def write_settings(self, data):
        """If writing to components area, save preset_output to a settings file.

        Args:
            data: The package data.
        """
        super().write_settings(data)
        if self._writer_options.dmi_sim_dir:
            settings = Settings.read_settings(data.filename)
            settings['preset_output'] = data.preset_output
            Settings.write_settings(data.filename, settings)
