"""Class for managing interprocess communication with XMS."""

__all__ = ['XmsData', 'MISSING']

# 1. Standard Python modules
import binascii
from functools import cached_property
from pathlib import Path
from typing import Any, Callable, Optional, Type
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query, XmsEnvironment as XmEnv
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file, UGrid2d
from xms.data_objects.parameters import Component, Coverage, Projection, Simulation, UGrid
from xms.datasets.dataset_reader import DatasetReader
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.hydroas.components.coverage_component import CoverageComponent
from xms.hydroas.components.material_component import MaterialComponent
from xms.hydroas.components.roughness_component import RoughnessComponent
from xms.hydroas.data.coverage_data import CoverageData
from xms.hydroas.data.mapped_material_data import MappedMaterialData
from xms.hydroas.data.material_data import MaterialData
from xms.hydroas.data.sim_data import SimData


class DefaultDict(dict):
    def __init__(self, factory: Callable):
        super().__init__()
        self.default_factory = factory

    def __missing__(self, key: str):
        self[key] = self.default_factory(key)
        return self[key]


MISSING = object()


class XmsData:
    """Class for managing interprocess communication with XMS."""

    model_name = 'HydroAS'
    sim_component_unique_name = 'SimComponent'
    bc_component_unique_name = 'CoverageComponent'
    material_component_unique_name = 'MaterialComponent'
    mapped_material_component_unique_name = 'MappedMaterialComponent'

    def __init__(
        self,
        query: Optional[Query] = None,
        sim_data: Optional[SimData] = None,
        ugrid: Optional[UGrid2d] = None,
        projection: Optional[Projection] = None,
        linked_ugrid_uuids: Optional[list[str]] = None,
        xms_version: str = '',
        bc_coverages: Optional[list] = None,
        material_coverage: Optional[tuple] = None,
        mapped_material_data: Optional[MappedMaterialData] = None,
        ugrid_name: str = '',
        grid_hash: str = '',
        import_file: str | Path = '',
    ):
        """
        Initialize the class.

        Most parameters will take the MISSING object defined above to override them to be None without a Query.

        Args:
            query: Interprocess communication object if passed.
            sim_data: Overrides self.sim_data if passed.
            ugrid: Overrides self.ugrid if passed.
            projection: Overrides self.projection if passed.
            linked_ugrid_uuids: Overrides self.linked_ugrid_uuids if passed.
            xms_version: Overrides self.xms_version if passed.
            bc_coverages: Overrides self.bc_coverages if passed.
            material_coverage: Overrides self.material_coverage if passed.
            mapped_material_data: Overrides self.mapped_material_data if passed.
            ugrid_name: Overrides self.ugrid_name if passed.
            grid_hash: Overrides self.grid_hash if passed.
            import_file: Overrides self.import_file if passed.
        """
        self.query = query
        self.datasets: dict[str, Optional[DatasetReader]] = DefaultDict(self._get_dataset)
        self.geometries: dict[str, Optional[UGrid2d]] = DefaultDict(self._get_ugrid)
        if sim_data is not None:
            self.sim_data = sim_data
        if ugrid is MISSING:
            self.ugrid = None
        elif ugrid is not None:
            self.ugrid = ugrid
        if linked_ugrid_uuids is not None:
            self.linked_ugrid_uuids = linked_ugrid_uuids
        if projection is MISSING:
            self.projection = None
        elif projection is not None:
            self.projection = projection
        if xms_version:
            self.xms_version = xms_version
        if bc_coverages is not None:
            self.bc_coverages = bc_coverages
        if material_coverage is MISSING:
            self.material_coverage = (None, None)
        elif material_coverage is not None:
            self.material_coverage = material_coverage
        if mapped_material_data is MISSING:
            self.mapped_material_data = None
        elif mapped_material_data is not None:
            self.mapped_material_data = mapped_material_data
        if ugrid_name is MISSING:
            self.ugrid_name = ''
        elif ugrid_name:
            self.ugrid_name = ugrid_name
        if import_file:
            self.import_file = Path(import_file)
        if grid_hash:
            self._grid_hash = grid_hash
        else:
            self._grid_hash = None

        self.added_sim: Optional[SimData] = None
        self.added_sim_name: str = ''
        self.added_ugrid: Optional[UGrid2d] = None
        self.added_ugrid_name: str = ''
        self.added_projection: Optional[Projection] = None
        self.added_coverages: list[tuple[Coverage, CoverageComponent]] = []
        self.added_materials: Optional[MappedMaterialData] = None
        self.linked_items: list[str] = []
        self.unlinked_items: list[str] = []
        self.finished = False

        self._grid_hashes = {}

    @cached_property
    def sim_data(self) -> SimData:
        """The data manager for the simulation the script is running for."""
        sim_id = self.query.current_item_uuid()
        simulation = self.query.item_with_uuid(
            sim_id, model_name=self.model_name, unique_name=self.sim_component_unique_name
        )

        sim_main_file = simulation.main_file
        data = SimData(sim_main_file)
        return data

    @cached_property
    def projection(self) -> Projection:
        """The current display projection."""
        return self.query.display_projection

    @cached_property
    def xms_version(self) -> str:
        """The version of XMS the script is running under."""
        return self.query.xms_app_version

    @cached_property
    def import_file(self) -> Path:
        """Path to the file being imported."""
        return Path(self.query.read_file)

    def add_ugrid(self, ugrid_2d: UGrid2d, name: str, projection: Projection):
        """
        Add a UGrid to XMS.

        Args:
            ugrid_2d: The UGrid to add.
            name: Name to assign the mesh.
            projection: The UGrid's native projection.
        """
        ugrid_uuid = str(uuid.uuid4())
        xmc_file = Path(XmEnv.xms_environ_process_temp_directory()) / f'{ugrid_uuid}.xmc'
        xmc_file = str(xmc_file)
        # SMS will assign a UUID if we don't, but if we leave it blank then it won't be included in the file, which
        # means it won't be part of the grid hash. When we export later, SMS *will* include it, and it *will* be part
        # of the grid hash. We assign it here to ensure its presence is consistent.
        ugrid_2d.uuid = ugrid_uuid
        ugrid_2d.write_to_file(xmc_file)
        grid_hash = compute_crc(xmc_file)
        self._grid_hash = grid_hash

        if not self.query:
            self.added_ugrid = ugrid_2d
            self.added_ugrid_name = name
            self.added_projection = projection
            return

        ugrid = UGrid(xmc_file, name=name, uuid=ugrid_uuid, projection=projection)

        self._grid_hash = compute_crc(xmc_file)
        self.query.add_ugrid(ugrid)
        self.query.link_item(self._sim_uuid, ugrid.uuid)

    def add_coverage(self, coverage: Coverage, component: CoverageComponent, keywords: Any = None):
        """
        Add a coverage to XMS.

        Args:
            coverage: The coverage to add.
            component: The coverage's hidden component.
            keywords: Component keywords to send to XMS. These should come from the coverage builder.
        """
        if not self.query:
            self.added_coverages.append((coverage, component))
            return

        do_component = Component(
            comp_uuid=component.uuid,
            main_file=str(component.main_file),
            class_name=component.class_name,
            module_name=component.module_name,
        )
        coverage_type = component.unique_name()
        self.query.add_coverage(
            coverage,
            model_name='HydroAS',
            coverage_type=coverage_type,
            components=[do_component],
            component_keywords=keywords
        )

    def add_sim(self, name: str, sim_component):
        """
        Add a simulation to XMS.

        Args:
            name: Name to assign the simulation.
            sim_component: Hidden component to attach to the simulation.
        """
        if not self.query:
            self.added_sim = sim_component.data
            self.added_sim_name = name
            return

        sim_uuid = str(uuid.uuid4())
        do_component = Component(
            name=name,
            comp_uuid=sim_component.uuid,
            main_file=str(sim_component.main_file),
            class_name=sim_component.class_name,
            module_name=sim_component.module_name,
        )

        do_sim = Simulation(model='HydroAS', sim_uuid=sim_uuid, name=name)

        self.query.add_simulation(do_sim, [do_component])
        self._sim_uuid = sim_uuid

    def add_mapped_materials(self, component):
        """
        Add a coverage to XMS.

        Args:
            component: The component to add.
        """
        if not self.query:
            self.added_materials = component.data
            return

        do_component = Component(
            comp_uuid=component.uuid,
            main_file=str(component.main_file),
            class_name=component.class_name,
            module_name=component.module_name,
        )
        self.query.add_component(do_component)

    @cached_property
    def _project_tree(self):
        """The project tree."""
        return self.query.copy_project_tree()

    @cached_property
    def _sim_uuid(self) -> str:
        """The UUID of the simulation, not its component."""
        tree = self._project_tree

        # We should be at either the simulation or its hidden component. The tree should contain the simulation, but
        # not its hidden component, so we can figure out where we are by checking if the current item's UUID is in
        # the tree. If it is, we must be at the simulation. If not, we're at its component, and its parent item is
        # the simulation.
        current_item_uuid = self.query.current_item_uuid()
        node = tree_util.find_tree_node_by_uuid(tree, current_item_uuid)
        if node is not None:
            return node.uuid

        parent_item_uuid = self.query.parent_item_uuid()
        node = tree_util.find_tree_node_by_uuid(tree, parent_item_uuid)
        return node.uuid

    @cached_property
    def _sim_tree_node(self) -> tree_util.TreeNode:
        """The simulation's tree node."""
        node = tree_util.find_tree_node_by_uuid(self._project_tree, self._sim_uuid)
        return node

    @cached_property
    def ugrid(self) -> Optional[UGrid2d]:
        """The UGrid."""
        ugrid_node = self._ugrid_tree_node
        if ugrid_node is None:
            return None

        ugrid_uuid = ugrid_node.uuid
        return self.geometries[ugrid_uuid]

    @property
    def grid_hash(self) -> str:
        """The hash of the UGrid linked to the simulation."""
        if self._grid_hash is None:
            ugrid = self.ugrid
            self._grid_hash = self._grid_hashes[ugrid.uuid]
        return self._grid_hash

    @cached_property
    def linked_ugrid_uuids(self) -> list[str]:
        """
        The UUIDs of all linked UGrids.

        0 or 1 UGrid is normal. 2 should only be a transient state between a new one being linked and the old one being
        kicked out to maintain the usual limit.
        """
        ugrid_nodes = tree_util.descendants_of_type(
            self._sim_tree_node, allow_pointers=True, xms_types=['TI_MESH2D_PTR'], only_first=False
        )
        uuids = [ugrid_node.uuid for ugrid_node in ugrid_nodes]
        return uuids

    @cached_property
    def ugrid_name(self) -> str:
        """The name of the linked UGrid, or an empty string if none is linked."""
        ugrid_node = self._ugrid_tree_node
        if ugrid_node:
            return ugrid_node.name
        return ''

    @cached_property
    def _ugrid_tree_node(self) -> tree_util.TreeNode:
        """The tree node of the UGrid."""
        ugrid_node = tree_util.descendants_of_type(
            self._sim_tree_node, allow_pointers=True, xms_types=['TI_MESH2D_PTR'], only_first=True
        )
        return ugrid_node

    @cached_property
    def roughness_coverage(self) -> tuple[Optional[Coverage], Optional[RoughnessComponent]]:
        """
        The coverage and component for the linked roughness coverage, or `(None, None)` if not linked.

        This only returns one coverage, unlike the BC and Material coverages, since the roughness coverage should only
        be linked long enough to convert it to another coverage.
        """
        roughness_node: tree_util.TreeNode = tree_util.descendants_of_type(
            self._sim_tree_node,
            allow_pointers=True,
            model_name='HydroAS',
            coverage_type='RoughnessComponent',
            only_first=True
        )
        if not roughness_node:
            return None, None

        coverage_uuid = roughness_node.uuid
        coverage = self.query.item_with_uuid(coverage_uuid)

        do_component: Component = self.query.item_with_uuid(
            coverage_uuid, model_name='HydroAS', unique_name='RoughnessComponent'
        )
        main_file = do_component.main_file
        component = RoughnessComponent(main_file)
        self.query.load_component_ids(component, points=True, arcs=True, polygons=True)
        component.comp_to_xms.setdefault(component.cov_uuid, {}).setdefault(TargetType.polygon, {})

        return coverage, component

    def _get_dataset(self, dataset_uuid: str) -> Optional[DatasetReader]:
        """
        Get a dataset.

        Args:
            dataset_uuid: UUID of the dataset to get.

        Returns:
            The dataset with the given UUID, or None if no datasets had that UUID.
        """
        dataset = self.query.item_with_uuid(dataset_uuid)
        if not dataset or not isinstance(dataset, DatasetReader):
            return None
        return dataset

    def _get_ugrid(self, ugrid_uuid: str) -> UGrid2d | tuple[UGrid2d, str]:
        """
        Get a UGrid.

        Args:
            ugrid_uuid: UUID of the UGrid to get.

        Returns:
            The UGrid.
        """
        do_ugrid = self.query.item_with_uuid(ugrid_uuid)
        grid_hash = compute_crc(do_ugrid.cogrid_file)
        ugrid = read_grid_from_file(do_ugrid.cogrid_file)

        self._grid_hashes[ugrid.uuid] = grid_hash
        return ugrid

    def link(self, item_uuid: str):
        """
        Link the item with the given UUID to the simulation.

        Args:
            item_uuid: UUID of the item to link.
        """
        if self.query:
            self.query.link_item(self._sim_uuid, item_uuid)
        else:
            self.linked_items.append(item_uuid)

    def unlink(self, item_uuid: str):
        """
        Unlink an item from the simulation.

        Args:
            item_uuid: UUID of the item to unlink.
        """
        if self.query is not None:
            self.query.unlink_item(self._sim_uuid, item_uuid)
        else:
            self.unlinked_items.append(item_uuid)

    def unlink_materials(self):
        """Unlink all the mapped and unmapped materials from the simulation."""
        uuids = []

        coverage, _component = self.material_coverage
        if coverage is not None:
            uuids.append(coverage.uuid)

        data = self.mapped_material_data
        if data is not None:
            uuids.append(data.uuid)

        if self.query:
            for item_uuid in uuids:
                self.query.unlink_item(self._sim_uuid, item_uuid)
        else:
            self.unlinked_items.extend(uuids)

    @cached_property
    def bc_coverages(self) -> list[tuple[Coverage, CoverageData]]:
        """The boundary condition coverages and their data managers that are linked to the simulation."""
        return self._coverages(self.bc_component_unique_name, CoverageComponent)

    @cached_property
    def material_coverage(self) -> tuple[Optional[Coverage], Optional[MaterialData]]:
        """The material coverages and their data managers that are linked to the simulation."""
        coverages = self._coverages(self.material_component_unique_name, MaterialComponent)
        return coverages[0] if coverages else (None, None)

    @cached_property
    def mapped_material_data(self) -> Optional[MappedMaterialData]:
        """The data manager for the linked mapped materials, or None."""
        node = tree_util.descendants_of_type(
            self._sim_tree_node, only_first=True, unique_name=self.mapped_material_component_unique_name
        )
        if node is None:
            return None

        do_component = self.query.item_with_uuid(node.uuid)
        main_file = do_component.main_file
        data = MappedMaterialData(main_file)
        return data

    def _coverages(self, unique_name: str, component_type: Type) -> list[tuple[Coverage, CoverageData | MaterialData]]:
        """
        Get the all the coverages of a particular type that are linked to the simulation.

        Args:
            unique_name: The unique name of the coverage component to look for.
            component_type: Type of component to construct for its data manager.

        Returns:
            List of tuples of (coverage, data). The data managers will have their component_id_map attributes
            initialized.
        """
        coverages = []

        nodes: list[tree_util.TreeNode] = tree_util.descendants_of_type(
            self._sim_tree_node, allow_pointers=True, model_name='HydroAS', coverage_type=unique_name, only_first=False
        )
        for node in nodes:
            coverage_uuid = node.uuid
            coverage = self.query.item_with_uuid(coverage_uuid)
            do_component = self.query.item_with_uuid(coverage_uuid, model_name=self.model_name, unique_name=unique_name)
            main_file = do_component.main_file
            component = component_type(main_file)
            self.query.load_component_ids(component, points=True, arcs=True, polygons=True)
            coverages.append((coverage, component.data))

        return coverages

    def get_component_data(self, item_uuid: str, unique_name: str) -> CoverageData | MaterialData:
        """
        Get the data manager for a coverage or material component.

        Args:
            item_uuid: The UUID of the coverage the component belongs to.
            unique_name: The unique name of the component to get the data manager for.

        Returns:
            Data manager for the coverage's component.
        """
        do_component = self.query.item_with_uuid(item_uuid, model_name=self.model_name, unique_name=unique_name)
        if unique_name == self.bc_component_unique_name:
            data = CoverageData(do_component.main_file)
            return data
        if unique_name == self.material_component_unique_name:
            data = MaterialData(do_component.main_file)
            return data
        raise AssertionError('Unrecognized unique name')

    def finish(self):
        """Mark the operation as finished and send changes to XMS."""
        if self.query:
            self.query.send()
        else:
            self.finished = True


def compute_crc(file: str | Path) -> str:
    """
    Calculate a file's CRC hash.

    Args:
        file: File to compute the CRC of.

    Returns:
        The file's CRC.
    """
    with open(file, 'rb') as f:
        crc = hex(binascii.crc32(f.read()))
        return crc
