"""Class for managing interprocess communication with XMS."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.tree import tree_util

# 4. Local modules
from xms.rsm.dmi.xms_data import XmsData


def cell_dataset_selection_tree(query, sim_component):
    """Filter the project tree for dataset selection.

    Args:
        query (Query): The XMS interprocess communication object
        sim_component (SimComponent): The simulation component

    Returns:
        (TreeNode): The filtered project explorer tree
    """
    xd = XmsData(query=query, sim_component=sim_component)
    tree_filter = _DatasetTreeFilter(xd)
    return tree_filter.filtered_tree()


class _DatasetTreeFilter:
    def __init__(self, xms_data, sim_component=None):
        """Constructor.

        Args:
            xms_data (XmsData): The XMS interprocess communication object
            sim_component (SimComponent): The simulation component
        """
        self.xms_data = xms_data
        self.sim_comp = sim_component
        self._ugrid_item = None
        self._tree_items_to_keep = set()
        self._dsets = []
        self._pe_tree = None

    def filtered_tree(self):
        """Filter the project tree for dataset selection.

        Returns:
            (TreeNode): Root of project explorer tree to filter
        """
        if self.sim_comp is not None and self.xms_data.ugrid_item is None:
            return None
        self._get_ugrid_item()
        self._get_cell_scalar_dsets()
        self._get_all_parents_of_scalar_dsets()
        self._filter_tree()
        return self._pe_tree

    def _filter_tree(self):
        """Filter the project tree."""
        tree_util.filter_project_explorer(self._pe_tree, self._filter_project_tree)

    def _get_ugrid_item(self):
        """Get the UGrid tree item."""
        self._pe_tree = self.xms_data.copy_project_tree()
        if self.sim_comp is not None:
            uuid = self.xms_data.ugrid_item.uuid
            self._ugrid_item = [tree_util.find_tree_node_by_uuid(self._pe_tree, uuid)]
        else:
            self._ugrid_item = tree_util.descendants_of_type(self._pe_tree, xms_types=['TI_UGRID_SMS'])

    def _filter_project_tree(self, item):
        """Filter the project tree.

        Args:
            item (:obj:`TreeNode`): The item to check

        Returns:
            (:obj:`bool`): True if the tree item is in the self._tree_items_to_keep set, False otherwise
        """
        return item in self._tree_items_to_keep

    def _get_all_parents_of_scalar_dsets(self):
        """Get all parents of scalar datasets."""
        # we also want all parents of _dsets up to the UGRID root
        for item in self._dsets:
            while item:
                self._tree_items_to_keep.add(item.parent)
                item = item.parent

    def _get_cell_scalar_dsets(self):
        """Get all cell based scalar datasets."""
        for ugrid_item in self._ugrid_item:
            dsets = tree_util.descendants_of_type(ugrid_item, xms_types=['TI_SCALAR_DSET'])
            dsets = [ds for ds in dsets if ds.data_location == 'CELL']
            self._dsets.extend(dsets)
        self._tree_items_to_keep.update(self._dsets)
