#! python3
from ._xmssnap.snap import _SnapInteriorArc
from .snap_base import _SnapBase, SnapBase
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import UGrid, Arc, FilterLocation


class SnapInteriorArc(_SnapInteriorArc, SnapBase):
    """This class snaps arcs to the boundary of a geometry."""
    def __init__(self):
        """Constructor."""
        _SnapInteriorArc.__init__(self)
        SnapBase.__init__(self)

    def set_grid(self, grid, target_cells):
        """Sets the geometry that will be snapped to.

        Args:
            grid (xms.data_objects.parameters.UGrid): The grid that will be targeted.
            target_cells (bool): True if the snap targets cell centers, point locations if false.
        """
        if isinstance(grid, UGrid):  # data_objects UGrid
            file = grid.cogrid_file
            if file:  # New CoGrid impl
                co_grid = read_grid_from_file(file)
            else:  # Old C++ impl for H5 file format
                co_grid = super().get_co_grid(grid)
            _SnapBase.set_grid(self, co_grid._instance, target_cells)
        else:
            _SnapBase.set_grid(self, grid._instance, target_cells)

    def get_snapped_points(self, arc):
        """Gets snapped locations and ids of the geometry.

        Args:
            arc (xms.data_objects.parameters.Arc or list of tuples of size 3 of doubles): The arc locations.

        Returns:
            A dictionary with keys 'id' and 'location'. 'id' holds a list of integers representing
            point or cell ids depending on the target set. 'location' holds a list that is parallel
            to the one in 'id'. The 'location' list is made of tuples of 3 doubles representing
            the snapped locations.
        """
        if isinstance(arc, Arc):
            arc_pts = arc.get_points(FilterLocation.LOC_ALL)
            # change data_object Point to python iterables
            for i, point in enumerate(arc_pts):
                arc_pts[i] = (point.x, point.y, point.z)
            return super().get_snapped_points(arc_pts)
        else:
            return super().get_snapped_points(arc)
