"""PrecipFileWriter class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from dataclasses import dataclass
from datetime import datetime
import logging
from pathlib import Path

# 2. Third party modules
from numpy import ndarray

# 3. Aquaveo modules
from xms.gmi.data.generic_model import GenericModel
from xms.gmi.data.sim_data import SimData

# 4. Local modules
from xms.gssha.file_io import io_util


def write(
    gssha_file_path: Path, sim_data: SimData, generic_model: GenericModel, start_date_time: datetime
) -> Path | None:
    """Writes the precipitation input (.gag) file and returns the file path.

    Args:
        gssha_file_path (str | Path): .gssha file path.
        sim_data: SimData object.
        generic_model: The generic model.
        start_date_time: The starting date/time.
    """
    writer = PrecipFileWriter(gssha_file_path, sim_data, generic_model, start_date_time)
    return writer.write()


@dataclass
class PrecipitationData:
    """Precipitation data."""
    x_vals: 'ndarray | None' = None
    y_vals: 'ndarray | None' = None
    average_depth: float = 0.0


class PrecipFileWriter:
    """Writes the precipitation input (.gag) file."""
    def __init__(
        self, gssha_file_path: Path, sim_data: SimData, generic_model: GenericModel, start_date_time: datetime
    ) -> None:
        """Initializes the class.

        Args:
            gssha_file_path (str | Path): .gssha file path.
            sim_data: SimData object.
            generic_model: The generic model.
            start_date_time: The starting date/time.
        """
        super().__init__()
        self._gag_file_path: Path = gssha_file_path.with_suffix('.gag')
        self._sim_data = sim_data
        self._group = generic_model.global_parameters.group('precipitation')
        self._start_date_time = start_date_time

        self._log = logging.getLogger('xms.gssha')

    def write(self) -> Path | None:
        """Writes the precipitation input (.gag) file and returns the file path."""
        self._log.info('Writing .gag file...')
        data = self._get_precipitation_data()
        if not data:
            return None

        with open(self._gag_file_path, 'w') as file:
            file.write(f'#HYETOGRAPH {data.average_depth}\n')
            file.write('EVENT\n')
            file.write('NRGAG 1\n')
            file.write(f'NRPDS {len(data.x_vals)}\n')
            file.write('COORD 0.0 0.0\n')
            for x, y in zip(data.x_vals, data.y_vals):
                date_time_str = io_util.get_time_string(self._start_date_time, x)
                file.write(f'ACCUM {date_time_str} {y * data.average_depth}\n')

        return self._gag_file_path

    def _get_precipitation_data(self) -> PrecipitationData | None:
        """Returns the precipitation data we need, or None if there's a problem."""
        # Get XY series
        data = PrecipitationData()
        xs, ys = self._group.parameter('hyetograph_xy').value
        if len(xs) < 2:
            self._log.error('Precipitation hyetograph XY series not defined.')
            return None

        # Get other data
        data.x_vals, data.y_vals = (xs, ys)
        data.average_depth = self._group.parameter('hyetograph_avg_depth').value
        return data
