"""GridIJCreator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules

# 4. Local modules

# Adjacent cell neighbors
NEIGHBOR_LOC_TOP = 0
NEIGHBOR_LOC_LEFT = 1
NEIGHBOR_LOC_BOTTOM = 2
NEIGHBOR_LOC_RIGHT = 3
NEIGHBOR_LOC_END = 4  # Just for iteration
# Point locations of quad cells
LOC_BOTTOM_LEFT = 0
LOC_BOTTOM_RIGHT = 1
LOC_TOP_RIGHT = 2
LOC_TOP_LEFT = 3
LOC_END = 4  # Just for iteration

NULL_COORD = 9000000000000000000.0  # Only for CH3D format, just needed a shared place to put this.
NULL_POINT = -1


class GridIJCreator:
    """Class for building i-j datasets from a UGrid with all 2D quad cells."""

    def __init__(self, ugrid, logger):
        """Constructor.

        Args:
            ugrid (UGrid): The UGrid to create cell i-j datasets for.
            logger (logging.Logger): Feedback logger
        """
        self._logger = logger
        self._ugrid = ugrid
        self._cell_i = None
        self._cell_j = None
        self._cell_orientation = None
        self._processed = None  # Have we recursed through a cells neighbors?
        self._visited = None  # Have we assigned a cell IJ values?
        self.errors = []
        # Optional cell-based datasets that will be added as columns to the output DataFrame. key=column name,
        # value=column values
        self.cell_datasets = {}

    def _find_neighbor_edge(self, origin_idx, neighbor_idx, edge_idx):
        """Find the location of an adjacent cell with respect to an origin cell.

        Args:
            origin_idx (int): Cellstream index of the origin cell
            neighbor_idx (int): Cellstream index of the neighbor cell
            edge_idx (int): Edge index (in the origin cell's edge list) of the neighbors' shared edge

        Returns:
            int: The location of the neighbor cell with respect to the origin cell, one of the NEIGHBOR_LOC_* constants
        """
        # Find the points of the two cells that make up the shared edge
        edge_pts = self._ugrid.get_cell_edge(origin_idx, edge_idx)
        for ei, edge in enumerate(self._ugrid.get_cell_edges(neighbor_idx)):
            if edge[0] in edge_pts and edge[1] in edge_pts:
                return ei
        # return None  - I prefer explicit returns, but I didn't want to come up with a test to cover this.

    def _edge_adcent_cell(self, cell_idx, edge_idx):
        """Get the cell adjacent to a cell on a particular edge.

        Args:
            cell_idx (int): Cellstream index of the cell
            edge_idx (int): Edge index of the cell to look for a neighbor on

        Returns:
            int: Index of the neighboring cell or None if it doesn't exist
        """
        adj_cells = self._ugrid.get_cell_edge_adjacent_cells(cell_idx, edge_idx)
        if len(adj_cells) == 1:
            return adj_cells[0]
        return None

    def _neighbor(self, cell_idx, neighbor_idx):
        """Finds a cell neighbor across the specified edge.

        Args:
            cell_idx (int): Cellstream index of the start cell
            neighbor_idx (int): edge index of cell_idx for neighbor being found

        Returns:
            tuple(int, int): The neighbor cell's index and the edge of the neighbor cell back to cell_idx
        """
        adj_cell = self._edge_adcent_cell(cell_idx, neighbor_idx)
        adj_cell_edge = None
        if adj_cell is not None:
            adj_cell_edge = self._find_neighbor_edge(cell_idx, adj_cell, neighbor_idx)
        return adj_cell, adj_cell_edge

    def _check_cell_neighbor(self, cell_idx, neighbor_idx):
        """Assign I/J to neighbors of the cell that are not yet assigned validating consistency.

        Args:
            cell_idx (int): cell that is being assigned

            neighbor_idx (int): which of the four neighbors of the cell_idx we are checking
        """
        # get the cells classification
        i = self._cell_i[cell_idx]
        j = self._cell_j[cell_idx]
        orientation = self._cell_orientation[cell_idx]

        # get the neighbor
        neighbor_cell_idx, neighbor_cell_edge = self._neighbor(cell_idx, neighbor_idx)

        # if the neighbor exists, process it
        if neighbor_cell_idx is not None:
            neighbor_i = i
            neighbor_j = j
            # compute the IJ for the top neighbor
            if neighbor_idx == NEIGHBOR_LOC_TOP:
                if orientation == 0:
                    neighbor_j += 1
                elif orientation == 1:
                    neighbor_i -= 1
                elif orientation == 2:
                    neighbor_j -= 1
                else:  # orientation == 3
                    neighbor_i += 1
            elif neighbor_idx == NEIGHBOR_LOC_LEFT:
                if orientation == 0:
                    neighbor_i -= 1
                elif orientation == 1:
                    neighbor_j -= 1
                elif orientation == 2:
                    neighbor_i += 1
                else:  # orientation == 3
                    neighbor_j += 1
            elif neighbor_idx == NEIGHBOR_LOC_BOTTOM:
                if orientation == 0:
                    neighbor_j -= 1
                elif orientation == 1:
                    neighbor_i += 1
                elif orientation == 2:
                    neighbor_j += 1
                else:  # orientation == 3
                    neighbor_i -= 1
            else:  # neigthbor_idx == NEIGHBOR_LOC_RIGHT
                if orientation == 0:
                    neighbor_i += 1
                elif orientation == 1:
                    neighbor_j += 1
                elif orientation == 2:
                    neighbor_i -= 1
                else:  # orientation == 3
                    neighbor_j -= 1
            orientation_adjustment = 2 + neighbor_idx - neighbor_cell_edge
            if orientation_adjustment < 0:
                orientation_adjustment += 4
            if orientation_adjustment > 3:
                orientation_adjustment -= 4
            neighbor_orientation = orientation + orientation_adjustment
            if neighbor_orientation > 3:
                neighbor_orientation -= 4

            if self._visited[neighbor_cell_idx]:  # make sure we are consistent
                if self._cell_i[neighbor_cell_idx] != neighbor_i or \
                   self._cell_j[neighbor_cell_idx] != neighbor_j or \
                   self._cell_orientation[neighbor_cell_idx] != neighbor_orientation:
                    # add this cell to the list of error cells
                    self.errors.append(f'Invalid ij numbering found at cell {(neighbor_cell_idx + 1)}. '
                                       f'Expected: ({int(self._cell_i[cell_idx])}, {int(self._cell_j[cell_idx])}) '
                                       f'Found: ({int(neighbor_i)}, {int(neighbor_j)})')
            else:  # classify neighbor
                self._visited[neighbor_cell_idx] = True
                self._cell_i[neighbor_cell_idx] = neighbor_i
                self._cell_j[neighbor_cell_idx] = neighbor_j
                self._cell_orientation[neighbor_cell_idx] = neighbor_orientation

    def _visit_cell_neighbors(self, cell_idx):
        """Assign I/J to neighbors of the cell that are not yet assigned validating consistency.

        Args:
            cell_idx (int): cell index for cell whose neighbors are being checked
        """
        # only do this if the cell has not been processed already (no need to repeat it)
        if not self._processed[cell_idx]:
            self._processed[cell_idx] = True
            self._check_cell_neighbor(cell_idx, NEIGHBOR_LOC_TOP)
            self._check_cell_neighbor(cell_idx, NEIGHBOR_LOC_LEFT)
            self._check_cell_neighbor(cell_idx, NEIGHBOR_LOC_BOTTOM)
            self._check_cell_neighbor(cell_idx, NEIGHBOR_LOC_RIGHT)

    def _compute_cell_ij_from_ugrid(self):
        """Build cell i and j coordinate datasets from the input UGrid geometry."""
        self._logger.info('Generating cell i and j datasets from input UGrid...')
        # initialize the arrays for computing IJ datasets
        num_cells = self._ugrid.cell_count
        self._cell_i = np.full(num_cells, np.nan)
        self._cell_j = np.full(num_cells, np.nan)
        self._cell_orientation = np.full(num_cells, np.nan)
        self._visited = np.full(num_cells, False)
        self._processed = np.full(num_cells, False)

        # set the first cell as I,J = 1,1 of the grid orientation matching cell orientation
        cell_idx = 0
        self._visited[cell_idx] = True
        self._cell_i[cell_idx] = 1
        self._cell_j[cell_idx] = 1
        self._cell_orientation[cell_idx] = 2
        # classify the neighbors of the first cell
        self._visit_cell_neighbors(cell_idx)

        # Loop until all cells are processed
        need_to_visit_cells = True
        while need_to_visit_cells:
            need_to_visit_cells = False
            # Loop through all the cells
            for cell_idx in range(self._ugrid.cell_count):
                if not self._visited[cell_idx]:
                    need_to_visit_cells = True
                else:
                    self._visit_cell_neighbors(cell_idx)

    def _renumber(self, df, start_ij):
        """Renumber the i and j indices so there are no gaps and start from specified value.

        Args:
            start_ij (int): Number to start i and j coordinates from
            df (pd.DataFrame): The sorted i-j DataFrame

        Returns:
            pd.DataFrame: The grid cell i-j dataset indexed by 'i' and 'j' with a 'cell_idx' column containing the
            XmUGrid cellstream index for the i-j coordinate pair index.
        """
        # Renumber the i and j DataFrame MultiIndex such that the lowest i and lowest j == start_ij. Create a mapping
        # from old i-j coordinates to new so we can preserve the data order.
        i_value_counts = df.index.get_level_values('i').value_counts(sort=False)
        j_value_counts = df.index.get_level_values('j').value_counts(sort=False)
        tempi = np.sort(i_value_counts.index.values)
        tempj = np.sort(j_value_counts.index.values)
        normi = np.arange(start_ij, i_value_counts.size + start_ij)
        normj = np.arange(start_ij, j_value_counts.size + start_ij)
        i_lookup = {ti: ni for ti, ni in zip(tempi, normi)}
        j_lookup = {tj: nj for tj, nj in zip(tempj, normj)}
        new_i = []
        new_j = []
        for ii in range(df.index.size):
            new_i.append(i_lookup[df.index[ii][0]])
            new_j.append(j_lookup[df.index[ii][1]])
        # Build the new DataFrame with the normalized i-j MultiIndex.
        idx = pd.MultiIndex.from_arrays([new_i, new_j], names=['i', 'j'])
        columns = {'cell_idx': df.cell_idx.values, 'cell_orientation': df.cell_orientation.values}
        columns.update({column_name: df[column_name].values for column_name in self.cell_datasets})
        return pd.DataFrame(columns, index=idx, copy=False)

    def _check_for_duplicate_ij(self, df):
        """Check the cell i-j dataset for cells with non-unique i-j coordinates.

        Args:
            df (pd.DataFrame): The cell i-j dataset

        Returns:
            bool: True if all the cells have unique coordinates
        """
        mask = df.index.duplicated(keep=False)
        dup_df = df[mask]
        if not dup_df.empty:
            cellstring = ', '.join((dup_df.cell_idx.values + 1).astype(str))
            self.errors.append(
                f'The following cells have duplicate i-j coordinates with at least one other cell: {cellstring}'
            )

    def set_cell_ij(self, cell_i, cell_j):
        """Set cell i-j coordinates from existing datasets.

        Args:
            cell_i (list): The cell i-coordinates in cellstream order
            cell_j (list): The cell j-coordinates in cellstream order
        """
        self._logger.info('Using preexisting cell i and j datasets.')
        self._cell_i = cell_i
        self._cell_j = cell_j
        num_cells = len(self._cell_i)
        self._cell_orientation = np.full(num_cells, 2)  # mark all cells as aligned with the curvilinear grid
        self._processed = np.full(num_cells, False)

    def create_ij_dataset(self, start_ij=1):
        """Build a DataFrame associating i-j values with cell indices.

        Args:
            start_ij (int): Number to start i and j coordinates from

        Returns:
            pd.DataFrame: The grid cell i-j dataset indexed by 'i' and 'j' with a 'cell_idx' column containing the
            XmUGrid cellstream index for the i-j coordinate pair index. Returns None if error occurred.
        """
        if self._cell_i is None or self._cell_j is None:  # If input datasets not specified, generate them
            self._compute_cell_ij_from_ugrid()
        if self.errors:
            raise RuntimeError('\n'.join(self.errors))

        # Build the DataFrame
        data_cols = {'cell_idx': np.arange(self._cell_i.size), 'cell_orientation': self._cell_orientation}
        data_cols.update(self.cell_datasets)  # Add any optional datasets that need to tag along with the cell index.
        df_index = pd.MultiIndex.from_arrays([self._cell_i, self._cell_j], names=['i', 'j'])
        df = pd.DataFrame(data=data_cols, index=df_index)

        self._logger.info('Checking for duplicate i-j coordinates...')
        self._check_for_duplicate_ij(df)
        if self.errors:
            raise RuntimeError('\n'.join(self.errors))

        self._logger.info('Sorting dataset by i and j indices...')
        df.sort_values(['i', 'j'], inplace=True)

        self._logger.info('Renumbering i and j indices...')
        return self._renumber(df, start_ij)
