"""SolutionReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass
import glob
import os
from pathlib import Path

# 2. Third party modules
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.components.bases.component_base import ActionRv
from xms.core.filesystem import filesystem as fs
from xms.data_objects.parameters import Component, Coverage
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid
from xms.guipy.dialogs import process_feedback_dlg
from xms.guipy.dialogs.feedback_thread import FeedbackThread

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.components.xms_data import XmsData
from xms.mf6.data.oc_data import oc_first_word
from xms.mf6.file_io import (dependent_variable_file_reader, model_reader_base, tdis_reader)
from xms.mf6.file_io.gwf.oc_reader import OcReader
from xms.mf6.file_io.pest import pest_obs_coverage_builder, pest_obs_results_reader, pest_obs_stats_writer
from xms.mf6.file_io.pest.pest_obs_results_reader import ObsResults
from xms.mf6.gui import units_util
from xms.mf6.misc import log_util
from xms.mf6.simulation_runner import pest_obs_runner


@dataclass
class CoverageCompInfo:
    """Stores info needed to add the coverage component to the query."""
    coverage: Coverage
    component: Component
    comp_data: list[dict]
    coverage_type: str
    model_name: str


def read(query: Query, params: list[dict], win_cont: QWidget) -> ActionRv:
    """Runs the solution reader.

    Args:
        query: Object for communicating with XMS
        params: list[dict]
        win_cont: The window container.
    """
    thread = ReadSolutionFeedbackThread(query, params)
    process_feedback_dlg.run_feedback_dialog(thread, win_cont)
    return [], []


class ReadSolutionFeedbackThread(FeedbackThread):
    """Thread for reading the solution."""
    def __init__(self, query: Query, params: list[dict]):
        """Initializes the class."""
        super().__init__(query)
        self._query = query
        self._params = params

        mfsim_nam = Path(params[0]['run_dir'])
        self.display_text |= {
            'title': 'Read Solution',
            'working_prompt': f'Reading solution at \"{str(mfsim_nam)}\".',
            'success_prompt': f'Successfully read solution at \"{str(mfsim_nam)}\".',
        }

    def _run(self):
        """Does the work."""
        reader = SolutionReader(self._query, self._params)
        reader.read()


class SolutionReader:
    """Reads the solution."""
    def __init__(self, query: Query, params: list[dict]):
        """Initializes the class.

        Args:
            query: a Query object to communicate with GMS.
            params: Generic map of parameters. Contains the structures for various components that
             are required for adding vertices to the Query Context with Add().
        """
        self._query = query
        self._params = params

        self._log = log_util.get_logger()
        self._sim_uuid: str = ''
        self._sim_node: TreeNode | None = None
        self._components_dir: str = ''
        self._ftype_counts: dict[str, int] = {}
        self._dset_time_units: str = ''
        self._start_date_time: str = ''
        self._times: list[float] = []
        self._mfsim_dir: str = params[0]['run_dir']
        self._solution_items: dict[str, list] = {}  # solution items to be added to XMS, grouped by model
        self._ugrid_uuids: dict[str, str] = {}  # model name -> ugrid uuid
        self._xms_data = XmsData(query)

        # These are for the current model and are updated for each model
        self._model_name: str = ''
        self._model_filename: str = ''
        self._model_ftype: str = ''
        self._model_uuid: str = ''
        self._model_ugrid_uuid: str = ''
        self._model_ugrid: UGrid | None = None  # Use self.model_ugrid() method, not this

    def model_ugrid(self) -> UGrid | None:
        """Return the UGrid of the current model."""
        if not self._model_ugrid:
            cogrid = self._xms_data.get_cogrid(self._model_uuid)
            if cogrid:
                self._model_ugrid = cogrid.ugrid
        return self._model_ugrid

    def read(self):
        """Reads the solution."""
        self._sim_uuid = self._query.parent_item_uuid()
        self._components_dir = os.path.join(self._query.xms_temp_directory, 'Components')
        self._sim_node = tree_util.find_tree_node_by_uuid(self._query.project_tree, self._sim_uuid)
        self._init_ugrid_uuids()
        self._ftype_counts = _calculate_ftype_counts(self._params[0]['model_ftypes'])

        # Iterate through models
        models_count: int = len(self._params[0]['model_files'])
        for i in range(models_count):
            self._set_current_model_variables(i)
            self._log.info(f'Reading solution for "{self._model_filename}"')

            # Get stuff needed to add the components
            self._set_up_time_info()
            dv_filename, budget_filename = _dep_var_and_budget_filenames(self._model_filename, self._model_ftype)

            # Add components for the different parts of the solution

            self._add_list_file()
            dset_uuid = self._add_dep_var_file(dv_filename)
            self._add_budget_file(budget_filename)
            self._add_csv_files()
            self._add_pest_obs(dset_uuid)
            self._add_swi_zeta_file()

        self._add_all_to_query()
        self._log.info('Solution read complete.')

    def _add_solution_item(self, solution_item) -> None:
        """Add the solution item to the list of solution items to be added at the end.

        Args:
            solution_item: An item that is part of the solution.
        """
        self._solution_items.setdefault(self._model_name, []).append(solution_item)

    def _add_all_to_query(self):
        """Adds all the components to the query.

        We do it this way so that they are added in the tree in the same order that the log indicates.
        """
        for model_name, item_list in reversed(list(self._solution_items.items())):
            for item in reversed(item_list):
                if isinstance(item, DatasetWriter):
                    self._query.add_dataset(item, folder_path=model_name)
                elif isinstance(item, CoverageCompInfo):
                    self._query.add_coverage(
                        item.coverage,
                        folder_path=model_name,
                        coverage_type=item.coverage_type,
                        model_name=item.model_name,
                        components=[item.component],
                        component_keywords=item.comp_data
                    )
                else:
                    self._query.add_component(item, folder_path=model_name)
        self._solution_items = []

    def _set_current_model_variables(self, index: int) -> None:
        """Sets the variables to the current model."""
        self._model_filename = self._params[0]['model_files'][index]
        self._model_name = self._params[0]['model_names'][index]
        self._model_ftype = self._params[0]['model_ftypes'][index]
        self._model_uuid = self._params[0]['model_uuids'][index]
        self._model_ugrid_uuid = self._ugrid_uuids.get(self._model_name, '')

    def _add_list_file(self):
        """Adds the .lst file."""
        list_filename = model_reader_base.list_file_from_model_name_file(self._model_filename)
        if list_filename and os.path.isfile(list_filename):
            txt_comp = self._build_solution_component(list_filename, 'TXT_SOL', 'Text File')
            self._add_solution_item(txt_comp)

    def _add_dep_var_file(self, dv_filename) -> str:
        """Adds the dependent variable (head, concentration, temperature) file as a dataset.

        Args:
            dv_filename: Dependent variable filepath.

        Returns:
            Uuid of the dataset, or '' if no dataset could be created.
        """
        if dv_filename and self._model_ugrid_uuid:  # Head or Concentration
            # Make a more unique name for the Head dataset if there are multiple models in this simulation.
            word_up = oc_first_word(self._model_ftype).title()
            dset_name = word_up if self._ftype_counts[self._model_ftype] == 1 else f'{word_up} ({self._model_name})'
            return self._add_dep_var_dset(self._model_ugrid_uuid, dv_filename, dset_name)
        return ''

    def _add_budget_file(self, budget_filename):
        """Adds the budget file."""
        if budget_filename and os.path.isfile(budget_filename):
            cbc_comp = self._build_solution_component(budget_filename, 'CBC', 'CBC')
            self._add_solution_item(cbc_comp)

    def _add_csv_files(self):
        """Adds any .csv files found."""
        output_dir = os.path.splitext(self._model_filename)[0] + '_output'
        csv_files = glob.glob(f'{output_dir}/*.csv')
        for file in csv_files:
            txt_comp = self._build_solution_component(file, 'CSV_SOL', 'CSV File')
            self._add_solution_item(txt_comp)

    def _add_pest_obs(self, dset_uuid: str) -> None:
        """Adds the PEST observation components.

        Args:
            dset_uuid: Uuid of the dependent variable (Head, Concentration, Temperature) dataset.
        """
        pest_obs_dir: str = pest_obs_runner.pest_obs_dir_from_model_filename(self._model_filename)
        if not os.path.isdir(pest_obs_dir):
            return

        pest_obs_runner.run_pest_batch_files(self._mfsim_dir, pest_obs_dir)
        obs_results: ObsResults = pest_obs_results_reader.read(pest_obs_dir, self._times)
        if not obs_results:
            return

        obs_results.dset_uuid = dset_uuid

        stats_filepath = self._write_obs_statistics(pest_obs_dir, obs_results)
        self._add_obs_statistics(stats_filepath)
        self._add_obs_coverages(pest_obs_dir, obs_results)

    def _add_swi_zeta_file(self):
        swi_zeta_file = _swi_zeta_file(self._model_filename)
        self._add_dep_var_dset(self._model_ugrid_uuid, swi_zeta_file, 'Zeta')

    def _add_dep_var_dset(self, ugrid_uuid, dv_filename, dset_name) -> str:
        """Adds the dependent variable (head, concentration, temperature) file as a dataset.

        Args:
            ugrid_uuid: UGrid uuid.
            dv_filename: Dependent variable filepath.
            dset_name: Dataset name.

        Returns:
            Uuid of the dataset, or '' if no dataset could be created.
        """
        if dv_filename and ugrid_uuid:
            # Convert the binary MODFLOW file to XMDF for XMS's reading pleasure.
            dis_enum = model_reader_base.dis_from_model_name_file(self._model_filename)
            dataset_writer, self._times = self._read_dep_var_file(dv_filename, ugrid_uuid, dset_name, dis_enum)
            self._log_component_add(dataset_writer.name, 'dataset', dv_filename)
            self._add_solution_item(dataset_writer)
            return dataset_writer.uuid
        return ''

    def _log_component_add(self, name: str, component_type: str, filepath: str) -> None:
        """Adds a log message for the component being added."""
        if filepath:
            relative_path = fs.compute_relative_path(self._mfsim_dir, filepath)
            self._log.info(f'Adding "{name}" {component_type} from "{relative_path}"')
        else:
            self._log.info(f'Adding "{name}" {component_type}')

    def _read_dep_var_file(self, filename, ugrid_uuid, dset_name, dis_enum):
        """Reads a MODFLOW binary head file.

        Args:
            filename (str): File path.
            ugrid_uuid (str): UUID of the UGrid.
            dset_name (str): Name to give the Head dataset.
            dis_enum (DisEnum): Tells what type of DIS/DISV/DISU we're dealing with.

        Returns:
            tuple[DatasetWriter, list[double]]: The dataset writer and the list of timestep times.
        """
        return dependent_variable_file_reader.read(
            dis_enum, filename, dset_name, ugrid_uuid, self._dset_time_units, self._start_date_time
        )

    def _build_solution_component(self, filepath: str, unique_name: str, display_name: str) -> Component:
        """Convenience function because the lines were too long.

        Args:
            filepath: Filepath to the solution file.
            unique_name: XML definition unique_name of the component to build
            display_name: display_name from xml. Gets used in log message 'as display_name'.

        Returns:
            The data_objects component to send back to XMS for the new component.
        """
        self._log_component_add(Path(filepath).name, display_name, filepath)
        return dmi_util.build_solution_component(
            filepath, self._sim_uuid, self._model_uuid, self._components_dir, unique_name
        )

    def _write_obs_statistics(self, pest_obs_dir: str, obs_results: ObsResults) -> str:
        return pest_obs_stats_writer.write(self._model_ftype, pest_obs_dir, obs_results)

    def _add_obs_statistics(self, stats_filepath):
        if stats_filepath:
            txt_comp = self._build_solution_component(stats_filepath, 'TXT_PEST_STATS', 'Text File')
            self._add_solution_item(txt_comp)

    def _add_obs_coverages(self, pest_obs_dir: str, obs_results: ObsResults):
        """Does what it says."""
        if not obs_results:
            return

        ugrid = self.model_ugrid()
        min_max_ugrid_pt_z = (ugrid.extents[0][2], ugrid.extents[1][2]) if ugrid else None
        self.pest_coverages = pest_obs_coverage_builder.build_coverages(pest_obs_dir, obs_results, min_max_ugrid_pt_z)
        if self.pest_coverages:
            for coverage, component, comp_data in self.pest_coverages:
                self._log_component_add(coverage.name, 'PEST observation coverage', '')
                info = CoverageCompInfo(coverage, component, comp_data, 'Observation Targets', 'Generic Coverages')
                self._add_solution_item(info)

    def _init_ugrid_uuids(self) -> None:
        """Initialize the dict of the model names -> UGrid UUIDs."""
        xms_types = ['TI_UGRID_PTR']
        ugrid_nodes = tree_util.descendants_of_type(tree_root=self._sim_node, xms_types=xms_types, allow_pointers=True)
        self._ugrid_uuids = {ugrid_node.parent.name: ugrid_node.uuid for ugrid_node in ugrid_nodes}

    def _set_up_time_info(self) -> None:
        """Sets the time units and starting date/time from the TDIS package.
        """
        tdis_node = tree_util.descendants_of_type(
            tree_root=self._sim_node,
            xms_types=['TI_COMPONENT'],
            unique_name='TDIS6',
            model_name='MODFLOW 6',
            recurse=False,
            only_first=True
        )
        tdis_time_units, self._start_date_time = tdis_reader.time_units_and_start_date_time(tdis_node.main_file)
        self._dset_time_units = units_util.dataset_time_units_from_tdis_time_units(tdis_time_units)


def _calculate_ftype_counts(model_ftypes):
    """Return a dict that has the unique ftypes in model_ftypes and how many times they appear.

    Args:
        model_ftypes (list[str]: The model ftypes.

    Returns:
        (dict[str, int]): See description.
    """
    ftype_counts = {}
    for ftype in model_ftypes:
        ftype_counts[ftype] = ftype_counts.get(ftype, 0) + 1
    return ftype_counts


def _dep_var_and_budget_filenames(model_filename: str, model_ftype: str):
    """Returns the dependent variable (head, concentration) and budget filenames.

    Args:
        model_filename: File path of model file.
        model_ftype: ftype of model (e.g. 'GWF6', 'GWT6')

    Returns:
        (tuple): tuple containing:
            - dependent_variable_filename (str): The head or concentration filename, or '' if none specified.
            - budget_filename (str): The budget filename, or '' if none specified.
    """
    packages = model_reader_base.packages_from_model_name_file(model_filename, 'OC6', first=True)
    oc_filename = packages[0].fname
    dv, budget = OcReader.dependent_variable_and_budget_files_from_oc_file(oc_filename, model_ftype)
    if dv or budget:
        model_dir = os.path.dirname(model_filename)
        if dv:
            dv = fs.resolve_relative_path(model_dir, dv)
        if budget:
            budget = fs.resolve_relative_path(model_dir, budget)
    return dv, budget


def _swi_zeta_file(model_filename: Path | str) -> str:
    """Returns the ZETA FILEOUT full filepath, if it's specified in the SWI package, else ''.

    Args:
        model_filename: File path of model file.

    Returns:
        See description.
    """
    zeta_file = ''
    packages = model_reader_base.packages_from_model_name_file(model_filename, 'SWI6', first=False)
    if packages:
        filename = packages[0].fname
        if Path(filename).is_file():
            with open(filename) as file:
                for line in file:
                    words = line.split()
                    if words and len(words) > 2 and words[0].upper() == 'ZETA' and words[1].upper() == 'FILEOUT':
                        pos = line.upper().find('FILEOUT')
                        if pos > -1:
                            zeta_file = line[pos + len('FILEOUT'):].strip()
                            model_dir = Path(model_filename).parent
                            zeta_file = fs.resolve_relative_path(model_dir, zeta_file)
                            break
    return zeta_file
