"""Module for the RubbleMoundStructureWriter class."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['write_rubble_mound_structures']

# 1. Standard Python modules
from functools import cached_property
from itertools import count
from pathlib import Path
from typing import TextIO

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import QuadtreeGrid2d
from xms.data_objects.parameters import Coverage, Polygon
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData
from xms.guipy.data.target_type import TargetType
from xms.snap import SnapPolygon

# 4. Local modules
from xms.cmsflow.data.model import get_model
from xms.cmsflow.file_io.card_writer import CardWriter


def write_rubble_mound_structures(
    coverage: Coverage, data: CoverageBaseData, ugrid: QuadtreeGrid2d, project_name: str, query: Query, logger,
    wrote_header: bool, cards: TextIO
) -> bool:
    """
    Write all the data needed by rubble mounds in a structure coverage.

    Args:
        coverage: Coverage containing geometry to write.
        data: Data manager for the coverage. Should have its component_id_map initialized.
        ugrid: The QuadTree to snap the coverage to.
        project_name: Name of the project. Used as the base for various file names.
        query: Interprocess communication object.
        logger: Where to log messages to.
        wrote_header: Whether the structures header has already been written.
        cards: Where to write cards to. Typically obtained by calling `open(...)` on the *.cmcards file.

    Returns:
        Whether the `!Structures` header was written (either before or by calling this function).
    """
    writer = RubbleMoundStructureWriter(coverage, data, ugrid, project_name, query, logger, wrote_header, cards)
    return writer.write()


class RubbleMoundStructureWriter:
    """Class for writing all the data needed by rubble mounds in a structure coverage."""
    def __init__(
        self, coverage: Coverage, data: CoverageBaseData, ugrid: QuadtreeGrid2d, project_name: str, query: Query,
        logger, wrote_header: bool, cards: TextIO
    ):
        """
        Initialize the writer.

        Args:
            coverage: Coverage containing geometry to write.
            data: Data manager for the coverage. Should have its component_id_map initialized.
            ugrid: The QuadTree to snap the coverage to.
            project_name: Name of the project. Used as the base for various file names.
            query: Interprocess communication object.
            logger: Where to log messages to.
            cards: Where to write cards to. Typically obtained by calling `open(...)` on the *.cmcards file.
            wrote_header: Whether the structures header has already been written.
        """
        self._coverage = coverage
        self._data = data
        self._logger = logger
        self._ugrid = ugrid
        self._project_name = project_name
        self._query = query
        self._tree = query.copy_project_tree()
        self._cards = CardWriter(cards)
        self._ids = np.zeros(ugrid.ugrid.cell_count, dtype=int)
        self._dataset_paths: dict[str, str] = {}
        self._component_ids: dict[int, int] = data.component_id_map[TargetType.polygon]
        self._next_polygon_id = 1
        self._wrote_structures_header = wrote_header
        self._wrote_rubble_mound_header = False

    def write(self) -> bool:
        """
        Write all the rubble mound data needed for the coverage.

        Returns:
            Whether the `!Structures` header was written (either before or by calling this function).
        """
        Path(f'{self._project_name}_RM.h5').unlink(missing_ok=True)

        for polygon in self._coverage.polygons:
            self._write_polygon(polygon)

        self._write_id_dataset()

        if self._wrote_rubble_mound_header:
            self._cards.write_newline()

        return self._wrote_structures_header

    @cached_property
    def _snapper(self) -> SnapPolygon:
        snapper = SnapPolygon()
        snapper.set_grid(self._ugrid, False)
        snapper.add_polygons(self._coverage.polygons)
        return snapper

    def _ensure_header_written(self):
        """
        Ensure the header for the rubble mound structure section is written.

        Does nothing after the first time it was called.
        """
        if not self._wrote_structures_header:
            self._logger.info('Writing rubble mound structures')
            self._wrote_structures_header = True
            self._cards.write('!Structures', indent=0)

        if not self._wrote_rubble_mound_header:
            self._wrote_rubble_mound_header = True
            self._cards.write('RUBBLE_MOUND_ID_DATASET', f'"{self._project_name}_RM.h5" "/Datasets/ID"', indent=0)

    def _write_polygon(self, polygon: Polygon):
        """
        Write all the data needed by a single polygon.

        Args:
            polygon: The polygon to write data for.
        """
        if polygon.id not in self._component_ids:
            return

        cells = self._snapper.get_cells_in_polygon(polygon.id)
        if len(cells) == 0:
            self._logger.warning(f'Polygon {polygon.id} did not intersect grid and will not be written.')
            return

        self._ensure_header_written()

        # The snapper yields tuples, which numpy interprets as one index for each of many dimensions.
        # Converting to an array makes it interpret as many indices on one dimension.
        cells = np.array(cells, dtype=int)
        self._ids[cells] = self._next_polygon_id
        self._next_polygon_id += 1

        self._write_datasets_for_polygon(polygon.id)
        self._write_cards_for_polygon(polygon.id)

    def _write_datasets_for_polygon(self, polygon_id: int):
        """
        Ensure all the datasets needed by a polygon have been written to the *_RM.h5 file.

        Args:
            polygon_id: Feature ID of the polygon to write datasets for. Used to get feature values and report errors.
        """
        section = get_model().polygon_parameters
        component_id = self._component_ids[polygon_id]
        values = self._data.feature_values(TargetType.polygon, component_id)
        section.restore_values(values)
        group = section.group('rubble_mound')
        if group.parameter('rock_diameter_type').value == 'Dataset':
            parameter = group.parameter('rock_diameter_dataset')
            self._write_dataset(polygon_id, parameter.label, parameter.value)
        if group.parameter('porosity_type').value == 'Dataset':
            parameter = group.parameter('porosity_dataset')
            self._write_dataset(polygon_id, parameter.label, parameter.value)
        if group.parameter('base_depth_type').value == 'Dataset':
            parameter = group.parameter('base_depth_dataset')
            self._write_dataset(polygon_id, parameter.label, parameter.value)

    def _write_dataset(self, polygon_id, parameter_label, dataset_uuid):
        """
        Ensure a dataset has been written to the *_RM.h5 file.

        This is safe to call multiple times with the same dataset. It will only write the dataset once.

        Args:
            polygon_id: Feature ID of the polygon that needs the dataset. Used for error reporting.
            parameter_label: Human-readable description of the parameter on the polygon that referenced the dataset.
                Used for error reporting.
            dataset_uuid: UUID of the dataset to write.
        """
        if dataset_uuid in self._dataset_paths:
            return

        reader: DatasetReader = self._query.item_with_uuid(dataset_uuid)
        if reader is None:
            message = f'Polygon {polygon_id}: parameter `{parameter_label}` refers to nonexistent dataset.'
            raise RuntimeError(message)

        path = self._get_dataset_path(dataset_uuid)
        self._dataset_paths[dataset_uuid] = path
        writer = DatasetWriter(
            h5_filename=f'{self._project_name}_RM.h5',
            name=path,
            dset_uuid=dataset_uuid,
            null_value=0,
            time_units=reader.time_units,
            location='cells',
            overwrite=False
        )
        writer.write_xmdf_dataset(times=[0.0], data=[reader.values[0]])

    def _get_dataset_path(self, dataset_uuid: str) -> str:
        """
        Get the path in the *_RM.h5 file where a dataset should be stored to.

        Args:
            dataset_uuid: UUID of the dataset to get a path for.

        Returns:
            Path to store the dataset at.
        """
        # The reader has a dataset name on it, but it isn't always reliable. If you duplicate a dataset in SMS, the name
        # for it on the reader is still the name of the original dataset. The tree node's name seems to be more
        # reliable, so we use that instead.
        dataset_node = tree_util.find_tree_node_by_uuid(self._tree, dataset_uuid)
        base_name = dataset_node.name

        # Normally dataset names are unique, but it's possible to have duplicates, e.g. by putting folders under the
        # UGrid.
        used_paths = set(self._dataset_paths.values())
        path = f'/Datasets/{base_name}'
        counter = count(start=1)
        while path in used_paths:
            path = f'/Datasets/{base_name} ({next(counter)})'
        return path

    def _write_id_dataset(self):
        """Write the rubble mound ID dataset."""
        if not self._wrote_rubble_mound_header:
            return  # If we didn't write a header, there were no polygons, so no need to write the dataset.

        writer = DatasetWriter(h5_filename=f'{self._project_name}_RM.h5', name='ID', location='cells', overwrite=False)
        writer.write_xmdf_dataset(times=[0.0], data=[self._ids])

    def _write_cards_for_polygon(self, polygon_id: int):
        """
        Write the necessary cards for a single polygon.

        Args:
            polygon_id: Feature ID of the polygon to write cards for.
        """
        section = get_model().polygon_parameters
        values = self._data.feature_values(TargetType.polygon, self._component_ids[polygon_id])
        section.restore_values(values)
        group = section.group('rubble_mound')

        self._cards.write('RUBBLE_MOUND_BEGIN', indent=0)

        name = group.parameter('name').value
        self._cards.write('NAME', f"'{name}'")

        if group.parameter('rock_diameter_type').value == 'Constant':
            constant = group.parameter('rock_diameter').value
            self._cards.write('ROCK_DIAMETER_CONSTANT', str(constant))
        else:
            dataset = group.parameter('rock_diameter_dataset').value
            path = self._dataset_paths[dataset]
            value = f'"{self._project_name}_RM.h5" "{path}"'
            self._cards.write('ROCK_DIAMETER_DATASET', value)

        if group.parameter('porosity_type').value == 'Constant':
            constant = group.parameter('porosity').value
            self._cards.write('STRUCTURE_POROSITY_CONSTANT', str(constant))
        else:
            dataset = group.parameter('porosity_dataset').value
            path = self._dataset_paths[dataset]
            value = f'"{self._project_name}_RM.h5" "{path}"'
            self._cards.write('STRUCTURE_POROSITY_DATASET', value)

        if group.parameter('base_depth_type').value == 'Constant':
            constant = group.parameter('base_depth').value
            self._cards.write('STRUCTURE_BASE_DEPTH_CONSTANT', f'{constant:g}')
        else:
            dataset = group.parameter('base_depth_dataset').value
            path = self._dataset_paths[dataset]
            value = f'"{self._project_name}_RM.h5" "{path}"'
            self._cards.write('STRUCTURE_BASE_DEPTH_DATASET', value)

        method_name = group.parameter('calc_method').value
        method_id = group.parameter('calc_method').options.index(method_name) + 1
        self._cards.write('FORCHHEIMER_COEFF_METHOD', f'{method_id}     !{method_name}')

        self._cards.write('RUBBLE_MOUND_END', indent=0)
