r"""Reads the `CHAN_INPUT <ci>`_ (.cif) file, which describes the stream network.

.. _ci: https://www.gsshawiki.com/Surface_Water_Routing:Channel_Routing

::

    GSSHA_CHAN
    ALPHA       1.000000
    BETA        1.000000
    THETA       1.000000
    LINKS       5
    MAXNODES    89
    CONNECT    4    5    0
    CONNECT    1    3    0
    CONNECT    2    3    0
    CONNECT    3    5    2    1    2
    CONNECT    5    0    2    4    3

    LINK           4
    DX             27.545402
    TRAPEZOID
    NODES          68
    NODE 1
    X_Y  459545.266545 4499466.101199
    ELEV 2167.939697
    XSEC
    MANNINGS_N     0.040000
    BOTTOM_WIDTH   1.000000
    BANKFULL_DEPTH 1.000000
    SIDE_SLOPE     1.000000
    NODE 2
    X_Y  459525.449644 4499453.475290
    ELEV 2168.067627
    .
    .
    .
"""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from dataclasses import dataclass, field
import logging
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.coverage import coverage_builder
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.data_objects.parameters import Coverage

# 4. Local modules
from xms.gssha.data.bc_generic_model import ChannelType
from xms.gssha.file_io import io_util

# Type aliases
Connect = dict[int, dict[str, int | list[int]]]  # CONNECT data: link -> dict[down link, list[up links]]
Nodes = list[list[float]]  # List of polyline (arc) xyz points


@dataclass
class Link:
    """A link (an arc)."""
    number: int = 0
    channel_type: ChannelType = ChannelType.TRAPEZOIDAL
    mannings_n: float = 0.0
    bottom_width: float = 0.0
    bankfull_depth: float = 0.0
    side_slope: float = 0.0
    nodes: Nodes = field(default_factory=list)
    xy_id: int = -1  # ID of XySeries in xy_dict that goes with NPAIRS and X1 cards for xsec (BREAKPOINT)


def read(file_path: Path) -> 'tuple[Coverage | None, dict[int, Link], dict[int, XySeries]]':
    """Reads the CHAN_INPUT (.cif) file and returns: coverage, links dict, and xy series dict.

    Args:
        file_path: .cif file path.

    Returns:
        See description.
    """
    reader = ChannelInputFileReader(file_path)
    return reader.read()


class ChannelInputFileReader:
    """Reads the channel input (.cif) file."""
    def __init__(self, file_path: Path) -> None:
        """Initializes the class.

        Args:
            file_path (str | Path): .cif file path.
        """
        super().__init__()
        self._file_path: Path = file_path

        self._log = logging.getLogger('xms.gssha')
        self._file = None
        self._xy_dict: dict[int, XySeries] = {}

    def read(self) -> 'tuple[Coverage | None, dict[int, Link], dict[int, XySeries]]':
        """Reads the CHAN_INPUT (.cif) file and returns: coverage, links dict, and xy series dict."""
        self._log.info('Reading .cif file...')

        links = self._read_links()
        coverage: Coverage = self._build_coverage(links)
        return coverage, links, self._xy_dict

    def _read_links(self) -> dict[int, Link]:
        """Reads all the links and returns the link dict and fills in self._xy_dict.

        Returns:
            See description.
        """
        # connect: Connect = {}  # CONNECT data: link -> dict[down links, list[up links]]
        links: dict[int, Link] = {}  # link number -> list of xyz points
        with open(self._file_path, 'r') as self._file:
            for line in self._file:
                line = line.rstrip('"')
                card, card_value = io_util.get_card_and_value(line)
                # elif card == 'CONNECT':
                #     words = card_value.split()
                #     num_up = int(words[2])
                #     upstream_links = [int(word) for word in words[3:num_up]]
                #     connect[int(words[0])] = {'downstream': int(words[1]), 'upstream': upstream_links}
                if card == 'LINK':
                    link = Link(number=int(card_value))
                    links[int(card_value)] = link
                elif card == 'TRAPEZOID':
                    link.channel_type = ChannelType.TRAPEZOIDAL
                elif card == 'BREAKPOINT':
                    link.channel_type = ChannelType.CROSS_SECTION
                elif card == 'NODES':
                    self._read_nodes(int(card_value), link)
        return links

    def _read_npairs(self, npairs: int) -> int:
        """Reads the xy series defined by NPAIRS and X1 cards and returns the new XY series ID.

        Args:
            npairs: Number of xy pairs in the series.

        Returns:
            See description.
        """
        x: list[float] = []
        y: list[float] = []
        while len(x) < npairs:
            line = next(self._file).rstrip('\n')
            card, x_y_values = io_util.get_card_and_value(line)
            if card == 'X1':
                words = x_y_values.split()
                x.append(float(words[0]))
                y.append(float(words[1]))
        xy_series = XySeries(x, y)
        return xy_util.add_or_match(xy_series, self._xy_dict)  # This also sets the XySeries id

    def _read_nodes(self, num_nodes: int, link: Link) -> None:
        """Reads and returns the nodes (list of points of the polyline (arc)), and an XySeries ID if NPAIRS is found.

        Args:
            num_nodes: Number of nodes.
            link: The current link being read.

        Returns:
            See description.
        """
        x: float = 0.0
        y: float = 0.0
        node: int = 0
        for line in self._file:
            line = line.rstrip('"')
            card, card_value = io_util.get_card_and_value(line)
            if card == 'NODE':
                node = int(card_value)
            elif card == 'X_Y':
                words = card_value.split()
                x, y = float(words[0]), float(words[1])
            elif card == 'ELEV':
                link.nodes.append([x, y, float(card_value)])
                if node == num_nodes:
                    break
            elif card == 'MANNINGS_N':
                link.mannings_n = float(card_value)
            elif card == 'BOTTOM_WIDTH':
                link.bottom_width = float(card_value)
            elif card == 'BANKFULL_DEPTH':
                link.bankfull_depth = float(card_value)
            elif card == 'SIDE_SLOPE':
                link.side_slope = float(card_value)
            elif card == 'NPAIRS':
                link.xy_id = self._read_npairs(int(card_value))

    def _build_coverage(self, links: dict[int, Link]) -> Coverage:
        """Builds and returns the coverage with the arcs.

        Args:
            links:

        Returns:
            See description.
        """
        # We assume that the shared xyz locations at the start and end of connected links are identical
        polylines = [link.nodes for link in links.values()]
        coverage = coverage_builder.build(polylines=polylines, name=self._file_path.stem)
        return coverage
