"""UGridSubset class."""

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
import xms.grid.ugrid

# 4. Local modules


class UGridSubset:
    """Class to do calculations related to create a subset from a 2d ugrid."""
    def __init__(self, ugrid):
        """Initializes the class.

        Args:
            ugrid (:obj:`xms.grid.ugrid.UGrid`): 2d Unstructured grid
        """
        self._ugrid = ugrid
        self._ug_pts = ugrid.locations
        self._ug_cellstream = ugrid.cellstream
        self._cell_boxes = None
        self._pt_classify = None
        self._cell_classify = None
        self._box_min = None
        self._box_max = None
        self._clipped_ugrid = None
        self._boundary_idxes = None
        self._logger = None

    def subset_from_box(self, box):
        """Create a subset ugrid from a bounding box.

        Args:
            box (:obj:`list`): [(xmin, ymin), (xmax, ymax)]

        Returns:
            (:obj:`xms.grid.ugrid.UGrid`): new ugrid
        """
        self._box_min = box[0]
        self._box_max = box[1]
        if self._cell_boxes is None:
            self.calc_cell_boxes()
        self._classify_cells_from_box()
        self._ugrid_from_classified_points_cells()
        return self._clipped_ugrid

    def set_logger(self, logger):
        """Sets the QProgress bar on the class.

        Args:
            logger (:obj:`logging.logger`): logger
        """
        self._logger = logger

    def calc_cell_boxes(self):
        """Create bounding boxes for each cell in the UGrid."""
        self._cell_boxes = []
        cnt = 0
        cell_idx = -1
        while cnt < len(self._ug_cellstream):
            cell_idx += 1
            if cell_idx % 10000 == 0 and self._logger:
                self._logger.info('Processed 10,0000 cells.')
            # cell_type = self._ug_cellstream[cnt]
            num_pts = self._ug_cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cnt = end
            cell_pts = self._ug_cellstream[start:end]
            # get the bounding box for the cell
            cell_coords = []
            for idx in cell_pts:
                cell_coords.append(self._ug_pts[idx])
            x, y, _ = zip(*cell_coords)
            cell_min = (min(x), min(y))
            cell_max = (max(x), max(y))
            self._cell_boxes.append((cell_min, cell_max))
        self._logger = None

    def stitch_grid(self, ugrid):
        """Stitch ugrid into the self._ugrid. The boundary points and edges from ugrid must match in self._ugrid.

        Args:
            ugrid (:obj:`xms.grid.UGrid`): ugrid

        Returns:
            (:obj:`xms.grid.UGrid`): stitched ugrid
        """
        # ugrid boundary points in a dict with key = location
        bound_dict = {}
        pts = ugrid.locations
        for i, pt in enumerate(pts):
            if self._pt_out_of_box(pt):
                bound_dict[(pt[0], pt[1])] = i

        # add original grid pts that will be included in new grid to a list and keep track of new idxes
        set_bound = set(self._boundary_idxes)
        old_to_new = [-1] * len(self._ug_pts)
        new_pts = []
        for idx, pt in enumerate(self._ug_pts):
            if self._pt_out_of_box(pt):
                if idx not in set_bound:
                    old_to_new[idx] = len(new_pts) + len(pts)
                    new_pts.append(pt)
                else:
                    old_to_new[idx] = bound_dict[(pt[0], pt[1])]

        # add cells to new cellstream with update indexes
        cnt = 0
        cell_idx = -1
        new_cellstream = []
        while cnt < len(self._ug_cellstream):
            cell_idx += 1
            cell_type = self._ug_cellstream[cnt]
            num_pts = self._ug_cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = self._ug_cellstream[start:end]
            cnt = end
            if self._cell_classify[cell_idx] == 0:
                add_cell = [cell_type, num_pts] + [old_to_new[idx] for idx in cell_pts]
                new_cellstream.extend(add_cell)

        stitched_pts = [p for p in pts] + [p for p in new_pts]
        stitched_cells = [c for c in ugrid.cellstream] + new_cellstream
        return xms.grid.ugrid.UGrid(stitched_pts, stitched_cells)

    def _ugrid_from_classified_points_cells(self):
        """Make a new ugrid from the cells that overlap the input box."""
        self._old_to_new_pt_idx = [-1] * len(self._ug_pts)
        new_idx = 0
        new_pts = []
        for i, flag in enumerate(self._pt_classify):
            if flag == 1:
                self._old_to_new_pt_idx[i] = new_idx
                new_pts.append(self._ug_pts[i])
                new_idx += 1

        cnt = 0
        cell_idx = -1
        new_cellstream = []
        while cnt < len(self._ug_cellstream):
            cell_idx += 1
            cell_type = self._ug_cellstream[cnt]
            num_pts = self._ug_cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = self._ug_cellstream[start:end]
            cnt = end
            if self._cell_classify[cell_idx] == 1:
                add_cell = [cell_type, num_pts] + [self._old_to_new_pt_idx[old_idx] for old_idx in cell_pts]
                new_cellstream.extend(add_cell)
        self._clipped_ugrid = xms.grid.ugrid.UGrid(new_pts, new_cellstream)

    def _classify_cells_from_box(self):
        """Classify cells that have a bounding box that overlaps the input box."""
        self._pt_classify = [0] * len(self._ug_pts)
        self._cell_classify = [0] * self._ugrid.cell_count
        self._boundary_idxes = []
        cnt = 0
        cell_idx = -1
        while cnt < len(self._ug_cellstream):
            cell_idx += 1
            # cell_type = self._ug_cellstream[cnt]
            num_pts = self._ug_cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cnt = end
            cell_min = self._cell_boxes[cell_idx][0]
            cell_max = self._cell_boxes[cell_idx][1]
            cell_pts = self._ug_cellstream[start:end]
            x_overlap = cell_min[0] > self._box_max[0] or cell_max[0] < self._box_min[0]
            y_overlap = cell_min[1] > self._box_max[1] or cell_max[1] < self._box_min[1]
            if x_overlap or y_overlap:
                pass  # boxes do not overlap
            else:
                self._cell_classify[cell_idx] = 1
            if self._cell_classify[cell_idx] == 1:
                for idx in cell_pts:
                    self._pt_classify[idx] = 1
                    # if the point is outside the box then it is a boundary point on the new ugrid
                    if self._pt_out_of_box(self._ug_pts[idx]):
                        self._boundary_idxes.append(idx)

    def _pt_out_of_box(self, pt):
        """Determines if a location is outside the min/max box.

        Args:
            pt (x, y, z): location

        Return:
            (:obj:`bool`): True if the point is outside the box
        """
        x_out = pt[0] < self._box_min[0] or pt[0] > self._box_max[0]
        y_out = pt[1] < self._box_min[1] or pt[1] > self._box_max[1]
        if x_out or y_out:
            return True
        return False
