"""SimQueryHelper class."""
# 1. Standard python modules
import binascii
import logging
import os

# 2. Third party modules

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.srh.components.bc_component import BcComponent
from xms.srh.components.material_component import MaterialComponent

# 4. Local modules
from xms.srhw.components.rainfall_component import RainfallComponent

__copyright__ = "(C) Copyright Aquaveo 2022"
__license__ = "All rights reserved"


class SimQueryHelper:
    """Class used to get data from XMS related to SRH."""

    def __init__(self, query, at_sim=False):
        """Constructor. Must be constructed with a Context at the simulation Context level.

        Args:
            query (:obj:`Query`): class to communicate with XMS
            at_sim (:obj:`bool`): True if query Context is at the simulation level, False if it
                is at the simulation component level
        """
        self._query = query
        self.sim_uuid = None
        self.sim_comp_file = ''
        self.component_folder = ''
        self.sim_component = None
        self.sim_tree_item = None
        if self._query is not None:
            self._initialize_from_xms(at_sim)
        self._logger = logging.getLogger('xms.srhw')
        self.grid_name = ''
        self.grid_units = ''
        self.grid_uuid = ''
        self.grid_wkt = ''
        self.co_grid = None
        self.co_grid_file = ''
        self.co_grid_file_crc32 = ''
        self.existing_mapped_component_uuids = []
        self.coverages = dict()
        self.bc_component = None
        self.material_component = None
        self.rainfall_component = None
        self.grid_error = ''
        self.using_ugrid = False
        self.solution_tree_item = None
        self.mesh_link = None
        self.mesh_tree_item = None

    def _initialize_from_xms(self, at_sim):
        """Initialize member variables with data retrieved from XMS.

        Args:
            at_sim (:obj:`bool`): True if query Context is at the simulation level, False if it
                is at the simulation  component level
        """
        if at_sim:  # Get simulation data and then move to component level
            self.sim_uuid = self._query.current_item_uuid()
            sim_comp = self._query.item_with_uuid(self.sim_uuid, model_name='SRH-W', unique_name='SimComponent')
            self.sim_comp_file = sim_comp.main_file
        else:
            self.sim_comp_file = self._query.current_item().main_file
            self.sim_uuid = self._query.parent_item_uuid()
        self.component_folder = os.path.dirname(os.path.dirname(self.sim_comp_file))
        from xms.srhw.components.sim_component import SimComponent  # avoid circular dependencies
        self.sim_component = SimComponent(self.sim_comp_file)
        self.sim_tree_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, self.sim_uuid)

    def get_geometry_data(self):
        """Get the mesh linked to the simulation."""
        self._get_mesh()

    def get_sim_data(self):
        """Gets the coverages associated with a simulation."""
        self._get_mesh()
        self._get_coverages()
        self._get_uuids_of_existing_mapped_components()
        self._get_coverage_comp_ids()

    def get_solution_data(self):
        """Get solution datasets for a simulation.

        Returns:
            (:obj:`list`): List of the solution data_object Dataset dumps for this simulation

        """
        dset_dumps = []
        sim_name = self.sim_tree_item.name
        sim_folder = f'{sim_name} (SRH-2D)'

        # Get the mesh tree item.
        self._get_mesh_link()
        if not self.mesh_link:
            self._logger.error('Unable to find SRH-2D solution datasets.')
            return dset_dumps
        mesh_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, self.mesh_link.uuid)

        # Get the simulation solution folder
        solution_folder = tree_util.first_descendant_with_name(mesh_item, sim_folder)
        if not solution_folder:
            self._logger.error('Unable to find SRH-2D solution datasets.')
            return dset_dumps
        self.solution_tree_item = solution_folder
        self.mesh_tree_item = mesh_item

        # Get dumps of all the children datasets
        solution_dsets = tree_util.descendants_of_type(solution_folder, xmd.DatasetItem)
        try:
            for dset in solution_dsets:
                dset_dumps.append(self._query.item_with_uuid(dset.uuid))
        except Exception:  # pragma no cover - hard to test exceptions using QueryPlayback
            self._logger.exception('Error getting solution dataset.')
        return dset_dumps

    def _get_coverages(self):
        """Gets the coverages associated with a simulation."""
        covs = [('Materials', 'Material_Component'), ('Rainfall', 'RainfallComponent'),
                ('Boundary Conditions', 'Bc_Component')]
        for pair in covs:
            self._get_coverage_info('SRH-W', pair)
            if pair[0] not in self.coverages:
                self._get_coverage_info('SRH-2D', pair)

    def _get_coverage_info(self, model_name, pair):
        cov_item = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_COVER_PTR'],
                                                 allow_pointers=True, only_first=True, recurse=False,
                                                 coverage_type=pair[0], model_name=model_name)
        if cov_item:
            cov_comp = self._query.item_with_uuid(item_uuid=cov_item.uuid, model_name=model_name, unique_name=pair[1])
            if cov_comp is not None:
                cov_dump = self._query.item_with_uuid(cov_item.uuid)
                self.coverages[pair[0]] = (cov_dump, cov_comp.main_file)

    def _get_mesh_link(self):
        """Gets the mesh associated with a simulation."""
        # Find the mesh tree item under the simulation.
        mesh_links = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_MESH2D_PTR', 'TI_UGRID_PTR'],
                                                   allow_pointers=True)
        if len(mesh_links) != 1:
            return

        if mesh_links[0].item_typename == 'TI_UGRID_PTR':
            self.using_ugrid = True
        self.mesh_link = mesh_links[0]

    def _get_mesh(self):
        """Gets the mesh associated with a simulation."""
        self._logger.info('Getting mesh from simulation.')
        self._get_mesh_link()
        if self.mesh_link is None:
            return
        mesh_item = self._query.item_with_uuid(self.mesh_link.uuid)
        self.grid_name = mesh_item.name
        proj = mesh_item.projection
        unit_str = proj.horizontal_units
        if unit_str == 'METERS':
            self.grid_units = 'GridUnit "METER"'
        elif unit_str in ['FEET (U.S. SURVEY)', 'FEET (INTERNATIONAL)']:
            self.grid_units = 'GridUnit "FOOT"'
        else:
            err_str = 'Unable to get horizontal units from mesh'
            self._logger.error(err_str)
            self._logger.error(f'unit_str: {unit_str}.')
            msg = 'Units must be one of: "METERS", "FEET (U.S. SURVEY)", "FEET (INTERNATIONAL)"'
            self._logger.error(msg)
            self.grid_error = f'{err_str}. {msg}'
            raise RuntimeError(err_str)
        self.grid_uuid = mesh_item.uuid
        self.grid_wkt = proj.well_known_text
        self.co_grid_file = mesh_item.cogrid_file
        with open(self.co_grid_file, 'rb') as f:
            self.co_grid_file_crc32 = str(hex(binascii.crc32(f.read()) & 0xFFFFFFFF))
        self.co_grid = read_grid_from_file(mesh_item.cogrid_file)
        self._logger.info('Mesh successfully loaded.')

    def _get_uuids_of_existing_mapped_components(self):
        """Gets the uuids of any existing mapped components."""
        # Get children of the simulation that are component tree items
        comp_items = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_COMPONENT'])
        self.existing_mapped_component_uuids = [comp_item.uuid for comp_item in comp_items]

    def _get_coverage_comp_ids(self):
        """Load the component ids for the coverage."""
        self._logger.info('Getting feature ids and component ids for coverages.')
        if 'Materials' in self.coverages:
            self.material_component = MaterialComponent(self.coverages['Materials'][1])
            self._query.load_component_ids(self.material_component, polygons=True)
        if 'Rainfall' in self.coverages:
            self.rainfall_component = RainfallComponent(self.coverages['Rainfall'][1])
            self._query.load_component_ids(self.rainfall_component, points=True)
        if 'Boundary Conditions' in self.coverages:
            self.bc_component = BcComponent(self.coverages['Boundary Conditions'][1])
            self._query.load_component_ids(self.bc_component, arcs=True)
