"""Classes used by BridgeFootprint class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.tool.algorithms.ugrids.ugrid_2d_merger import UGrid2dMerger


class PierBase:
    """Base class for wall piers."""

    def __init__(self, length=1.0, width=1.0, wrap_width=0.0, number_of_side_elements=1, number_layers=1,
                 growth_factor=1.5):
        """Initialize the class.

        Args:
            length (float): length of the pier
            width (float): width of the pier
            wrap_width (float): width of first wrap around the pier
            number_of_side_elements (int): number of elements down the side of a wall pier
            number_layers (int): number of layers to wrap around the pier
            growth_factor (float): factor for growth of element layers
        """
        if width > length:
            raise RuntimeError('pier width must be < pier length')
        if length <= 0.0:
            raise RuntimeError('pier length (or diameter) must be > 0.0')
        if width <= 0.0:
            raise RuntimeError('pier width must be > 0.0')
        if wrap_width <= 0.0:
            wrap_width = width / 2.0
        self.pts = []
        self.cells = []
        self.outer_layer_thick = []
        self._length = length
        self._width = width
        self._wrap_width = wrap_width
        self._outer_poly_slice = 0
        self._outer_poly_idx = []
        self._number_of_layers = number_layers
        self._number_of_side_elements = number_of_side_elements
        self._growth_factor = growth_factor
        self._end_pt_indexes = None
        self._min_bridge_width = None

    def min_bridge_width(self):
        """Calculates the minimum bridge width for the pier.

        Returns:
            float: see description
        """
        if self._min_bridge_width:
            return self._min_bridge_width
        return self._length + (2 * self._wrap_width)

    def _create_cells(self):
        """Creates the cell stream."""
        self.cells = []
        pts_per_layer = int(len(self.pts) / (self._number_of_layers + 1))
        out_start = pts_per_layer
        self._outer_poly_slice = len(self.pts) - pts_per_layer

        quad = 9  # see XMU_QUAD
        start = 0
        for _ in range(self._number_of_layers):
            for i, idx in enumerate(range(start, out_start - 1)):
                self.cells.extend([quad, 4, idx, out_start + i, out_start + i + 1, idx + 1])
            self.cells.extend([quad, 4, out_start - 1, out_start + pts_per_layer - 1, out_start, start])
            start = out_start
            out_start += pts_per_layer

    def _generate_points_cells(self):
        """Creates the points and cells from the pier definition."""
        layer_pts = {0: [[p[0] * self._width, p[1] * self._length, 0.0] for p in self.pts]}
        # create layers of points
        new_width = self._wrap_width * 2
        for i in range(self._number_of_layers):
            layer_pts[i + 1] = [[p[0] * new_width + layer_pts[0][j][0],
                                 p[1] * new_width + layer_pts[0][j][1],
                                 0.0] for j, p in enumerate(self.pts)]
            old_width = new_width
            new_width = new_width + self._growth_factor**(i + 1) * (self._wrap_width * 2)
            self.outer_layer_thick.append(((new_width - old_width) / 2) - self._wrap_width)
        self.pts = []
        for v in layer_pts.values():
            self.pts.extend(v)
        self._create_cells()

    def rotate_points(self, angle):
        """Rotates points using angle.

        Args:
            angle (degrees): measured from north
        """
        rad = math.radians(angle)
        self.rotate_points_radians(rad)

    def rotate_points_radians(self, angle_radians):
        """Rotates points using angle.

        Args:
            angle_radians (degrees): measured from north
        """
        cos_rad = math.cos(angle_radians)
        sin_rad = math.sin(angle_radians)
        for p in self.pts:
            rotate_point(p, cos_rad, sin_rad)

    def translate_points(self, new_origin):
        """Translates the points in x, y based on values in new_origin.

        Args:
            new_origin (iterable): x, y of new origin
        """
        for p in self.pts:
            translate_point(p, new_origin)

    def outer_polygon(self):
        """Returns the outer polygon of this pier.

        Returns:
            list: list of x,y,z coordinates of the outer polygon
        """
        if self._outer_poly_idx:
            return [self.pts[i] for i in self._outer_poly_idx]
        return self.pts[self._outer_poly_slice:]

    def end_pt_indexes(self):
        """Returns 2 lists of the indexes of end points (points that could be snapped to the bridge footprint).

        Returns:
            list (list(ints)): list of 2 lists with the
        """
        return self._end_pt_indexes


class SquareUnitWallPier(PierBase):
    """Class that has the definition for a square unit wall pier."""

    def __init__(self, length=1.0, width=1.0, wrap_width=0.0, number_of_side_elements=1, number_layers=1,
                 growth_factor=1.5):
        """Initialize the class.

        Args:
            length (float): length of the pier
            width (float): width of the pier
            wrap_width (float): width of first wrap around the pier
            number_of_side_elements (int): number of elements down the side of a wall pier
            number_layers (int): number of layers to wrap around the pier
            growth_factor (float): factor for growth of element layers
        """
        super().__init__(length, width, wrap_width, number_of_side_elements, number_layers, growth_factor)
        dy = 1.0 / number_of_side_elements
        left = [[-0.5, 0.5 - (dy * i), 0.0] for i in range(number_of_side_elements + 1)]
        right = [[0.5, -0.5 + (dy * i), 0.0] for i in range(number_of_side_elements + 1)]
        self.pts = left + right
        self._generate_points_cells()
        self.number_end_nodes = 2
        idx = number_of_side_elements
        poly = self.outer_polygon()
        self._end_pt_indexes = [[len(poly) - 1, 0], [idx, idx + 1]]


class PointedUnitWallPier(PierBase):
    """Class that has the definition for a square unit wall pier."""

    def __init__(self, length=1.0, width=1.0, wrap_width=0.0, number_of_side_elements=1, number_layers=1,
                 growth_factor=1.5):
        """Initialize the class.

        Args:
            length (float): length of the pier
            width (float): width of the pier
            wrap_width (float): width of first wrap around the pier
            number_of_side_elements (int): number of elements down the side of a wall pier
            number_layers (int): number of layers to wrap around the pier
            growth_factor (float): factor for growth of element layers
        """
        super().__init__(length, width, wrap_width, number_of_side_elements, number_layers, growth_factor)
        if length == width:
            self._number_of_side_elements = 0
            left = [[-0.5, 0.0, 0.0]]
            right = [[0.5, 0.0, 0.0]]
            self.number_end_nodes = 1
            self._end_pt_indexes = [[0], [2]]
        else:
            y_start = 0.5 * (length - width) / length
            dy = 2 * y_start / number_of_side_elements
            left = [[-0.5, y_start - (dy * i), 0.0] for i in range(number_of_side_elements + 1)]
            right = [[0.5, -y_start + (dy * i), 0.0] for i in range(number_of_side_elements + 1)]
        top = [[0.0, 0.5, 0.0]]
        bot = [[0.0, -0.5, 0.0]]
        self.pts = top + left + bot + right
        self._generate_points_cells()
        if self._end_pt_indexes is None:
            self.number_end_nodes = 3
            idx = number_of_side_elements + 1
            poly = self.outer_polygon()
            self._end_pt_indexes = [[len(poly) - 1, 0, 1], [idx, idx + 1, idx + 2]]


class RoundUnitWallPier(PierBase):
    """Class that has the definition for a round unit pier."""

    def __init__(self, length=1.0, width=1.0, wrap_width=0.0, number_of_side_elements=1, number_layers=1,
                 growth_factor=1.5):
        """Initialize the class.

        Args:
            length (float): length of the pier
            width (float): width of the pier
            wrap_width (float): width of first wrap around the pier
            number_of_side_elements (int): number of elements down the side of a wall pier
            number_layers (int): number of layers to wrap around the pier
            growth_factor (float): factor for growth of element layers
        """
        super().__init__(length, width, wrap_width, number_of_side_elements, number_layers, growth_factor)
        if length == width:
            self._number_of_side_elements = 0
            left = [[-0.5, 0.0, 0.0]]
            right = [[0.5, 0.0, 0.0]]
            self.number_end_nodes = 2
            self._end_pt_indexes = [[0, 1], [3, 4]]
        else:
            y_start = 0.5 * (length - width) / length
            dy = 2 * y_start / number_of_side_elements
            left = [[-0.5, y_start - (dy * i), 0.0] for i in range(number_of_side_elements + 1)]
            right = [[0.5, -y_start + (dy * i), 0.0] for i in range(number_of_side_elements + 1)]

        top = [[0.25, 0.5, 0.0], [-0.25, 0.5, 0.0]]
        bot = [[-0.25, -0.5, 0.0], [0.25, -0.5, 0.0]]
        self.pts = top + left + bot + right
        self._generate_points_cells()
        if self._end_pt_indexes is None:
            self.number_end_nodes = 4
            idx = number_of_side_elements + 2
            poly = self.outer_polygon()
            self._end_pt_indexes = [[len(poly) - 1, 0, 1, 2], [idx, idx + 1, idx + 2, idx + 3]]


class PierGroup(PierBase):
    """Class that has the definition for a round unit pier."""

    def __init__(self, number_piers=1, pier_spacing=2.0, diameter=1.0, wrap_width=0.0, number_of_layers=1,
                 growth_factor=1.5, num_side_elements=-1):
        """Initialize the class.

        Args:
            number_piers (int): number of piers in group
            pier_spacing (float): spacing between piers
            diameter (float): diameter of each pier
            wrap_width (float): width of first wrap around the pier
            number_of_layers (int): number of layers to wrap around the pier group
            growth_factor (float): factor for growth of element layers
            num_side_elements (int): number of desired side elements
        """
        super().__init__(length=diameter, width=diameter, wrap_width=wrap_width, number_layers=number_of_layers,
                         growth_factor=growth_factor)
        self._num_piers = number_piers
        self._spacing = pier_spacing
        self._num_side_elements = num_side_elements
        self._calc_piers()
        self._calc_wrapping()
        self.number_end_nodes = 4
        if number_piers == 1:
            self.number_end_nodes = 2
            self._end_pt_indexes = [[0, 1], [3, 4]]
        else:
            poly = self.outer_polygon()
            j = int(3 + (len(poly) - 8) / 2)
            self._end_pt_indexes = [[len(poly) - 1, 0, 1, 2], [j, j + 1, j + 2, j + 3]]

    def _calc_piers(self):
        """Create the points and cells for the pier group."""
        total_length = self._spacing * (self._num_piers - 1)
        loc = (0.0, total_length / 2, 0.0)
        all_pts = []
        all_cells = []
        quad = 9  # see XMU_QUAD
        num_new_cells = 9
        for _ in range(self._num_piers):
            p = RoundUnitWallPier(length=self._length, width=self._width, wrap_width=self._wrap_width,
                                  number_of_side_elements=1, number_layers=1)
            p.translate_points(loc)
            num_pts = len(all_pts)
            if num_pts > 0:
                connect = [quad, 4, -4, 8, 7, -3,
                           quad, 4, -3, 7, 6, -2,
                           quad, 4, -2, 6, 11, -1]
                p.cells.extend(connect)
                for c in range(num_new_cells):
                    idx = (c * 6) + 2
                    for j in range(4):
                        p.cells[idx + j] += num_pts
            all_pts.extend(p.pts)
            all_cells.extend(p.cells)
            loc = (0.0, loc[1] - self._spacing, 0.0)
        self.pts = all_pts
        self.cells = all_cells
        self._calc_outer_poly()

    def _calc_outer_poly(self):
        """Calculate the outer polygon."""
        ug = UGrid(self.pts, self.cells)
        ugb = UGridBoundaries(ug)
        loops = ugb.get_loops()
        for loop in loops.values():
            set_ids = set(loop['id'])
            if all(x in set_ids for x in [6, 7, 8, 11]):
                self._outer_poly_idx = list(loop['id'])
                self._outer_poly_idx.reverse()
                break

    def _calc_wrapping(self):
        """Create the points and cells that will wrap the piers."""
        if self._number_of_layers < 2:
            return

        start_width = (0.5 * self._width) + self._wrap_width
        wrap_width = self._wrap_width * 2
        w = start_width + self._growth_factor * (0.5 * wrap_width)
        if w < wrap_width:
            w = wrap_width * 1.1
        length = self._spacing * (self._num_piers - 1) + w
        side_elements = max(1, self._num_piers - 1)
        use_merger = False
        if 0 < self._num_side_elements != side_elements:
            side_elements = self._num_side_elements
            use_merger = True
            w = start_width * 2.0
            length = self._spacing * (self._num_piers - 1) + w
            # if self._growth_factor != 1.5:
            #     w = start_width + 1.5 * (0.5 * wrap_width)
            #     length = self._spacing * (self._num_piers - 1) + w
        p = RoundUnitWallPier(length=length, width=w, number_of_side_elements=side_elements,
                              number_layers=self._number_of_layers - 1, growth_factor=self._growth_factor)
        self.outer_layer_thick = p.outer_layer_thick
        self._min_bridge_width = p.min_bridge_width()
        if use_merger:  # merge the circular piers grid with the wall pier grid
            self._merge_grids(p)
        else:  # merge the pts and cells
            num_pts = len(self.pts)
            old_to_new_idx = self._outer_poly_idx.copy()
            for i, _ in enumerate(range(len(self._outer_poly_idx), len(p.pts))):
                old_to_new_idx.append(num_pts + i)

            self.pts.extend(p.pts[len(self._outer_poly_idx):])
            cell_idx = -1
            cnt = 0
            while cnt < len(p.cells):
                cell_idx += 1
                # cell_type = p.cells[cnt]
                # num_cell_pts = p.cells[cnt + 1]
                for c in range(4):
                    pt_idx = p.cells[cnt + 2 + c]
                    p.cells[cnt + 2 + c] = old_to_new_idx[pt_idx]
                cnt += 6
            self.cells.extend(p.cells)
            pts_per_layer = int(len(p.pts) / (p._number_of_layers + 1))
            self._outer_poly_slice = len(self.pts) - pts_per_layer
            self._outer_poly_idx = []

    def _merge_grids(self, wall_pier):
        """Merge the circular piers grid with the wall pier grid.

        Args:
            wall_pier (RoundUnitWallPier): the wall pier
            pt_tol (float): point tolerance for the merger
        """
        b = UGridBuilder()
        b.set_is_2d()
        b.set_ugrid(UGrid(self.pts, self.cells))
        ug_circular_piers = b.build_grid()
        b.set_ugrid(UGrid(wall_pier.pts, wall_pier.cells))
        ug_wall_pier = b.build_grid()
        merger = UGrid2dMerger(ug_circular_piers, ug_wall_pier)
        merged_grid = merger.merge_grids()
        self.pts = [[p[0], p[1], p[2]] for p in merged_grid.locations]
        self.cells = merged_grid.cellstream
        self._end_pt_indexes = wall_pier.end_pt_indexes()
        outer_poly = wall_pier.outer_polygon()
        pt_lookup = {(p[0], p[1]): i for i, p in enumerate(self.pts)}
        self._outer_poly_idx = [pt_lookup[(p[0], p[1])] for p in outer_poly]


def rotate_point(pt, cos_rad, sin_rad):
    """Rotates points using angle.

    Args:
        pt (x,y): location
        cos_rad (float): cosine of angle
        sin_rad (float): sine of angle
    """
    new_x = cos_rad * pt[0] - sin_rad * pt[1]
    new_y = sin_rad * pt[0] + cos_rad * pt[1]
    pt[0] = new_x
    pt[1] = new_y


def translate_point(pt, new_origin):
    """Translates the point in x, y based on values in new_origin.

    Args:
        pt (iterable): x, y of point
        new_origin (iterable): x, y of new origin
    """
    pt[0] += new_origin[0]
    pt[1] += new_origin[1]
