"""Class for managing interprocess communication with XMS."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.constraint import Grid as CoGrid, read_grid_from_file
from xms.data_objects.parameters import Coverage
from xms.datasets.dataset_reader import DatasetReader

# 4. Local modules
from xms.adh.components.bc_conceptual_component import BcConceptualComponent
from xms.adh.components.material_conceptual_component import MaterialConceptualComponent
from xms.adh.components.sediment_constituents_component import SedimentConstituentsComponent
from xms.adh.components.sediment_material_conceptual_component import SedimentMaterialConceptualComponent
from xms.adh.components.sim_component import SimComponent
from xms.adh.components.transport_constituents_component import TransportConstituentsComponent
from xms.adh.components.vessel_component import VesselComponent
from xms.adh.data.adh_query_data import AdhQueryData
from xms.adh.data.xms_data import XmsData


class XmsQueryData(XmsData):
    """Class for managing interprocess communication with XMS."""

    def __init__(self, query: Query, at_sim: bool = True, null_sim_item: bool = False):
        """Initialize the XmsData object.

        Args:
            query: Query object for XMS data.
            at_sim: Are we at the simulation tree item or a child?
            null_sim_item: There is no simulation item (for testing).
        """
        super().__init__()
        self.query = query
        self._at_sim = at_sim
        self.adh_data = AdhQueryData(self)
        if null_sim_item:
            self.sim_item = None
        else:
            self.sim_item = self._load_sim_item()

    def get_dataset_from_uuid(self, uuid: str) -> DatasetReader | None:
        """Get the dataset reader for the given UUID.

        Args:
            uuid (str): The UUID of the dataset.

        Returns:
            The dataset reader for the given UUID.
        """
        dataset = self.dataset_readers.get(uuid, None)
        if dataset is None:
            dataset = self.query.item_with_uuid(uuid)
            self.dataset_readers[uuid] = dataset
        return dataset

    def send(self):
        """Send data to XMS."""
        self.query.send()

    def _load_project_tree(self) -> TreeNode:
        """Load the project tree.

        Returns:
            The project tree root.
        """
        return self.query.project_tree

    def _load_xms_temp_directory(self) -> str:
        """Load the XMS temp directory.

        Returns:
            The XMS temp directory.
        """
        return self.query.xms_temp_directory

    def _load_sim_item(self) -> TreeNode | None:
        """Load the simulation tree item.

        Returns:
            The simulation tree item.
        """
        data_object = self.query.current_item() if self._at_sim else self.query.parent_item()
        current_uuid = data_object.uuid if data_object is not None else None
        current_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, current_uuid)
        if current_item is None or current_item.item_typename != "TI_DYN_SIM":
            sim_item = None
        else:
            sim_item = current_item
        return sim_item

    def _load_sim_name(self) -> str | None:
        """Load the simulation name.

        Returns:
            The simulation name.
        """
        sim_name = None
        if self.sim_item is not None:
            sim_name = self.sim_item.name
        return sim_name

    def _load_sim_uuid(self) -> str | None:
        """
        Loads the simulation's UUID (not it's component's UUID).

        Returns:
            str: The simulation UUID.
        """
        if self.sim_item is not None:
            return self.sim_item.uuid
        return None

    def _load_sim_component(self) -> SimComponent | None:
        """Load the simulation component.

        Returns:
            The simulation component.
        """
        manager_item = self.query.item_with_uuid(
            item_uuid=self.sim_item.uuid, model_name='AdH', unique_name='Sim_Manager'
        )
        sim_component = SimComponent(main_file=manager_item.main_file) if manager_item else None
        return sim_component

    def _load_bc_component(self) -> BcConceptualComponent | None:
        """Load the bc component.

        Returns:
            The bc component.
        """
        cov = self.bc_coverage
        if cov is None:
            return None

        bc_comp = self.query.item_with_uuid(item_uuid=cov.uuid, model_name='AdH', unique_name='BcConceptual')
        bc_comp = BcConceptualComponent(bc_comp.main_file)
        if bc_comp.cov_uuid != cov.uuid:
            bc_comp.cov_uuid = cov.uuid
            bc_comp.data.info.attrs['cov_uuid'] = cov.uuid

        self.query.load_component_ids(bc_comp, arcs=True, points=True)

        # Remove anything unused.
        bc_comp.clean_attributes()
        return bc_comp

    def _load_material_component(self) -> MaterialConceptualComponent | None:
        """Load the material component.

        Returns:
            The material component.
        """
        mat_comp = None
        if self.material_coverage is not None:
            do_comp = self.query.item_with_uuid(
                item_uuid=self.material_coverage.uuid, model_name='AdH', unique_name='MaterialConceptual'
            )
            mat_comp = MaterialConceptualComponent(do_comp.main_file)

            self.query.load_component_ids(mat_comp, polygons=True)
        return mat_comp

    def _load_transport_component(self) -> TransportConstituentsComponent | None:
        """Load the transport constituents component.

        Returns:
            The transport constituents component.
        """
        transport_component = None
        transport_uuid = ""
        if self.material_coverage is not None and self.material_component is not None:
            if self.material_component.data.info.attrs['transport_uuid']:
                transport_uuid = self.material_component.data.info.attrs['transport_uuid']
        if not transport_uuid and self.bc_coverage is not None and self.bc_component is not None:
            if self.bc_component.data.info.attrs['transport_uuid']:
                transport_uuid = self.bc_component.data.info.attrs['transport_uuid']
        if transport_uuid:
            do_comp = self.query.item_with_uuid(transport_uuid)
            if do_comp:
                transport_component = TransportConstituentsComponent(do_comp.main_file)
        return transport_component

    def _load_sediment_material_component(self) -> SedimentMaterialConceptualComponent | None:
        """Load the sediment material conceptual component.

        Returns:
            The sediment material conceptual component.
        """
        sediment_material_component = None
        if self.sediment_material_coverage:
            coverage_comp = self.query.item_with_uuid(
                self.sediment_material_coverage.uuid, model_name='AdH', unique_name='SedimentMaterialConceptual'
            )
            sediment_material_component = SedimentMaterialConceptualComponent(coverage_comp.main_file)

            self.query.load_component_ids(sediment_material_component, polygons=True)
        return sediment_material_component

    def _load_sediment_constituents_component(self) -> SedimentConstituentsComponent | None:
        """Load the sediment constituents component.

        Returns:
            The sediment constituents component.
        """
        sediment_constituents_component = None
        if self.sediment_material_component is not None:
            if self.sediment_material_component.data.info.attrs['sediment_transport_uuid']:
                sediment_uuid = self.sediment_material_component.data.info.attrs['sediment_transport_uuid']
                do_comp = self.query.item_with_uuid(sediment_uuid)
                if do_comp:
                    sediment_constituents_component = SedimentConstituentsComponent(do_comp.main_file)
        return sediment_constituents_component

    def _load_vessel_components(self) -> list[VesselComponent] | None:
        """Load the vessel components.

        Returns:
            The vessel components.
        """
        vessel_components = []
        if 'uuids' in self.adh_data.model_control.vessel_uuids:
            vessel_uuids = self.adh_data.model_control.vessel_uuids['uuids'].values.tolist()
            for uuid in vessel_uuids:
                do_component = self.query.item_with_uuid(uuid, model_name='AdH', unique_name='Vessel')
                if do_component is not None:
                    component = VesselComponent(do_component.main_file)
                    vessel_components.append(component)
                    self.query.load_component_ids(component, arcs=True)
        return vessel_components

    def _load_co_grid(self) -> CoGrid | None:
        """
        Loads the grid.

        Returns:
            CoGrid: The grid to be loaded.
        """
        mesh_item = self._get_tree_link('TI_MESH2D_PTR')
        do_ugrid = self.query.item_with_uuid(mesh_item.uuid) if mesh_item else None
        co_grid = read_grid_from_file(do_ugrid.cogrid_file) if do_ugrid else None
        return co_grid

    def _load_geom_uuid(self) -> str | None:
        """
        Loads the geom uuid.

        Returns:
            str: The geom uuid.
        """
        mesh_item = tree_util.descendants_of_type(
            self.sim_item, xms_types=['TI_MESH2D_PTR'], allow_pointers=True, recurse=False, only_first=True
        )
        if mesh_item is not None:
            geom_uuid = mesh_item.uuid
            return geom_uuid
        return None

    def _load_bc_coverage(self) -> Coverage | None:
        """Load the bc coverage.

        Returns:
            The bc coverage.
        """
        bc_item = tree_util.descendants_of_type(
            self.sim_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            coverage_type='Boundary Conditions',
            model_name='AdH',
            recurse=False,
            only_first=True
        )
        coverage = self.query.item_with_uuid(bc_item.uuid) if bc_item else None
        return coverage

    def _load_material_coverage(self) -> Coverage | None:
        """Load the material coverage.

        Returns:
            The material coverage.
        """
        mat_item = tree_util.descendants_of_type(
            self.sim_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            model_name='AdH',
            coverage_type='Materials',
            recurse=False,
            only_first=True
        )
        coverage = self.query.item_with_uuid(mat_item.uuid) if mat_item else None
        return coverage

    def _load_sediment_material_coverage(self):
        """Load the sediment material coverage.

        Returns:
            The sediment material coverage.
        """
        cov_item = tree_util.descendants_of_type(
            self.sim_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            model_name='AdH',
            coverage_type='Sediment Materials',
            recurse=False,
            only_first=True
        )
        coverage = self.query.item_with_uuid(cov_item.uuid) if cov_item else None
        return coverage

    def _load_output_coverage(self) -> Coverage | None:
        """Load the output coverage.

        Returns:
            The output coverage.
        """
        cov_item = tree_util.descendants_of_type(
            self.sim_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            model_name='AdH',
            coverage_type='Output',
            recurse=False,
            only_first=True
        )
        coverage = self.query.item_with_uuid(cov_item.uuid) if cov_item else None
        return coverage

    def _load_vessel_coverages(self) -> list[Coverage] | None:
        """Load the vessel coverages.

        Returns:
            The vessel coverages.
        """
        vessel_coverages = []
        if 'uuids' in self.adh_data.model_control.vessel_uuids:
            vessel_uuids = self.adh_data.model_control.vessel_uuids['uuids'].values.tolist()
            for uuid in vessel_uuids:
                coverage = self.query.item_with_uuid(uuid)
                if coverage is not None:
                    vessel_coverages.append(coverage)
        return vessel_coverages

    def _load_horizontal_units(self) -> str | None:
        """Returns the horizontal units.

        Returns:
            str: The horizontal units.
        """
        if self.query.display_projection is None:
            return None
        return self.query.display_projection.horizontal_units

    def _get_tree_link(self, xms_type: str, coverage_type: str | None = None) -> TreeNode:
        """
        Get a tree link.

        Args:
            xms_type: The XMS type.
            coverage_type: The coverage type.

        Returns:
            The tree link.
        """
        model_name = 'AdH' if coverage_type else None
        if self.sim_item is not None:
            item = tree_util.descendants_of_type(
                self.sim_item,
                xms_types=[xms_type],
                allow_pointers=True,
                model_name=model_name,
                coverage_type=coverage_type,
                recurse=False,
                only_first=True
            )
        else:
            item = None
        return item
