"""Module for get_cell_id_to_material_id_map."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['get_cell_id_to_material_id_map']

# 1. Standard Python modules
import itertools

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.data_objects.parameters import Polygon
from xms.snap import SnapPolygon

# 4. Local modules


def get_cell_id_to_material_id_map(material_polygons: list[list[Polygon]], mesh, default_material: int = -1):
    """
    Get a mapping from cell ID to material ID.

    This function will always assign a cell to exactly one material, even if it happens to be contained in multiple
    polygons. I'm not sure how it does this. It might be by area. SnapPolygon figures it out.

    Cell IDs are their zero-based index in the grid's cell list, not the 1-based IDs shown in XMS.

    A material's ID is its index in `material_polygons`. So, for example, all the cells that are assigned to polygons
    in `material_polygons[0]` will be assigned the ID of 0, and all the cells assigned to `material_polygons[1]` will
    have the ID of 1.

    Cells that are not part of any polygon will be assigned `default_material`. This ensures that every cell is
    present in the output. In the event that every cell is covered by a polygon, `default_material` will not be used.

    Args:
        material_polygons: Polygons defining the areas covered by materials. Each element is a list of all the polygons
            covering the area of that material. An element may be an empty list, which indicates no cells have that
            material, so no cells will receive that ID (but see `default_material`).
        mesh: The mesh to make the mapping for.
        default_material: The material ID to assign any cell which is not part of any polygon. The default of -1 is an
            impossible material, which makes it easy to determine which cells were unassigned.

    Returns:
        A list-like object mapping cell IDs in the mesh to material ID to apply to the cell.
    """
    cell_id_to_material_id = np.full(mesh.ugrid.cell_count, default_material, dtype=int)

    if not material_polygons:
        return cell_id_to_material_id

    flattened_polygons = list(itertools.chain.from_iterable(material_polygons))

    snapper = SnapPolygon()
    snapper.set_grid(mesh, False)
    snapper.add_polygons(flattened_polygons)
    for material_id, polygon_list in enumerate(material_polygons):
        for material_polygon in polygon_list:
            cells = snapper.get_cells_in_polygon(material_polygon.id)
            for cell in cells:
                cell_id_to_material_id[cell] = material_id

    return cell_id_to_material_id
