"""Module for ExportSimRunner."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['ExportSimRunner']

# 1. Standard Python modules
from itertools import count
from pathlib import Path
from typing import Iterable

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import Coverage
from xms.gmi.data.generic_model import UNASSIGNED_MATERIAL_ID
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExpectedError, FeedbackThread

# 4. Local modules
from xms.hydroas.data.coverage_data import CoverageData
from xms.hydroas.dmi.xms_data import XmsData
from xms.hydroas.feedback.snapper import SnapMode, Snapper
from xms.hydroas.file_io.errors import GmiError, Messages
from xms.hydroas.file_io.gmi_writer import GmiWriter


class ExportSimRunner(FeedbackThread):
    """Export thread."""
    def __init__(self, data: XmsData):
        """
        Constructor.

        Args:
            data: Interprocess communication object.
        """
        super().__init__(query=None, is_export=True, create_query=False)
        self.display_text = {
            'title': 'Export Simulation',
            'working_prompt': 'Exporting simulation file. Please wait...',
            'warning_prompt': 'Warning(s) encountered while exporting simulation. Review log output for more details.',
            'error_prompt': 'Error(s) encountered while exporting simulation. Review log output for more details.',
            'success_prompt': 'Successfully exported simulation',
            'note': '',
            'auto_load': 'Close this dialog automatically when exporting is finished.'
        }

        self._data = data

        self._points = []
        self._point_values = []

        self._arcs = []
        self._arc_ids: list[int] = []
        self._arc_names = []
        self._arc_values = []

        self._material_names = []  # Human-readable name of the material used by a polygon.
        self._material_numbers = []  # Integer ID of the material
        self._material_values = []
        self._cell_materials = None

    def _run(self):
        """Export a simulation."""
        self._log.info('Exporting simulation...')

        self._log.info('Retrieving simulation data...')
        sim_data = self._data.sim_data
        model = sim_data.generic_model
        model_instantiation = sim_data.model_values
        global_instantiation = sim_data.global_values

        self._log.info('Retrieving mesh...')
        ugrid = self._data.ugrid

        self._log.info('Retrieving metadata...')
        projection = self._data.projection
        app_version = self._data.xms_version

        self._get_boundaries()
        self._get_materials()
        name = self._data.ugrid_name if self._data.ugrid else ''

        writer = GmiWriter(
            self._log,
            name=name,
            ugrid=ugrid,
            model=model,
            model_instantiation=model_instantiation,
            global_instantiation=global_instantiation,
            points=self._points,
            point_values=self._point_values,
            arcs=self._arcs,
            arc_ids=self._arc_ids,
            arc_names=self._arc_names,
            arc_values=self._arc_values,
            material_names=self._material_names,
            material_numbers=self._material_numbers,
            cell_materials=self._cell_materials,
            material_groups=self._material_values,
            projection=projection,
            xms_version=app_version,
        )

        path = Path('mesh.2dm')

        try:
            writer.write(path)
        except GmiError as error:
            raise ExpectedError(f'Error exporting simulation: {error}')

    def _get_materials(self):
        """Get all the necessary data to write definitions and values for the material coverages."""
        material_data = self._data.mapped_material_data
        if material_data is None:
            return
        if self._data.ugrid and material_data.grid_hash != self._data.grid_hash:
            raise ExpectedError(
                'Error exporting simulation: Mesh was modified after materials were mapped. Remove the mapped materials'
                ' and apply new ones to export materials.'
            )

        self._log.info('Processing mapped materials...')
        section = material_data.generic_model.material_parameters
        section.restore_values(material_data.material_values)

        for name in section.group_names:
            if name != UNASSIGNED_MATERIAL_ID:
                label = section.group(name).label
                values = section.group(name)
                number = int(name)
                self._material_names.append(label)
                self._material_numbers.append(number)
                self._material_values.append(values)

        self._cell_materials = material_data.cell_materials

    def _get_boundaries(self):
        """Get all the necessary data to write definitions and values for the point-arc coverage."""
        bc_coverages = self._data.bc_coverages
        if not bc_coverages:
            return
        if not self._data.ugrid:
            raise ExpectedError(Messages.features_without_domain)

        snapper = Snapper(self._data.ugrid, self._log)

        for coverage, data in self._data.bc_coverages:
            self._log.info(f'Processing coverage: "{coverage.name}"')
            if TargetType.point in data.component_id_map:
                self._log.info('Snapping points...')
                points = snapper.points(coverage)
                self._log.info('Collecting point attributes...')
                self._add_points(points, data)
            if TargetType.arc in data.component_id_map:
                self._log.info('Snapping arcs...')
                snapping_modes = self._get_snapping_modes(coverage, data)
                try:
                    arcs = snapper.arcs(coverage, snapping_modes)
                except GmiError as err:
                    raise ExpectedError(str(err))
                self._log.info('Collecting arc attributes...')
                self._add_arcs(arcs, data)
            self._log.info('Done processing coverage.')

        self._renumber_arcs()

    @staticmethod
    def _get_snapping_modes(coverage: Coverage, data: CoverageData) -> dict[int, SnapMode]:
        """
        Get the snapping modes for arcs in a coverage.

        Args:
            data: Data source with values for each arc.

        Returns:
            Mapping of `feature_id -> SnapMode` describing how each arc should be snapped.
        """
        mapping = {}
        model = data.generic_model
        section = model.arc_parameters
        for arc in coverage.arcs:
            mapping[arc.id] = SnapMode.INTERIOR
            values = data.feature_values(TargetType.arc, feature_id=arc.id)
            section.restore_values(values)
            for group_name in section.active_group_names:
                if not section.group(group_name).legal_on_interior:
                    mapping[arc.id] = SnapMode.EXTERIOR

        endpoint_groups = ['6', '7', '8']

        model = data.generic_model
        section = model.arc_parameters
        for feature_id in data.component_id_map[TargetType.arc].keys():
            values = data.feature_values(TargetType.arc, feature_id=feature_id)
            section.restore_values(values)
            for endpoint_group in endpoint_groups:
                if section.has_group(endpoint_group) and section.group(endpoint_group).is_active:
                    mapping[feature_id] = SnapMode.ENDPOINTS

        return mapping

    def _add_points(self, points: Iterable[tuple[int, int]], data: CoverageData):
        """
        Add some snapped points to be written.

        Args:
            points: Points to be added. Should come from snapper.
            data: Data manager to extract arc values from.
        """
        for feature_id, node_id in points:
            values = data.feature_values(TargetType.point, feature_id=feature_id)
            self._points.append(node_id)
            self._point_values.append(values)

    def _add_arcs(self, arcs: Iterable[tuple[int, tuple[int, ...]]], data: CoverageData):
        """
        Add some snapped arcs to be written.

        Args:
            arcs: Arcs to be added. Should come from snapper.
            data: Data manager to extract arc values from.
        """
        for feature_id, node_string in arcs:
            values = data.feature_values(TargetType.arc, feature_id=feature_id)
            self._arcs.append(node_string)
            self._arc_ids.append(feature_id)
            self._arc_names.append('')
            self._arc_values.append(values)

    def _renumber_arcs(self):
        """Renumber arcs so they don't have duplicate IDs."""
        used_ids = set()
        available_ids = (i for i in count(start=1) if i not in used_ids)
        for i in range(len(self._arc_ids)):
            if self._arc_ids[i] in used_ids:
                self._arc_ids[i] = next(available_ids)
            else:
                used_ids.add(self._arc_ids[i])
