"""Ugrid3dFromRasters class."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
import uuid

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.constraint.ugrid_extrude import extrude_grid
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu

# 4. Local modules


class Columns:
    """Names of the columns in the rasters table."""
    HORIZON = 'Horizon'
    RASTER = 'Raster'
    FILL = 'Fill'
    CLIP = 'Clip'
    SUBLAYERS = 'Sublayers'
    PROPORTIONS = 'Proportions'


class UGrid3dFromRastersCreator:
    """Creates a 3D UGrid from a 2D UGrid and rasters.

    A 'sheet' is the interface between cell layers. A 'raster sheet' is a sheet that corresponds to a raster
    (layers can be subdivided and thus there can be sheets that don't correspond to a raster).

    Non-raster sheets are ignored until _set_sublayer_elevations() - only raster sheets are considered in
    _fill_and_clip() and _fix_negative_thicknesses().

    """
    def __init__(self, default_wkt: str, logger) -> None:
        """Initializes the class.

        Args:
            default_wkt (str): The display projection wkt.
            logger: Logger.
        """
        # Args
        self._co_grid_2d = None
        self._rasters: pd.DataFrame | None = None
        self._target_location: str = ''
        self._minimum_layer_thickness: float = 0.0
        self._default_wkt = default_wkt  # Display projection
        self._logger = logger

        # Other vars
        self._total_layer_count: int = 0
        self._sheet_count: int = 0  # Total number of sheets including subdivisions
        self._raster_sheets: list[int] = []  # Sheets associated with rasters, num rasters, top to bottom
        self._locs_2d = []
        self._co_grid_3d = None  # The new 3D CoGrid
        self._ugrid_3d = None  # The new 3D UGrid
        self._elevs_3d = None  # 1D list of the point or cell locations in the 3D grid

    def create_3d_co_grid(
        self, co_grid_2d, rasters: pd.DataFrame, target_location: str, minimum_layer_thickness: float
    ):
        """Creates and returns the 3D CoGrid.

        Args:
            co_grid_2d: The 2D UGrid used to create the 3D UGrid.
            rasters (pd.DataFrame): The rasters table data frame.
            target_location (str): 'Cell tops and bottoms', or 'Points'.
            minimum_layer_thickness (float): Minimum thickness that every layer must have.

        Returns:
            The 3D CoGrid.
        """
        self._co_grid_2d = co_grid_2d
        self._rasters = rasters
        self._target_location = target_location
        self._minimum_layer_thickness = minimum_layer_thickness

        # Algorithm
        self._get_layer_count_and_raster_sheets()
        if self._total_layer_count < 1:
            if self._logger:
                self._logger.error('No layers would be created.')
            return None
        self._fill_and_clip()
        self._fix_negative_thicknesses()
        self._set_sublayer_elevations()
        self._honor_minimum_thickness()
        self._extrude_3d_grid()
        self._set_grid_elevations()
        return self._co_grid_3d

    def _get_layer_count_and_raster_sheets(self):
        """Computes the total number of layers and the sheets associated with rasters."""
        self._total_layer_count, self._raster_sheets = compute_total_num_layers(self._rasters)

    def _extrude_3d_grid(self):
        """Extrudes the 2d grid into a 3d grid."""
        self._co_grid_3d = extrude_grid(self._co_grid_2d, [1.0] * self._total_layer_count)

    def _get_2d_locations(self, ugrid_2d):
        """Returns the locations we're interested in - either cell centers or points."""
        if self._target_location == 'Points':
            locations = ugrid_2d.locations
        else:
            locations = [ugrid_2d.get_cell_centroid(i)[1] for i in range(ugrid_2d.cell_count)]
        return locations

    def _fill_and_clip(self):
        """Sets the elevations of the raster sheets from the rasters."""
        ugrid_2d = self._co_grid_2d.ugrid
        self._locs_2d = self._get_2d_locations(ugrid_2d)  # Either point or cell center locations
        self._sheet_count = self._total_layer_count + 1
        count_per_sheet = len(self._locs_2d)
        self._elevs_3d = np.zeros((self._sheet_count, count_per_sheet))

        # From bottom to top
        for reverse_idx, (raster_file, fill, clip, raster_sheet) in enumerate(
            zip(
                reversed(self._rasters[Columns.RASTER].to_list()), reversed(self._rasters[Columns.FILL].to_list()),
                reversed(self._rasters[Columns.CLIP].to_list()), reversed(self._raster_sheets)
            )
        ):
            if fill or clip:
                elevs_2d = self._get_raster_elevations_at_points(raster_file)
                if fill:
                    self._elevs_3d[raster_sheet] = elevs_2d  # Set elevations from raster
                if clip:
                    self._clip(reverse_idx, elevs_2d)

    def _clip(self, reverse_idx: int, elevs_2d: np.array) -> None:
        """Adjusts elevations at all rasters sheets below reverse_idx to be at or below elevs_2d.

        Args:
            reverse_idx (int): Index into self._raster_sheets from the end.
            elevs_2d (np.array): One sheet of elevations from a raster.
        """
        forward_idx = len(self._raster_sheets) - reverse_idx - 1  # Index from beginning (idx is from end)
        # From the raster sheet below to the bottom raster sheet
        for raster_idx in range(forward_idx + 1, len(self._raster_sheets)):
            if self._rasters[Columns.FILL].iloc[raster_idx]:
                sheet_idx = self._raster_sheets[raster_idx]
                for i in range(self._elevs_3d.shape[1]):
                    if self._elevs_3d[sheet_idx][i] > elevs_2d[i]:
                        self._elevs_3d[sheet_idx][i] = elevs_2d[i]

    def _get_raster_elevations_at_points(self, raster_file: str) -> np.array:
        """Returns the raster elevations at the points.

        Variable names in this function are short so that we can fit np.fromiter() all on one line.

        Args:
            raster_file (str): Path of raster file.

        Returns:
            (np.array): Array of elevations.
        """
        r = RasterInput(raster_file)
        l2d_tx = gu.transform_points_from_wkt(self._locs_2d, self._default_wkt, r.wkt)
        e2d = np.fromiter((r.get_raster_value_at_loc(lc[0], lc[1], interpolate=False) for lc in l2d_tx), dtype='float')
        return e2d

    def _fix_negative_thicknesses(self):
        """Fixes raster layers with negative thicknesses to have zero thickness."""
        # From bottom to top
        for raster_idx in range(len(self._raster_sheets) - 2, -1, -1):
            raster_sheet = self._raster_sheets[raster_idx]
            if raster_sheet == -1:
                continue
            for raster_idx_below in range(raster_idx + 1, len(self._raster_sheets)):
                raster_sheet_below = self._raster_sheets[raster_idx_below]
                if raster_sheet_below == -1:
                    continue
                for i in range(self._elevs_3d.shape[1]):
                    if self._elevs_3d[raster_sheet][i] < self._elevs_3d[raster_sheet_below][i]:
                        self._elevs_3d[raster_sheet][i] = self._elevs_3d[raster_sheet_below][i]

    def _set_sublayer_elevations(self):
        """Sets the elevations on all the sublayers."""
        count_per_sheet = len(self._locs_2d)
        # From top to bottom
        for raster_idx, (fill, sublayers, proportions) in enumerate(
            zip(
                self._rasters[Columns.FILL].to_list(), self._rasters[Columns.SUBLAYERS].to_list(),
                self._rasters[Columns.PROPORTIONS].to_list()
            )
        ):
            if fill and sublayers > 1:
                proportions_list, accumulated, total = self._get_proportions_info(proportions, raster_idx, sublayers)
                top_raster_sheet = self._raster_sheets[raster_idx]
                bottom_raster_sheet = next((s for s in self._raster_sheets[raster_idx + 1:] if s != -1), -1)
                if bottom_raster_sheet == -1:  # pragma no cover - shouldn't happen but just in case
                    continue
                for i in range(count_per_sheet):
                    top = self._elevs_3d[top_raster_sheet][i]
                    bottom = self._elevs_3d[bottom_raster_sheet][i]
                    if proportions_list:
                        base_thickness = (top - bottom) / total
                        for n in range(1, sublayers):
                            self._elevs_3d[top_raster_sheet + n][i] = top - (accumulated[n - 1] * base_thickness)
                    else:  # Space them equally
                        base_thickness = (top - bottom) / sublayers
                        for n in range(1, sublayers):
                            self._elevs_3d[top_raster_sheet + n][i] = top - (n * base_thickness)

    def _honor_minimum_thickness(self):
        """Makes layers at least as thick as the minimum thickness."""
        if self._minimum_layer_thickness == 0.0:
            return

        # From top to bottom, preserved top and move elevations down as needed
        for sheet in range(self._elevs_3d.shape[0] - 1):
            for i in range(self._elevs_3d.shape[1]):
                if self._elevs_3d[sheet][i] - self._elevs_3d[sheet + 1][i] < self._minimum_layer_thickness:
                    self._elevs_3d[sheet + 1][i] = self._elevs_3d[sheet][i] - self._minimum_layer_thickness

    def _get_proportions_info(self, proportions_str: str, row_idx: int,
                              sublayers: int) -> tuple[list[int], list[int], int]:
        """Returns the information needed to handle proportions, given the proportions string.

        Args:
            proportions_str (str): Space delimited string like '1 1 2'
            row_idx (int): Row in the rasters table.
            sublayers (int): Number of sublayers.

        Returns:
            (tuple[list[int], list[int], int]): Proportions as int list, the same but summed, total of the proportions.
        """
        if not proportions_str:
            return [], [], 0

        accumulated = []
        total = 0
        errors: dict[str, str] = {}
        key = 'rasters_table'
        proportions_list = get_proportions_list(proportions_str, row_idx, sublayers, errors, key)
        if key in errors:
            self._logger.error(errors[key])
        elif proportions_list:
            total = sum(proportions_list)
            accumulated = np.cumsum(proportions_list)
        return proportions_list, accumulated, total

    def _compute_model_on_off_cells(self) -> list[int]:
        """Computes model on/off cells vector marking 0 thickness cells as pass through."""
        elevs_3d_flat = self._elevs_3d.ravel()  # Returns a view
        self._ugrid_3d = self._co_grid_3d.ugrid  # Only want to do this once
        cell_count_3d = self._ugrid_3d.cell_count
        model_on_off_cells = [-1] * cell_count_3d  # init all to "pass through"
        if self._target_location == 'Points':
            for cell_idx in range(cell_count_3d):
                cell_points = self._ugrid_3d.get_cell_points(cell_idx)
                half_point_count = len(cell_points) // 2
                for i in range(half_point_count):
                    if elevs_3d_flat[cell_points[i]] - elevs_3d_flat[cell_points[i + half_point_count]] > 0.0:
                        model_on_off_cells[cell_idx] = 1  # It has some thickness. Make it an "on" cell
                        break
        else:
            count_per_sheet = len(self._locs_2d)
            for cell_idx in range(cell_count_3d):
                if elevs_3d_flat[cell_idx] - elevs_3d_flat[cell_idx + count_per_sheet] > 0.0:
                    model_on_off_cells[cell_idx] = 1  # It has some thickness. Make it an "on" cell
        return model_on_off_cells

    def _set_grid_elevations(self):
        """Sets the point elevations."""
        model_on_off_cells = self._compute_model_on_off_cells()
        self._co_grid_3d.model_on_off_cells = model_on_off_cells
        elevs_3d_flat = self._elevs_3d.ravel()  # Returns a view
        if self._target_location == 'Points':
            self._co_grid_3d.delete_cell_tops_and_bottoms()
            self._co_grid_3d.custom_point_elevations = elevs_3d_flat
        else:
            self._co_grid_3d.set_cell_tops_and_bottoms(
                elevs_3d_flat[:-self._elevs_3d.shape[1]], elevs_3d_flat[self._elevs_3d.shape[1]:]
            )

    @staticmethod
    def get_uuid():
        """Returns a random uuid string (or maybe not so random if testing).

        Returns:
            (str): uuid string.
        """
        return str(uuid.uuid4())


def compute_total_num_layers(df: pd.DataFrame) -> tuple[int, list[int]]:
    """Returns the total number of layers we will need, and the raster sheet indexes list.

    Args:
        df (pd.DataFrame): The rasters table data frame.

    Returns:
        (tuple[int, list[int]]): See description.
    """
    count = 0
    raster_sheets = []
    for _index, row in df.iterrows():
        if row[Columns.FILL]:
            raster_sheets.append(count)
            count += row[Columns.SUBLAYERS]
        else:
            raster_sheets.append(-1)
    return count, raster_sheets


def get_proportions_list(proportions_str: str, row_idx: int, sublayers: int, errors: dict[str, str],
                         key: str) -> list[int]:
    """Returns an integer list given a space delimited string list, adding errors to errors arg if there is trouble.

    Args:
        proportions_str (str): Space delimited string like '1 1 2'.
        row_idx (int): Table row.
        sublayers (int): Number of sublayers.
        errors (dict[str, str]): Dictionary of errors keyed by argument name.
        key (str): Key to use in errors dict.

    Returns:
        (tuple[list[int], list[int], int]): Proportions as int list, the same but summed, total of the proportions.
    """
    if not proportions_str:
        return []

    try:
        int_list = [int(n) for n in proportions_str.split()]
    except ValueError:
        errors[key] = (
            f'Row {row_idx + 1}: Proportions must be a space delimited list of integers'
            ' (e.g. "1 1 2"), or left blank to indicate sublayers will be evenly spaced.'
        )
        return []
    else:
        if len(int_list) != sublayers:
            errors[key] = (
                f'Row {row_idx + 1}: Proportions list must be the same length as the number of'
                ' sublayers.'
            )
            return []
    return int_list


def write_sublayers_file(filepath: Path | str, df: pd.DataFrame) -> None:
    """Creates a file with two columns: Layer, Sublayers.

    Args:
        filepath: Path of file that will be created.
        df: The dataframe from the tool.
    """
    df = df[df[Columns.FILL] == 1]  # Don't include rows where we are not filling
    df = df.reset_index()
    df['Layer'] = df.index + 1  # Add Layer column from 1 to num rows
    df = df.drop(df.index[-1])  # Don't include the last row
    df.to_csv(filepath, columns=['Layer', 'Sublayers'], index=False)
