"""Module for XmsData."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['XmsData']

# 1. Standard Python modules
from functools import cached_property
import importlib
from importlib.metadata import entry_points
from pathlib import Path
from typing import Any, Callable, cast, Optional
import uuid
import xml.etree.ElementTree as Tree
from xml.etree.ElementTree import ParseError

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.api.tree import tree_util, TreeNode
from xms.constraint import read_grid_from_file
from xms.constraint.grid import Grid as CoGrid
from xms.data_objects.parameters import Component, Coverage, Projection, Simulation, UGrid as DoGrid
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.components.bases.component_with_menus_base import ComponentWithMenusBase
from xms.components.bases.visible_coverage_component_base import VisibleCoverageComponentBase
from xms.components.bases.visible_coverage_component_base_data import VisibleCoverageComponentBaseData
from xms.components.display.display_options_helper import DisplayOptionsHelper

# I think these are all the types there are, unless you want to count TI_SCAT2D_PTR.
# TI_UGRID_PTR seems to encompass all types of UGrids, including quadtrees and Cartesian constrained ones.
all_grid_pointer_types = ['TI_MESH2D_PTR', 'TI_CGRID2D_PTR', 'TI_UGRID_PTR']


class LazyDict(dict):
    """A dictionary that lazily loads values for its keys."""
    def __init__(self, factory: Callable[[Any], Any]):
        """
        Initialize the dictionary.

        Args:
            factory: Called when a key is first requested. The factory will be passed the key, and should return the
                value the dictionary should associate with that key. The returned value will be cached, so the factory
                will only be called once per key.
        """
        super().__init__()
        self.default_factory = factory

    def __missing__(self, key):
        """
        Handle a missing key by calling the default factory and inserting its response into the dictionary.
        """
        self[key] = self.default_factory(key)
        return self[key]


class XmsData:
    """Higher level wrapper around Query."""
    def __init__(self, query: Optional[Query] = None):
        """
        Initialize the class.

        Args:
            query: Interprocess communication object. If not provided, read operations will fail and write operations
                will store data in attributes on the class. When running under XMS, this should generally be provided.
                Omitting it is mainly useful for when running under tests.
        """
        #
        # Public attributes
        #
        #

        # Dictionary of dataset_uuid -> dataset. Accessing a dataset will fetch it from XMS the first time its UUID is
        # requested, then retrieve it from the cache each time thereafter. If XMS doesn't have an item with the
        # requested UUID, then None is retrieved instead. If the requested UUID exists but is not a dataset, then an
        # AssertionError is raised on the assumption that this likely indicates a bug (this may be reduced to KeyError
        # or returning None later if someone finds a use case where it's actually normal, but for now it's assumed bad).
        # Tests can assign an ordinary dictionary to this when mocking.
        self.datasets: dict[str, Optional[DatasetReader]] = LazyDict(self._get_dataset)
        # Dictionary of items that exist in XMS (item_uuid -> exists)
        self.existing_items: dict[str, bool] = {}

        #
        # Internal attributes
        #
        self._query = query
        # (model_name, unique_name) -> (module_name, class_name)
        self._component_info: dict[tuple[str, str], tuple[str, str]] = {}
        # (module_name, class_name) -> (model_name, unique_name)
        self._model_info: dict[tuple[str, str], tuple[str, str]] = {}
        self._xml_files: Optional[list[Path]] = None
        self._sim_uuid: str = ''

        #
        # Attributes used for mocking
        #

        # get_linked_coverage looks here first before consulting Query. Maps component_type -> (coverage, data).
        self.linked_coverages: dict[type, tuple[Optional[Coverage], Optional[VisibleCoverageComponentBaseData]]] = {}
        # add_simulation populates added_simulations, added_simulation_names, and added_simulation_uuids when called.
        self.added_simulations = []
        self.added_simulation_names: list[str] = []
        self.added_simulation_uuids: list[str] = []
        # add_coverage adds the coverage's UUID into this list when instructed to link the item.
        self.linked_uuids: list[str] = []
        # add_coverage places its parameters into this list.
        self.added_coverages: list[tuple[Coverage, VisibleCoverageComponentBase]] = []
        # add_generic_coverage places its parameters into this list.
        self.added_generic_coverages: list[Coverage] = []
        # add_particle_set places its parameters into this list.
        self.added_particle_sets: list[Path] = []
        # unlink_all_grids_except sets this to the UUID of the one grid that was kept.
        self.unlinked_all_grids_except: str = ''
        # add_grid places its parameters into this list.
        self.added_grids: list[tuple[CoGrid, str, Projection]] = []
        # add_dataset places its parameters into this list.
        self.added_datasets: list[DatasetReader | DatasetWriter] = []

        # The progress callback passed to self.set_progress_callback.
        self.progress_callback: Optional[Callable[[], None]] = None
        # The percentage passed to self.set_progress()
        self.progress_percent: float = -1.0

    @cached_property
    def file_to_read(self) -> Path:
        """The file that XMS wants the import script to read."""
        return Path(self._query.read_file)

    @cached_property
    def model_name(self) -> str:
        """The name of the model."""
        node = self._sim_tree_node
        assert node
        return node.model_name

    @property
    def added_simulation(self) -> ComponentWithMenusBase:
        """The simulation that was added."""
        assert len(self.added_simulations) == 1
        return self.added_simulations[0]

    @property
    def added_simulation_name(self) -> str:
        """The simulation name that was added."""
        assert len(self.added_simulation_names) == 1
        return self.added_simulation_names[0]

    def add_simulation(
        self, sim_component: ComponentWithMenusBase, name: str, update_default_parent: bool = True
    ) -> str:
        """
        Add a simulation to XMS.

        Args:
            sim_component: Hidden component to attach to the simulation.
            name: Name to assign the simulation. Appears in the project explorer tree.
            update_default_parent: Whether to update the parent uuid used by self.link_item to be this one.
        """
        sim_uuid = str(uuid.uuid4())
        if not self._query:
            self.added_simulations.append(sim_component)
            self.added_simulation_names.append(name)
            self.added_simulation_uuids.append(sim_uuid)
            return sim_uuid

        model_name = self._find_model_name(sim_component.module_name, sim_component.class_name)
        do_component = Component(
            name=name,
            comp_uuid=sim_component.uuid,
            main_file=str(sim_component.main_file),
            class_name=sim_component.class_name,
            module_name=sim_component.module_name,
        )

        do_sim = Simulation(model=model_name, sim_uuid=sim_uuid, name=name)

        if update_default_parent:
            self._sim_uuid = sim_uuid
        self._query.add_simulation(do_sim, [do_component])
        return sim_uuid

    @cached_property
    def simulation_data(self) -> Optional:
        """The data manager for the current simulation, or None."""
        sim_node = self._sim_tree_node
        if sim_node is None:
            return None
        model_name = sim_node.model_name
        unique_name = sim_node.unique_name
        item_uuid = sim_node.uuid
        # sim_node has a component_uuid on it, but Query.item_with_uuid won't give us the component itself without the
        # model and unique names, even if we ask for the component UUID specifically.
        main_file = self._query.item_with_uuid(item_uuid, model_name=model_name, unique_name=unique_name).main_file
        component_type: type[ComponentWithMenusBase] = self._find_component_type(model_name, unique_name)
        component = component_type(main_file)
        # component.data will probably be there in practice, but it would be good to actually add it somewhere.
        return component.data

    @cached_property
    def simulation_name(self) -> str:
        """The name of the simulation, or an empty string."""
        sim_node = self._sim_tree_node
        if sim_node is None:
            return ''
        return sim_node.name

    def add_coverage(self, coverage: Coverage, component: VisibleCoverageComponentBase):
        """
        Add a coverage to be sent to XMS.

        Args:
            coverage: The coverage to send.
            component: The component to attach to the coverage.
        """
        if not self._query:
            self.added_coverages.append((coverage, component))
            return

        model_name = self._find_model_name(component.module_name, component.class_name)

        keywords = _build_keywords(component)

        do_component = Component(
            comp_uuid=component.uuid,
            main_file=str(component.main_file),
            class_name=component.class_name,
            module_name=component.module_name,
        )
        coverage_type = component.unique_name()
        self._query.add_coverage(
            coverage,
            model_name=model_name,
            coverage_type=coverage_type,
            components=[do_component],
            component_keywords=keywords
        )

    def add_generic_coverage(self, coverage: Coverage):
        """
        Add a generic coverage to be sent to XMS.

        Args:
            coverage: The coverage to send.
        """
        if not self._query:
            self.added_generic_coverages.append(coverage)
            return

        self._query.add_coverage(coverage)

    def add_particle_set(self, file: str | Path):
        """
        Add a particle set to XMS.

        Args:
            file: Path to the file containing the particle set.
        """
        if not self._query:
            self.added_particle_sets.append(Path(file))
            return

        self._query.add_particles(str(file))

    def add_dataset(self, dataset: DatasetReader | DatasetWriter):
        """
        Add a dataset to XMS.

        Note that `self.link_item` doesn't work on datasets. Attempting to use it with a dataset UUID will fail
        silently. To make the dataset link to the desired geometry, ensure the dataset's `geom_uuid` attribute is set to
        the UUID of the geometry you want it linked to.
        """
        if not self._query:
            self.added_datasets.append(dataset)
            return

        self._query.add_dataset(dataset)

    def link_item(self, item_uuid: str):
        """
        Link an item to the simulation.

        Args:
            item_uuid: The UUID of the item to link.

                When linking coverages (with or without components), this should be the UUID of the coverage
                (not the component). For other geometry (UGrids, scatter sets, etc.) this should be the geometry's UUID.

                XMS links datasets automatically based on their parent geometry UUID, so this method isn't necessary for
                them.
        """
        if self._query:
            self._query.link_item(self._sim_uuid, item_uuid)
        else:
            self.linked_uuids.append(item_uuid)

    def get_linked_coverage(
        self, coverage_type: type[VisibleCoverageComponentBase]
    ) -> tuple[Optional[Coverage], Optional[VisibleCoverageComponentBaseData]]:
        """
        Get the specified coverage that is linked to the simulation.

        This method is intended for cases where all three of the following apply:
        - Only one coverage with the specified type can be linked to the simulation at a time
        - The coverage has an attached component
        - The attached component derives from VisibleCoverageComponentBase

        Cases where one or more of those is false are unsupported.

        Code that calls this will typically look something like
        ```
        coverage, data = xms_data.get_linked_coverage(ExampleCoverageComponent)
        if coverage is not None:
            do_stuff_with(coverage, data)
        else:
            do_stuff_without_coverage()
        ```

        Args:
            coverage_type: The type of the coverage component that should be attached to the desired coverage. This
                should be the type itself, not an instance.

        Returns:
            Tuple of (coverage, data). If no coverage with the requested component type is linked to the simulation,
            returns (None, None) instead.
        """
        self._ensure_linked_coverage_exists(coverage_type)
        return self.linked_coverages[coverage_type]

    def add_grid(self, cogrid: CoGrid, name: str, projection: Projection):
        """
        Add a constrained UGrid to XMS.

        Args:
            cogrid: The UGrid to add.
            name: Name to assign the grid.
            projection: The UGrid's native projection.
        """
        if not self._query:
            self.added_grids.append((cogrid, name, projection))
            return

        xmc_file = Path(XmEnv.xms_environ_process_temp_directory()) / f'{cogrid.uuid}.xmc'
        xmc_file = str(xmc_file)
        cogrid.write_to_file(xmc_file)
        ugrid = DoGrid(xmc_file, name=name, uuid=cogrid.uuid, projection=projection)
        self._query.add_ugrid(ugrid)

    @cached_property
    def linked_grid(self) -> Optional[CoGrid]:
        """
        The UGrid currently linked to the simulation, or None if no grid is linked.

        This assumes there is only one grid currently linked to the simulation (e.g. because your link event kicks out
        the old grid when there's more than one), and that you don't care what type of grid it happens to be (e.g.
        because the XML restricts to only one type).
        """
        node = self._linked_grid_node
        if not node:
            return None

        grid_uuid = node.uuid
        do_ugrid: DoGrid = self._query.item_with_uuid(grid_uuid)
        file = do_ugrid.cogrid_file
        co_grid = read_grid_from_file(file)
        return co_grid

    def unlink_all_grids_except(self, grid_uuid: str):
        """Unlink all grids that are linked to the simulation except the one with the given UUID."""
        if not self._query:
            self.unlinked_all_grids_except = grid_uuid
            return

        sim_node = self._sim_tree_node
        grid_nodes = tree_util.descendants_of_type(sim_node, xms_types=all_grid_pointer_types, allow_pointers=True)
        for node in grid_nodes:
            if node.uuid != grid_uuid:
                self._query.unlink_item(sim_node.uuid, node.uuid)

    def send(self):
        """Send any data to XMS if this was constructed with a query, else does nothing."""
        if self._query is not None:
            self._query.send()

    def start_progress_loop(self, progress_callback: Callable[[], None]):
        """
        Start a loop which will call `progress_callback` periodically to send progress updates to XMS.

        The progress loop is responsible for updating the progress bar on the Model Run Queue dialog. Calling this
        method and starting a loop only makes sense from a progress script. If this was initialized with a real `Query`,
        that `Query` must have been created via `Query(progress_script=True)`.

        Args:
            progress_callback: A callable that takes no parameters and returns nothing. The progress loop will
                call it when it wants the progress tracker to update the progress. The frequency of calls is user
                configurable and should not be assumed.
        """
        if not self._query:
            self.progress_callback = progress_callback
            return

        prog = self._query.xms_agent.session.progress_loop
        prog.set_progress_function(progress_callback)
        prog.start_loop()

    @cached_property
    def model_stdout_file(self) -> Path:
        """
        The file that the currently running model should be writing its stdout to.

        This property only works in progress scripts. If this was initialized with a real `Query`, that `Query` must
        have been created via `Query(progress_script=True)`.
        """
        prog = self._query.xms_agent.session.progress_loop
        stdout_file = prog.command_line_output_file
        return Path(stdout_file)

    def set_progress(self, progress_percent: float):
        """
        Set the current progress percent for a model run script.

        This effectively controls the progress bar on the Model Run Queue dialog.

        This method is only usable from model run progress trackers. If this was initialized with a real `Query`, that
        `Query` must have been created via `Query(progress_script=True)`.

        Args:
            progress_percent: The model's current progress toward completing its run. This should be a value from
                0.0 (just started) to 1.0 (completely finished).

                Note that this is different from `Query.update_progress_percent` (which it aims to replace), which takes
                values from 0 to 100. Support for the alternative could be added, but it's expected that 0.0 to 1.0 will
                be a more natural range for most models, so that's what this takes for now.
        """
        progress_percent = int(progress_percent * 100)
        assert 0 <= progress_percent <= 100

        if not self._query:
            self.progress_percent = progress_percent
            return

        self._query.update_progress_percent(progress_percent)

    @cached_property
    def simulation_folder(self) -> Optional[Path]:
        """
        The folder where the simulation's model-native files were or will be exported to.

        Will be `None` if the project wasn't saved yet, since the folder could end up almost anywhere depending on where
        the user decides to save the project to.

        If a real path, the path may not exist, e.g. because the user deleted it or hasn't exported the simulation yet.
        """
        project_path = self._query.xms_project_path
        if not project_path:
            return None
        project_path = Path(project_path)
        models_path = project_path.with_name(project_path.stem + '_models')
        folder = models_path / self.model_name / self.simulation_name
        return folder

    def item_exists(self, item_uuid: str) -> bool:
        """
        Checks if an item exists in the project tree.

        Args:
            item_uuid: The UUID of the item to search for.

        Returns:
            bool: True if the item is found, False otherwise.
        """
        if item_uuid not in self.existing_items:
            tree = self._query.copy_project_tree()
            node = tree_util.find_tree_node_by_uuid(tree, item_uuid)
            self.existing_items[item_uuid] = node is not None

        return self.existing_items[item_uuid]

    def _ensure_linked_coverage_exists(self, coverage_type: type[VisibleCoverageComponentBase]):
        """
        Populate self.linked_coverages with the appropriate value for the given coverage type.

        self.linked_coverages[coverage_type] will be assigned after this returns. If the requested type was linked, the
        value will be a tuple of (coverage, data), and the data manager will have its component IDs initialized. If not
        found, the value will be a tuple of (None, None).
        """
        if coverage_type in self.linked_coverages:
            return

        sim_node = self._sim_tree_node
        model_name = sim_node.model_name
        unique_name = coverage_type.unique_name()
        coverage_node = tree_util.descendants_of_type(
            sim_node, model_name=model_name, unique_name=unique_name, allow_pointers=True, only_first=True
        )
        if not coverage_node:
            self.linked_coverages[coverage_type] = (None, None)
            return

        coverage_uuid = coverage_node.uuid
        coverage = self._query.item_with_uuid(coverage_uuid)
        main_file = self._query.item_with_uuid(coverage_uuid, model_name=model_name, unique_name=unique_name).main_file
        component_type = self._find_component_type(model_name, unique_name)
        # self._find_component_type can find anything derived from ComponentWithMenusBase, so that's what type it
        # declares it returns. But we're asking for a coverage, so it should specifically derive from
        # VisibleCoverageComponentBase instead.
        component_type = cast(type[VisibleCoverageComponentBase], component_type)
        component = component_type(main_file)
        self._query.load_component_ids(component, points=True, arcs=True, polygons=True)
        self.linked_coverages[coverage_type] = (coverage, component.data)

    @cached_property
    def _sim_tree_node(self) -> Optional[TreeNode]:
        """The current sim tree node, or None."""
        tree = self._query.copy_project_tree()
        current_node = tree_util.find_tree_node_by_uuid(tree, self._query.current_item_uuid())
        if current_node is None:
            current_node = tree_util.find_tree_node_by_uuid(tree, self._query.parent_item_uuid())
        if current_node is None or current_node.item_typename != 'TI_DYN_SIM':
            return None
        return current_node

    def _get_dataset(self, dataset_uuid: str) -> Optional[DatasetReader]:
        """
        Get a dataset from XMS.

        This method doesn't do caching. Outside code is responsible for that.

        Args:
            dataset_uuid: UUID of the dataset to retrieve.

        Returns:
            The dataset with the given UUID if it exists, or None if nothing matches the UUID.
        """
        dataset = self._query.item_with_uuid(dataset_uuid)
        if not dataset or isinstance(dataset, DatasetReader):
            return dataset

        # We got something, but it wasn't a dataset. Given how unlikely this is to happen by chance, we're going to
        # assume the caller passed the wrong variable and tell them something is wrong.
        raise AssertionError(f'Expected a dataset, XMS sent {dataset} instead.')

    @cached_property
    def _linked_grid_node(self) -> Optional[TreeNode]:
        """The tree node for the UGrid linked to the simulation, or None if no grid is linked."""
        sim_node = self._sim_tree_node
        grid_node = tree_util.descendants_of_type(
            sim_node, xms_types=all_grid_pointer_types, only_first=True, allow_pointers=True
        )
        return grid_node

    def _find_component_type(self, model_name: str, unique_name: str) -> type[ComponentWithMenusBase]:
        """Find the component type that matches a given model and unique name."""
        self._ensure_component_info_exists(model_name, unique_name)
        module_name, class_name = self._component_info[(model_name, unique_name)]

        module = importlib.import_module(module_name)
        component_type = getattr(module, class_name)
        return component_type

    def _find_model_name(self, module_name: str, class_name: str) -> str:
        """Find the name of the model that defines a given module and class name."""
        self._ensure_model_info_exists(module_name, class_name)
        model_name, _unique_name = self._model_info[(module_name, class_name)]
        return model_name

    def _ensure_component_info_exists(self, model_name: str, unique_name: str):
        """Ensure self._component_info contains an entry for the given model and unique name."""
        self._find_xml_files()

        while self._xml_files and (model_name, unique_name) not in self._component_info:
            self._process_next_xml()

        if (model_name, unique_name) not in self._component_info:
            raise AssertionError(
                f'Could not find component info for {model_name}|{unique_name}. '
                'Ensure it is listed in the XML and that the package is installed.'
            )

    def _ensure_model_info_exists(self, module_name: str, class_name: str):
        """Ensure self._model_info contains an entry for the given module and class name."""
        self._find_xml_files()

        while self._xml_files and (module_name, class_name) not in self._model_info:
            self._process_next_xml()

        if (module_name, class_name) not in self._model_info:
            raise AssertionError(
                f'Could not find model info for {module_name}|{class_name}. '
                'Ensure it is listed in the XML and that the package is installed.'
            )

    def _find_xml_files(self):
        """
        Find all the XML files that probably contain XMS DMI definitions.

        Only does anything the first time. Does nothing on subsequent calls.
        """
        if self._xml_files is not None:
            return

        self._xml_files = []
        xms_entry_points = entry_points(group='xms.dmi.interfaces')
        for entry_point in xms_entry_points:
            # If there are classifiers, .get_all() returns a list of them. But if there are *not* classifiers, it
            # returns None instead.
            classifiers = entry_point.dist.metadata.get_all('Classifier') or []
            classifiers = filter(lambda c: c.startswith('XMS DMI Definition'), classifiers)
            for classifier in classifiers:
                self._add_xml_file(entry_point, classifier)

    def _add_xml_file(self, entry_point, classifier: str):
        """
        Add an XML file for an entry point.

        Args:
            entry_point: The entry point containing the file.
            classifier: The classifier to parse the XML file out of.
        """
        xml_file = classifier.split('::')[-1].strip()
        # This returns a "SimplePath", which is different from Path in ways I don't want to bother figuring out.
        file_path = entry_point.dist.locate_file(xml_file)
        if not file_path.is_absolute() or not file_path.exists():  # pragma: nocover
            # This happens for editable installations. I couldn't figure out a good way to test it.
            dist_path = Path(str(entry_point.load().__path__[0]))
            dist_path = dist_path.parent.parent
            file_path = dist_path / xml_file
        self._xml_files.append(file_path)

    def _process_next_xml(self) -> None:
        """
        Process the next XML file.

        Assumes there is a next XML file to process.
        """
        # The files we're interested in should have this kind of structure:
        #
        # <dynamic_model filetype="dynamic model">
        #   <model_name="..."></model>
        #   <component unique_name="..." module_name="..." class_name="...">
        #
        # Where all the ... are pieces of an entry we want to create.
        xml_path = self._xml_files.pop()

        try:
            root = Tree.parse(xml_path).getroot()
        except ParseError:
            return

        if root.tag != 'dynamic_model':
            return

        if 'filetype' not in root.attrib or root.attrib['filetype'] != 'dynamic model':
            return

        model_name = ''
        info = []

        for child in root:
            if child.tag == 'model' and 'name' in child.attrib:
                model_name = child.attrib['name']

            if child.tag == 'component':
                if 'module_name' in child.attrib and 'class_name' in child.attrib and 'unique_name' in child.attrib:
                    module_name = child.attrib['module_name']
                    class_name = child.attrib['class_name']
                    unique_name = child.attrib['unique_name']
                    info.append((module_name, class_name, unique_name))

        if not model_name:
            return

        for module_name, class_name, unique_name in info:
            self._component_info[(model_name, unique_name)] = (module_name, class_name)
            self._model_info[(module_name, class_name)] = (model_name, unique_name)


def _build_keywords(component: VisibleCoverageComponentBase):
    """
    Build the component keywords to send to XMS.

    If we're in an import script, then XMS will ignore the keywords we send and run the component's create event to get
    the ID mapping. If we're in an ActionRequest callback, however, then XMS will ignore the create event instead and
    insist on using these keywords to associate IDs.

    It would be nice if XMS would do the same thing in both cases, but it doesn't, and we don't have a good way to tell
    which case we're in, so we'll just always send keywords and let XMS ignore them if they're unnecessary.
    """
    ids = component.get_component_coverage_ids()
    with DisplayOptionsHelper(component.main_file) as helper:
        messages = helper.get_update_messages(component.cov_uuid)

    keywords = {
        'component_coverage_ids': ids,
        'display_options': [list(message) for message in messages],
    }

    return [keywords]
