"""DisuUgridStreamBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.misc import log_util


class DisuUgridStreamBuilder:
    """Stores disu information to help create a grid."""
    def __init__(self, locations2d, icvert, tops, bots, iac, ja, ihc):
        """Initializes the class.

        Args:
            locations2d: x,y locations from VERTICES block.
            icvert: 0-based array of int vertex numbers (in VERTICES block) used to define the cell, in clockwise order.
            tops: top elevation for each cell in the model grid
            bots: bottom elevation for each cell
            iac: number of connections (plus 1) for each cell.
            ja: list of cell number (n) followed by connecting cell numbers (m) for each m cells connected to cell n.
            ihc: index array indicating the direction between node n and all of its m connections
        """
        self.locations2d = locations2d
        self.icvert = icvert
        self.tops = tops
        self.bots = bots
        self.ug_cell_pts = []
        self.ug_cell_stream = []
        self.iac = iac
        self.ja = ja
        self.ihc = ihc
        self.cell_offset = [0] * len(self.iac)
        for i in range(len(self.iac) - 1):
            self.cell_offset[i + 1] = self.cell_offset[i] + self.iac[i]
        self.visited_cell = None
        self.column_idx = None
        self.columns = None
        self.layer_idx = None
        self.layers = None
        self.cell_layer_number = None
        self._log = log_util.get_logger()

    def build_ugrid_stream(self):
        """Builds the UGrid points and cell stream lists."""
        if self._must_build_unconnected_ugrid():
            self._build_unconnected_ugrid()
        else:
            self._build_connected_ugrid()

    def _must_build_unconnected_ugrid(self):
        self._log.info('Checking if possible to build ugrid with connected cells.')
        # if any IHC value is 2
        if 2 in self.ihc:
            return True

        self._build_columns()
        if not self.column_idx:
            return True
        self._build_layers()
        if not self.layer_idx:
            return True
        error = self._determine_cell_layer_numbers()
        return error

    def _build_columns(self):
        """Creates columns of cell indexes based on the information in the DISU file."""
        self._log.info('Building columns of cells.')
        self.visited_cell = [0] * len(self.iac)
        self.column_idx = [-1] * len(self.iac)
        error = False
        self.columns = []
        for i in range(len(self.iac)):
            if not self.visited_cell[i]:
                self.visited_cell[i] = 1
                column = [i]
                error = self._build_column(i, column)
                if not error:
                    idx = len(self.columns)
                    self.columns.append(column)
                    for j in range(len(column)):
                        self.column_idx[column[j]] = idx
            if error:
                self.column_idx = None
                self.columns = None
                break

    def _build_column(self, cell_idx, column):
        """Creates a list of cells in a single column ordered from top to bottom.

        Args:
            cell_idx: the index of the curent cell in a column. We recursively find the next cell below
            column: the cells in the column (this is filled in)

        Returns:
            error (bool): violates our assumptions for building a connected grid
        """
        cell_below, error = self._get_cell_below(cell_idx)
        if not error and cell_below is not None:
            self.visited_cell[cell_below - 1] = 1
            column.append(cell_below - 1)
            self._build_column(cell_below - 1, column)
        return error

    def _get_cell_below(self, cell_idx):
        """Find the cell below cell_idx.

        Args:
            cell_idx: index of the current cell

        Returns:
            cell_below (int), error (bool): index of the single cell below cell_idx, if can't build connected grid
        """
        offset = self.cell_offset[cell_idx]
        num_con = self.iac[cell_idx]
        error = False
        cell_below = None
        num_vert_connections = 0
        for i in range(1, num_con):
            idx = offset + i
            if self.ihc[idx] == 0:
                num_vert_connections += 1
                neighbor_cell_id = self.ja[idx]
                if neighbor_cell_id > cell_idx:
                    if cell_below is not None:
                        error = True
                    # TODO may need to rotate icvert to minimum cell value and then compare
                    if self.icvert[cell_idx] != self.icvert[neighbor_cell_id - 1]:
                        error = True
                    cell_below = neighbor_cell_id
        if num_vert_connections > 2:
            error = True
        return cell_below, error

    def _build_layers(self):
        """Creates layers of cell indexes based on the information in the DISU file."""
        self._log.info('Building layers from horizontal connections.')
        error = False
        self.visited_cell = [0] * len(self.iac)
        self.layer_idx = [None] * len(self.iac)
        self.layers = []
        for i in range(len(self.iac)):
            if not self.visited_cell[i]:
                layer = [i]
                self.visited_cell[i] = 1
                error = self._build_layer(layer)
                layer.sort()
                if not error:
                    idx = len(self.layers)
                    self.layers.append(layer)
                    for j in range(len(layer)):
                        self.layer_idx[layer[j]] = idx
            if error:
                self.layer_idx = None
                self.layers = None
                break

    def _build_layer(self, layer):
        """Builds a list of all cells in the same layer as cell_idx.

        Args:
            layer (list): list of cell indexes in the layer

        Returns:
            error (bool) : true if we can't build a connected grid
        """
        error = False
        idx = 0
        while not error and idx < len(layer):
            cell_idx = layer[idx]
            idx += 1
            adj_cells, error = self._get_unvisited_horizontal_neighbors_idx(cell_idx)
            if not error:
                layer.extend(adj_cells)
        return error

    def _get_unvisited_horizontal_neighbors_idx(self, cell_idx):
        """Get a list of the unvisited horizontal neighbors.

        Args:
            cell_idx (int): the current cell index

        Returns:
            neighbors (list), error (bool): list of adjacent cells, error flag
        """
        error = False
        neighbors = []
        offset = self.cell_offset[cell_idx]
        num_con = self.iac[cell_idx]
        cell_edges = set()
        cell_icvert = self.icvert[cell_idx]
        for i in range(len(cell_icvert)):
            cell_edges.add((cell_icvert[i - 1], cell_icvert[i]))
        matched_edges = []
        for i in range(1, num_con):
            idx = offset + i
            if self.ihc[idx] == 1:
                neigh_cell_idx = self.ja[idx] - 1
                shared_edge = self._find_shared_edge(cell_edges, neigh_cell_idx)
                # don't have a matching edge or multiple neighbors on same edge
                if not shared_edge or shared_edge in matched_edges:
                    error = True
                    break

                matched_edges.append(shared_edge)
                if not self.visited_cell[neigh_cell_idx]:
                    neighbors.append(neigh_cell_idx)
                    self.visited_cell[neigh_cell_idx] = 1
        return neighbors, error

    def _find_shared_edge(self, cell_edges, neigh_cell_idx):
        """Finds the shared edge index in cell_edges with the neighbor cell.

        Args:
            cell_edges (list): list of cell edges
            neigh_cell_idx (int): index of neighbor cell

        Returns:
            (int): True if an edge is shared.
        """
        neigh_icvert = self.icvert[neigh_cell_idx]
        for i in range(len(neigh_icvert)):
            # make the edge backwards to compare with the cell_edges passed in
            edge = (neigh_icvert[i], neigh_icvert[i - 1])
            if edge in cell_edges:
                return edge
        return False

    def _determine_cell_layer_numbers(self):
        self._log.info('Assigning layer numbers to cells.')
        self.visited_cell = [0] * len(self.iac)
        self.cell_layer_number = [None] * len(self.iac)
        error = False
        for i in range(len(self.iac)):
            if self.cell_layer_number[i] is None:
                error = self._assign_layer_to_cell(i, 0)
                if error:
                    return error
        min_layer = min(self.cell_layer_number)
        self.cell_layer_number = [x - min_layer for x in self.cell_layer_number]
        return error

    def _assign_layer_to_cell(self, cell_idx, layer_number):
        error = False
        processed_layer_idx = {self.layer_idx[cell_idx]}
        cell_idxs = self.layers[self.layer_idx[cell_idx]]
        dict_idx_layer = {idx: layer_number for idx in cell_idxs}
        for cidx in cell_idxs:
            layer = dict_idx_layer[cidx]
            if not self.visited_cell[cidx]:
                self.visited_cell[cidx] = True
                self.cell_layer_number[cidx] = layer
                column = self.columns[self.column_idx[cidx]]
                col_idx = column.index(cidx)
                for j in range(len(column)):
                    cidx2 = column[j]
                    if self.layer_idx[cidx2] not in processed_layer_idx:
                        if not self.visited_cell[column[j]]:
                            processed_layer_idx.add(self.layer_idx[cidx2])
                            new_layer_number = layer_number - col_idx + j
                            cell_idxs2 = self.layers[self.layer_idx[cidx2]]
                            d2 = {idx: new_layer_number for idx in cell_idxs2}
                            dict_idx_layer = {**d2, **dict_idx_layer}
                            cell_idxs.extend(cell_idxs2)  # noqa B038 editing a loop's mutable iterable
        return error

    def _build_unconnected_ugrid(self):
        """Creates a blocky ugrid that may have tears or overlaps."""
        self._log.info('Building ugrid with unconnected cells.')
        pt_hash = dict()
        self.ug_cell_pts = []
        self.ug_cell_stream = []
        for i in range(len(self.iac)):
            num_pts = len(self.icvert[i])
            if self.icvert[i][0] == self.icvert[i][-1]:
                num_pts -= 1
            cell_stream = []
            cell_bot_face_pts = []
            cell_top_face_pts = []
            for j in range(num_pts):
                # idx = self.icvert[i][j] - 1
                idx = self.icvert[i][j]  # It's already 0-based
                pt = (self.locations2d[idx][0], self.locations2d[idx][1], self.bots[i])
                cell_bot_face_pts.append(self._hash_pt(pt_hash, pt, self.ug_cell_pts))
                pt = (self.locations2d[idx][0], self.locations2d[idx][1], self.tops[i])
                cell_top_face_pts.append(self._hash_pt(pt_hash, pt, self.ug_cell_pts))
            cell_stream.append(UGrid.cell_type_enum.POLYHEDRON)  # cell type
            cell_stream.append(num_pts + 2)  # number of faces
            # add bottom face
            cell_stream.append(num_pts)  # number of points on face
            cell_stream.extend(cell_bot_face_pts)
            for j in range(num_pts):
                next_idx = (j + 1) % num_pts
                cell_stream.append(4)  # number of points on face
                cell_stream.append(cell_bot_face_pts[j])
                cell_stream.append(cell_top_face_pts[j])
                cell_stream.append(cell_top_face_pts[next_idx])
                cell_stream.append(cell_bot_face_pts[next_idx])
            # add top face
            cell_top_face_pts.reverse()
            cell_stream.append(num_pts)  # number of points on face
            cell_stream.extend(cell_top_face_pts)
            self.ug_cell_stream.extend(cell_stream)

    def _hash_pt(self, pt_hash, pt, ug_cell_pts):
        pt_idx = pt_hash.setdefault(pt, len(pt_hash))
        if pt_idx == len(ug_cell_pts):
            ug_cell_pts.append(pt)
        return pt_idx

    def _build_connected_ugrid(self):
        self._log.info('Building ugrid with connected cells.')
        max_layer = max(self.cell_layer_number)
        ugrid_pt_idx_layers = []
        for _i in range(max_layer + 2):
            ugrid_pt_idx_layers.append([None] * len(self.locations2d))
        self.ug_cell_pts = []
        self.ug_cell_stream = []
        for i in range(len(self.iac)):
            cell_layer = self.cell_layer_number[i]
            num_pts = len(self.icvert[i])
            if self.icvert[i][0] == self.icvert[i][-1]:
                num_pts -= 1
            cell_stream = []
            cell_bot_face_pts = []
            cell_top_face_pts = []
            for j in range(num_pts):
                # idx = self.icvert[i][j] - 1
                idx = self.icvert[i][j]  # It's already 0-based
                if ugrid_pt_idx_layers[cell_layer][idx] is None:
                    top_pt = (self.locations2d[idx][0], self.locations2d[idx][1], self.tops[i])
                    ugrid_pt_idx_layers[cell_layer][idx] = len(self.ug_cell_pts)
                    self.ug_cell_pts.append(top_pt)
                if ugrid_pt_idx_layers[cell_layer + 1][idx] is None:
                    bot_pt = (self.locations2d[idx][0], self.locations2d[idx][1], self.bots[i])
                    ugrid_pt_idx_layers[cell_layer + 1][idx] = len(self.ug_cell_pts)
                    self.ug_cell_pts.append(bot_pt)
                cell_bot_face_pts.append(ugrid_pt_idx_layers[cell_layer + 1][idx])
                cell_top_face_pts.append(ugrid_pt_idx_layers[cell_layer][idx])
            cell_stream.append(UGrid.cell_type_enum.POLYHEDRON)  # cell type
            cell_stream.append(num_pts + 2)  # number of faces
            # add bottom face
            cell_stream.append(num_pts)  # number of points on face
            cell_stream.extend(cell_bot_face_pts)
            for j in range(num_pts):
                next_idx = (j + 1) % num_pts
                cell_stream.append(4)  # number of points on face
                cell_stream.append(cell_bot_face_pts[j])
                cell_stream.append(cell_top_face_pts[j])
                cell_stream.append(cell_top_face_pts[next_idx])
                cell_stream.append(cell_bot_face_pts[next_idx])
            # add top face
            cell_top_face_pts.reverse()
            cell_stream.append(num_pts)  # number of points on face
            cell_stream.extend(cell_top_face_pts)
            self.ug_cell_stream.extend(cell_stream)
