"""Code dealing with arc direction."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import math
from typing import Sequence

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import Arc, FilterLocation

# 4. Local modules


def arcs_have_compatible_directions(arc1: Arc, arc2: Arc) -> bool:
    """Checks the directions of arcs to see if they are 'parallel' or 'anti-parallel'.

    Args:
        arc1: data object.
        arc2: data object.

    Returns:
        (bool): True if the arcs are 'parallel'.
    """
    pts1 = [(pt.x, pt.y, pt.z) for pt in arc1.get_points(FilterLocation.PT_LOC_ALL)]
    pts2 = [(pt.x, pt.y, pt.z) for pt in arc2.get_points(FilterLocation.PT_LOC_ALL)]

    # sum the distances between parametric locations on the two arcs
    sum1 = _sum_distances_between_arc_points(pts1, pts2)

    # invert the second arc and sum the distances between parametric locations on the two arcs
    pts2.reverse()
    sum2 = _sum_distances_between_arc_points(pts1, pts2)

    # if the first sum is less than the second, the arcs have compatible direction
    if sum1 < sum2:
        return True
    else:
        return False


def point_on_arc(arc: Arc, t: float) -> Sequence[float]:
    """Returns the middle point along the arc (which is where the label would be).

    Args:
        arc: The arc.
        t: parametric (0.0 - 1.0) value

    Returns:
        The middle point.
    """
    arc_points = [(pt.x, pt.y, pt.z) for pt in arc.get_points(FilterLocation.PT_LOC_ALL)]
    cum_len, total_len = _cumulative_and_total_lengths(arc_points)
    return _point_on_arc(cum_len, total_len, arc_points, t)


def reverse_arc_direction(arc: Arc) -> None:
    """Reverses the arc direction by swapping the start/end nodes and reversing the list of vertices.

    Args:
        arc: The arc.
    """
    arc.start_node, arc.end_node = arc.end_node, arc.start_node
    arc.vertices = reversed(arc.vertices)


def _sum_distances_between_arc_points(pts1: Sequence[Sequence[float]], pts2: Sequence[Sequence[float]]) -> float:
    """Computes points along the two arcs and sums the distances between them, in 2D (x and y).

    Args:
        pts1: list of points in arc 1. Only x and y are used (the first 2 values in each sequence).
        pts2: list of points in arc 2. Only x and y are used (the first 2 values in each sequence).

    Returns:
        (float): cumulative distance.
    """
    cum_len1, total_len1 = _cumulative_and_total_lengths(pts1)
    cum_len2, total_len2 = _cumulative_and_total_lengths(pts2)

    sum_len = 0.0
    t = 0.0
    while t <= 1.0:
        # find the point on arc1 and arc2 at parametric value t
        x1, y1, _z1 = _point_on_arc(cum_len1, total_len1, pts1, t)
        x2, y2, _z2 = _point_on_arc(cum_len2, total_len2, pts2, t)
        dx = x2 - x1
        dy = y2 - y1
        sum_len += dx * dx + dy * dy
        t += 0.1

    return sum_len


def _point_on_arc(cum_len: Sequence[float], total_len: float, pts: Sequence[Sequence[float]],
                  t: float) -> tuple[float, float, float]:
    """Computes a point at parameter value t on the arc.

    Args:
        cum_len: list of lengths to each point in the arc.
        total_len: total length of the arc.
        pts: points in the arc. Only x and y are used (the first 2 values of each point).
        t: parametric (0.0 - 1.0) value

    Returns:
        (x, y ,z): location on the arc.
    """
    len_t = t * total_len
    # find the segment containing len_t
    i = 1
    while cum_len[i] < len_t:
        i += 1
    percent = (len_t - cum_len[i - 1]) / (cum_len[i] - cum_len[i - 1])
    x = pts[i][0] * percent + pts[i - 1][0] * (1.0 - percent)
    y = pts[i][1] * percent + pts[i - 1][1] * (1.0 - percent)
    z = pts[i][2] * percent + pts[i - 1][2] * (1.0 - percent)
    return x, y, z


def _cumulative_and_total_lengths(arc_points: Sequence[Sequence[float]]) -> tuple[Sequence[float], float]:
    """Returns a list of cumulative segment lengths, and the total length.

    Args:
        arc_points: list of points on arc.

    Returns:
        See description.
    """
    total_len = 0.0
    cum_len = [0.0]
    for i in range(len(arc_points) - 1):
        dx = arc_points[i + 1][0] - arc_points[i][0]
        dy = arc_points[i + 1][1] - arc_points[i][1]
        dz = arc_points[i + 1][2] - arc_points[i][2]
        seg_len = math.sqrt(dx * dx + dy * dy + dz * dz)
        total_len += seg_len
        cum_len.append(total_len)
    return cum_len, total_len
