"""SCHISM files writer."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
from typing import Sequence, TextIO

# 2. Third party modules

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid

# 4. Local modules
from .fort14_file import Boundary, Fort14File


def write_fort14(file: Fort14File, path: Path | str):
    """
    Write a fort14 file to disk.

    Args:
        file: The file to write.
        path: Where to write it.
    """
    with HorizontalGridWriter(path) as writer:
        writer.write_dataset(file.ugrid.ugrid, file.dataset)

        if file.open_boundaries is not None:
            # This is the hgrid
            writer.write_open_boundaries(file.open_boundaries)
            writer.write_closed_boundaries(file.closed_boundaries)


class HorizontalGridWriter:
    """Writer class for the horizontal grid from SCHISM."""
    def __init__(self, path: str | Path):
        """
        Initialize the writer.

        Args:
            path: Where to write the file to.
        """
        self._path = Path(path)
        self._file: TextIO

    def __enter__(self):
        """Open the file."""
        self._file = open(self._path, 'w')
        self._file.write(f'{self._path.name}\n')
        return self

    def __exit__(self, _exc_type, _exc_value, _exc_tb):
        """Close the file."""
        self._file.close()

    def write_dataset(self, domain: UGrid, dataset):
        """
        Write a dataset-type file.

        Args:
            domain: The geometry to write.
            dataset: The values to use for elevations.
        """
        self._file.write(f'{domain.cell_count} {domain.point_count}  ! elements, nodes\n')
        self._write_locations(domain, dataset)
        self._write_connectivity(domain)

    def _write_locations(self, domain: UGrid, dataset):
        """
        Write the location section of the file.

        Args:
            domain: Write the location section of the file.
            dataset: Z-values to write for each node.
        """
        for index, (location, value) in enumerate(zip(domain.locations, dataset), start=1):
            self._file.write(f'{index} {location[0]} {location[1]} {value}\n')

    def _write_connectivity(self, domain: UGrid):
        """
        Write the connectivity section of the file.

        Args:
            domain: Domain to write connectivity for.
        """
        for cell_index in range(domain.cell_count):
            cell_points = domain.get_cell_points(cell_index)
            points = ' '.join([str(p + 1) for p in cell_points])
            self._file.write(f'{cell_index + 1} {len(cell_points)} {points}\n')

    def write_open_boundaries(self, boundaries: Sequence[Boundary]):
        """
        Write the open boundaries part of the file.

        Args:
            boundaries: Open boundaries to write.
        """
        num_nodes = sum(len(arc.nodes) for arc in boundaries)
        self._file.write(f'{len(boundaries)} = Number of open boundaries\n')
        self._file.write(f'{num_nodes} = Total number of open boundary nodes\n')
        for number, boundary in enumerate(boundaries, start=1):
            self._file.write(f'{len(boundary.nodes)} = Number of nodes for open boundary {number}\n')
            for node in boundary.nodes:
                self._file.write(f'{node}\n')

    def write_closed_boundaries(self, boundaries: Sequence[Boundary]):
        """
        Write the closed boundaries section of the file.

        Args:
            boundaries: Closed boundaries to write.
        """
        num_nodes = sum(len(arc.nodes) for arc in boundaries)
        self._file.write(f'{len(boundaries)} = number of land boundaries\n')
        self._file.write(f'{num_nodes} = Total number of land boundary nodes\n')
        for number, boundary in enumerate(boundaries, start=1):
            self._file.write(
                f'{len(boundary.nodes)} {int(boundary.boundary_type)} = Number of nodes for land boundary {number}\n'
            )
            for node in boundary.nodes:
                self._file.write(f'{node}\n')
