"""Module for the simulation comonent."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['PtmSimComponent']

# 1. Standard Python modules
from pathlib import Path
from typing import Callable, Optional

# 2. Third party modules
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.components.dmi.xms_data import XmsData
from xms.gmi.component_bases.sim_component_base import SimComponentBase
from xms.gmi.data.generic_model import Parameter, Section
from xms.gmi.gui.dataset_callback import DatasetRequest
from xms.guipy.dialogs.dataset_selector import DatasetSelector
from xms.guipy.dialogs.process_feedback_dlg import run_feedback_dialog

# 4. Local modules
from xms.ptm.feedback.export_mesh_files_thread import export_mesh_files
from xms.ptm.feedback.import_datasets_thread import ImportDatasetsThread
from xms.ptm.model.model import simulation_model


class PtmSimComponent(SimComponentBase):
    """Component for the PTM simulation."""
    def __init__(self, main_file: Optional[str | Path] = None):
        """
        Initialize the component class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self._section_dialog_keyword_args['hide_checkboxes'] = True
        self.tree_commands.append(('Export mesh files...', self._run_file_export))
        self.tree_commands.append(('Import external datasets...', self._import_datasets))

    def _get_global_parameter_section(self) -> Section:
        """Get the global parameter section."""
        return simulation_model()

    @staticmethod
    def _run_file_export(query: Query, _params: list[dict], parent: QWidget):
        """Menu callback for handling the command to export mesh files."""
        data = XmsData(query)
        # This needs to display a dialog, and that doesn't work well when done from a feedback thread, so we use a
        # function instead of the usual feedback thread pattern here.
        export_mesh_files(data, parent)

    @staticmethod
    def _import_datasets(query: Query, _params: list[dict], parent: QWidget):
        """Menu callback for handling the command to export mesh files."""
        data = XmsData(query)
        thread = ImportDatasetsThread(data)
        run_feedback_dialog(thread, parent)

    def _dataset_callback(self, request: DatasetRequest, parameter: Parameter) -> Optional[str | TreeNode]:
        """
        Handle a request for information when picking a dataset.

        See the base class version.
        """
        names = ['linked_elevation_dataset', 'linked_flow_dataset', 'd35_dataset', 'd50_dataset', 'd90_dataset']
        if request != DatasetRequest.GetTree or parameter.parameter_name not in names:
            return super()._dataset_callback(request, parameter)

        if parameter.parameter_name == 'linked_flow_dataset':
            filter_func = DatasetSelector.is_vector_if_dset
        else:
            filter_func = DatasetSelector.is_scalar_if_dset

        # DatasetSelectorMixin sets self._project_tree to this, but it only does it once, and we can't share that one
        # copy between the scalar and vector datasets, so we make our own clean copy here.
        clean_tree = self._query.copy_project_tree()
        sim_uuid = self._query.parent_item_uuid()  # We're at the component, so the sim is our parent.
        sim_node = tree_util.find_tree_node_by_uuid(clean_tree, sim_uuid)
        linked_grid_node = tree_util.descendants_of_type(
            sim_node,
            xms_types=['TI_MESH2D_PTR', 'TI_CGRID2D_PTR', 'TI_UGRID_PTR'],
            allow_pointers=True,
            only_first=True
        )
        if not linked_grid_node:
            return None

        grid_node = tree_util.find_tree_node_by_uuid(clean_tree, linked_grid_node.uuid)
        filter_tree(grid_node, filter_func)
        return grid_node


def filter_tree(node: TreeNode, condition: Callable[[TreeNode], bool]):
    """
    Remove nodes from a tree that don't satisfy a condition.

    This is similar to tree_util.filter_project_explorer, but the handling of children is different. Consider this tree,
    where Child-1 is the only node that satisfies the condition:

    Root --- Parent --- Child-1
                    |__ Child-2

    Both functions keep Root, regardless of whether it satisfies the condition.

    The other function discards any node that fails the condition, and all of its children, even if those children
    satisfy the condition. It will discard Parent (since it fails the condition), plus its children, leaving only Root.

    This function keeps any node that satisfies the condition, and all of its parents, even if those parents don't
    satisfy the condition. It will keep Parent (since Child-1 satisfies the condition) and only discard Child-2.

    Args:
        node: The root node to filter.
        condition: A callable that takes a single parameter, a node in the tree passed to this function, and returns
            whether that node (and its ancestors) should be kept in the tree.
    """
    filtered_children = []
    for child in node.children:
        filter_tree(child, condition)
        if child.children or condition(child):
            filtered_children.append(child)
    node.children = filtered_children
