"""Removes adverse slopes."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from logging import Logger
from pathlib import Path
import shutil
from typing import MutableSequence, Sequence

# 2. Third party modules
from geopandas import GeoDataFrame
from shapely import LineString

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.grid.geometry import geometry

# 4. Local modules
from xms.gssha.components import dmi_util, gmi_util
from xms.gssha.components.bc_coverage_component import BcCoverageComponent
from xms.gssha.data import bc_util
from xms.gssha.data.bc_util import feature_and_index_from_id_and_type, NodeArcs
from xms.gssha.data.data_util import gt_tol, lt_tol, lteq_tol

# Constants
Z = 2  # z index (ie. x == 0, y == 1, z == 2)
MIN_ELEV_DIFF = 0.01
ITER_MAX = 20  # Max number of iterations when removing solitary high spots in the profile
ZERO_TOL = 0.000001  # (1e-6) XM_ZERO_TOL from xms


def run(query: Query, bc_cov: GeoDataFrame, new_name: str, logger: Logger) -> GeoDataFrame:
    """Creates a new coverage in which the stream arc elevations have no adverse slopes (always flow downhill).

    Args:
        query: Object for communicating with XMS
        bc_cov: BC coverage
        new_name: Name for the new coverage
        logger: The logger.

    Returns:
        The new coverage.
    """
    remover = AdverseSlopeRemover(query, bc_cov, new_name, logger)
    return remover.run()


class AdverseSlopeRemover:
    """Create a new coverage in which the stream arc elevations have no adverse slopes (always flow downhill)."""
    def __init__(self, query: Query, bc_cov: GeoDataFrame, new_name: str, logger: Logger) -> None:
        """Initializer.

        Args:
            query: Object for communicating with XMS
            bc_cov: BC coverage
            new_name: Name for the new coverage
            logger: The logger.
        """
        self._query = query
        self._bc_cov = bc_cov
        self._new_name = new_name
        self._logger = logger

        self._bc_comp: BcCoverageComponent = dmi_util.get_bc_coverage_component(
            self._bc_cov.attrs['uuid'], self._query
        ) if (self._bc_cov is not None) else None
        self._new_bc_cov: 'GeoDataFrame | None' = None
        self._new_bc_comp: BcCoverageComponent | None = None
        self._arcs: list[tuple] = []
        self._node_arcs: NodeArcs | None = None
        self._changes_made: bool = False

    def run(self):
        """Creates a new coverage in which the stream arc elevations have no adverse slopes (always flow downhill).

        See iRemoveStreamAdverseSlopes() in creategrid.cpp in WMS.
        """
        try:
            self._copy_coverage_and_component()
            self._get_stream_arcs()
            self._remove_adverse_slopes()
            self._add_to_query_if_changes()
        except RuntimeError as error:
            self._delete_copies()
            if self._logger:
                self._logger.error(str(error))
        except Exception as exception:
            self._delete_copies()
            if self._logger:
                self._logger.error(str(exception))
        return self._new_bc_cov

    def _copy_coverage_and_component(self) -> None:
        """Copies the coverage and the coverage component."""
        self._logger.info('Copying the coverage...')
        rv = gmi_util.copy_coverage_and_component(self._bc_cov, self._bc_comp, self._query)
        self._new_bc_cov, self._new_bc_comp = rv[0], rv[1]
        self._new_bc_cov.attrs['name'] = self._new_name

    def _get_stream_arcs(self) -> None:
        """Gets the stream arcs."""
        self._logger.info('Getting stream arcs...')
        self._arcs, self._node_arcs = bc_util.get_stream_arcs(self._query, self._new_bc_cov, self._new_bc_comp)
        if not self._arcs:
            raise RuntimeError('No stream arcs found. Aborting.')

    def _remove_adverse_slopes(self) -> None:
        """Removes the adverse slopes."""
        self._logger.info('Removing adverse slopes...')
        self._changes_made = _remove_stream_adverse_slopes(self._bc_cov, self._arcs, self._node_arcs)

    def _add_to_query_if_changes(self) -> None:
        """Add new coverage and component to the query if we made changes, or warn and cleanup if we didn't."""
        if self._changes_made:
            # Add to query
            self._logger.info('Adding coverage...')
            unique_name = 'BcCoverageComponent'
            coverage_type = 'Boundary Conditions'
            gmi_util.add_to_query(self._new_bc_cov, self._new_bc_comp, unique_name, coverage_type, 'GSSHA', self._query)
        else:
            self._logger.warning('No adverse slopes found. No changes made.')
            self._delete_copies()

    def _delete_copies(self):
        """Delete the copies."""
        if self._new_bc_comp and self._new_bc_comp.main_file and Path(self._new_bc_comp.main_file).exists():
            shutil.rmtree(Path(self._new_bc_comp.main_file).parent)  # Delete the new cov comp directory
        self._new_bc_cov = None


def _remove_stream_adverse_slopes(coverage: GeoDataFrame, arcs: list[tuple], node_arcs: NodeArcs) -> bool:
    r"""Removes adverse slopes on stream arcs.

    Adapted for iRemoveStreamAdverseSlopes() in creategrid.cpp in WMS.

    Args:
        coverage: The coverage
        arcs: List of all stream arcs.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.

    Returns:
        True if an adverse slope was found and changes were made.
    """
    upstream_arcs = _find_upstream_arcs(coverage, arcs, node_arcs)
    changes = False
    for arc in upstream_arcs:
        arcs = _get_all_downstream_arcs(coverage, arc, node_arcs)  # (in order, first one is arc)
        changes = _remove_adverse_slopes_along_line(coverage, arcs)
    return changes


def _find_upstream_arcs(coverage: GeoDataFrame, arcs: list[tuple], node_arcs: NodeArcs) -> list[tuple]:
    """Returns a list with the most upstream arcs among all the stream arcs.

    Adapted from iFindUpstreamArcs() in creategrid.cpp in WMS.

    Args:
        coverage: The coverage
        arcs: List of all stream arcs.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.

    Returns:
        See description.
    """
    upstream_arcs = []
    for arc in arcs:
        feature_arc, _ = feature_and_index_from_id_and_type(coverage, arc[0], arc[1])
        start_node_id = feature_arc.start_node
        if len(node_arcs[start_node_id]) == 1:
            upstream_arcs.append(arc)
    return upstream_arcs


def _get_all_downstream_arcs(coverage: GeoDataFrame, arc: tuple, node_arcs: NodeArcs) -> list[tuple]:
    """Returns a list of arcs downstream from the given arc.

    Adapted from iGetAllDownstreamArcs() in creategrid.cpp in WMS.

    Args:
        coverage: The coverage
        arc: The arc.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.

    Returns:
        See description.
    """
    downstream_arcs = []
    cur_arc = arc
    while cur_arc:
        downstream_arcs.append(cur_arc)
        cur_arc = _get_downstream_arc(coverage, cur_arc, node_arcs)
    return downstream_arcs


def _get_downstream_arc(coverage: GeoDataFrame, arc: tuple, node_arcs: NodeArcs) -> tuple | None:
    """Returns the arc downstream from given arc (or None).

    Adapted from feDownstreamArc() in fline.cpp in WMS.

    Args:
        coverage: The coverage.
        arc: The arc.
        node_arcs: Dict of node ID -> list of arc IDs (arcs attached to each node) for arc connectivity.

    Returns:
        See description.
    """
    feature_arc, _ = feature_and_index_from_id_and_type(coverage, arc[0], arc[1])
    for connected_arc in node_arcs[feature_arc.end_node]:
        if connected_arc != arc and connected_arc.start_node == feature_arc.end_node:
            return connected_arc.id, connected_arc.geometry_types
    return None


def _remove_adverse_slopes_along_line(coverage: GeoDataFrame, arcs: list[tuple]) -> bool:
    """Removes the adverse slopes.

    Adapted from gdRemoveAdverseSlopesAlongLine and iSmoothStreamElevations in GdGsshaUtils.cpp.

    Args:
        coverage: The coverage
        arcs: List of arcs from upstream to downstream.

    Returns:
        True if an adverse slope was found and changes were made.
    """
    if not arcs:
        return False

    lengths = []
    zs = []
    num = 0
    last_pt = None
    # We will convert point 3d to length and elevation along curve
    for arc in arcs:
        # Grab the points
        feature_arc, _ = feature_and_index_from_id_and_type(coverage, arc[0], arc[1])
        arc_pts = list(feature_arc.geometry.coords)

        # On the first point, set it at the zero length and Z elevation to vectors
        if not lengths:
            lengths.append(0.0)
            zs.append(arc_pts[0][2])
            last_pt = arc_pts[0]
            num += 1

        # Go through the points determining distance from last point and setting Z elevation
        for i in range(1, len(arc_pts)):
            last_pt_tuple = last_pt[0], last_pt[1], last_pt[2]
            arc_pt_tuple = arc_pts[i][0], arc_pts[i][1], arc_pts[i][2]
            lengths.append(lengths[num - 1] + geometry.distance_2d(last_pt_tuple, arc_pt_tuple))
            zs.append(arc_pts[i][2])
            last_pt = arc_pts[i]
            num += 1

    # Smooth the points
    changes = False
    if zs:
        zs[0] = max(zs)
        zs[-1] = min(zs) - MIN_ELEV_DIFF
        while _adverse_slopes_exist(zs):
            _smooth_stream_data(lengths, zs)
            changes = True

    # Now return the smoothed points to the original arc
    if changes:
        z_offset = 0
        for arc in arcs:
            feature_arc, index = feature_and_index_from_id_and_type(coverage, arc[0], arc[1])
            old_arc_pts = list(feature_arc.geometry.coords)
            new_arc_pts = LineString([(x, y, zs[z_offset + i]) for i, (x, y, _) in enumerate(old_arc_pts)])
            coverage.at[index, 'geometry'] = new_arc_pts
            z_offset += len(old_arc_pts) - 1

    return changes


def _adverse_slopes_exist(zs: Sequence[float]) -> bool:
    """Returns True if adverse slopes exist in the list of points.

    Adapted from iAdverseSlopesExist() in GdGsshaUtils.cpp in xmsutl.

    Args:
        zs: z elevations

    Returns:
        See description.
    """
    for i in range(1, len(zs)):
        if zs[i - 1] - zs[i] <= 0.0:
            return True
    return False


def _smooth_stream_data(lengths: MutableSequence[float], zs: MutableSequence[float]) -> None:
    """Smooths the elevations.

    Adapted from gdSmoothStreamData() in GdGsshaUtils.cpp in xmsutl.

    Args:
        lengths: lengths from first point
        zs: z elevations
    """
    # THIS SECTION OF CODE REMOVES SOLITARY HIGH SPOTS IN THE PROFILE BY ITERATING
    it = 0
    done = False
    while not done and it < ITER_MAX:
        it += 1
        done = True
        for i in range(1, len(zs) - 1):
            if gt_tol(zs[i], zs[i - 1], ZERO_TOL):
                zs[i] = 0.5 * (zs[i + 1] + zs[i - 1])
                done = False

    # NOW SMOOTH THE PROFILE BY A LITTLE THREE POINT MOVING AVERAGING
    # WITH WEIGHTS DETERMINED BY THE DISTANCES TO ADJACENT POINTS
    for i in range(1, len(lengths) - 1):
        usl = lengths[i] - lengths[i - 1]  # upstream length
        dsl = lengths[i + 1] - lengths[i]  # downstream length
        ttl = usl + dsl
        newel = zs[i - 1] * dsl / ttl + zs[i + 1] * usl / ttl
        if lt_tol(newel, zs[i - 1], ZERO_TOL) or lteq_tol(newel, zs[i], ZERO_TOL):
            zs[i] = newel
