"""BcMapAtt class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import logging
from pathlib import Path
from typing import Sequence

# 2. Third party modules
from numpy import datetime64

# 3. Aquaveo modules
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.data_objects.parameters import Arc
from xms.gmi.data.coverage_data import CoverageData
from xms.gmi.data.generic_model import Group

# 4. Local modules
from xms.hgs.data.bc_generic_model import InputType
from xms.hgs.data.domains import Domains
from xms.hgs.file_io import file_io_util
from xms.hgs.file_io.section import IndentedSection
from xms.hgs.mapping import set_maker
from xms.hgs.mapping.map_att import _feature_type, MapAtt

# Type aliases
TimeValueTable = tuple[list[list[float]], bool]  # The table, and a bool for 'interpolate'
TimeFileTable = list[tuple[float, str]]
TimeRasterTable = list[tuple[float, str]]


class BcMapAtt(MapAtt):
    """A coverage attribute item (i.e. BC, observation, hydrograph) with attributes and intersected grid components."""
    def __init__(self, feature, att_type: str, group: Group, coverage_data: CoverageData) -> None:
        """Initializes the class.

        Args:
            feature: A point, arc, or polygon.
            att_type (str): Attribute type (e.g. 'Flux nodal')
            group (Group): The BC as a generic model group.
            coverage_data: CoverageData class (for XySeries data).
        """
        super().__init__(feature, att_type, group)
        self._coverage_data = coverage_data

        self._grok_filepath: Path | None = None  # Member variable cause otherwise we have to pass it down many levels
        self._log = logging.getLogger('xms.hgs')

    def write(
        self,
        section,
        set_id_strings: set[str] | None = None,
        bc_names: set[str] | None = None,
        *args,
        **kwargs
    ) -> None:
        """Writes this MapAtt object to grok.

        Args:
            section (Type[SectionBase]): The section we are writing to.
            set_id_strings (set[str] | None): Set of grid component set id strings so we can ensure they are unique.
            bc_names (set[str] | None): Set of BC names so we can ensure they are unique.
        """
        self._grok_filepath = Path(section.file_data.file.name)
        set_name, set_id_str = set_maker.make_set(self._grok_filepath, self, set_id_strings, section)

        time_file_table, time_value_table, time_raster_table = self._make_table(set_id_str)

        with IndentedSection(section.file_data, 'boundary condition', append_blank=True) as section:
            section.write_value('type', self.att_type.lower(), append_blank=True)
            unique_name = self._get_name(bc_names)
            section.write_value('name', unique_name, append_blank=True)
            section.write_value(f'{self.get_grid_component_type()} set', set_name, append_blank=True)
            if time_value_table is not None:
                section.write_value('time value table', time_value_table[0], append_blank=True)
                if time_value_table[1]:
                    section.write_string('interpolate')
            elif time_file_table is not None:  # time_file_table is not None:
                section.write_value('time file table', time_file_table, append_blank=True)
            elif time_raster_table is not None:
                section.write_value('time raster table', time_raster_table, append_blank=True, path_columns=[1])
            section.write_value('tecplot output')

    def _get_name(self, bc_names: set[str]) -> str:
        """Returns the BC name, providing a unique one if blank or a duplicate.

        Args:
            bc_names (set[str]):

        Returns:
            (str): See the description.
        """
        default_name = self.att_type.lower().replace(' ', '-')
        name = self.value('name', default_name)
        new_name = default_name if not name else name
        new_name = file_io_util.make_string_unique(new_name, bc_names)
        if new_name != name:
            self.set_value('name', new_name)  # This is only local and won't update the coverage
        return new_name

    def get_domain(self) -> str:
        """Returns the domain the BcMapAtt belongs to.

        Returns:
            (str): See description.
        """
        if 'domain' in self.group.parameter_names:
            return self.value('domain')
        return Domains.OLF

        # # These can only be in OLF so they don't have a domain attribute.
        # # Yes, 'Potential evapotranspiration' is in the OLF domain.
        # olf = {'Rain', 'Potential evapotranspiration', 'Simple river', 'Simple drain', 'Critical depth'}
        # if self.att_type in olf:
        #     return Domains.OLF

    def get_grid_component_type(self):
        """Returns the type of grid component ('node', 'face', 'segment') from the att type."""
        grid_component_types = {
            'Head': 'node',
            'Flux': 'face',
            'Rain': 'face',
            'Flux nodal': 'node',
            'Potential evapotranspiration': 'face',
            'Simple river': 'node',
            'Simple drain': 'node',
            'Critical depth': 'segment',
        }
        return grid_component_types[self.att_type]

    def _make_table(self,
                    set_id_str: str) -> tuple[TimeFileTable | None, TimeValueTable | None, TimeRasterTable | None]:
        """Makes and returns the time value table or time file table.

        Args:
            set_id_str (str): String that uniquely identifies the bc.

        Returns:
            (tuple[TimeFileTable | None, TimeValueTable | None]): See description.
        """
        time_file_table, time_value_table, time_raster_table = None, None, None
        if isinstance(self.feature, Arc) and self._interpolate_along_arc():
            time_file_table = self._write_time_files(set_id_str, self.intersection['t_values'])
        elif self.value('input_type') == InputType.TIME_RASTERS:
            time_raster_table = self._make_time_raster_table()
        else:
            time_value_table = self._make_time_value_table()
        return time_file_table, time_value_table, time_raster_table

    def _interpolate_along_arc(self) -> bool:
        """Returns True if we should interpolate along the arc.

        Returns:
            (bool): See description.
        """
        return self.att_type in {'Head', 'Flux'}

    def _get_xy_series_values(self, xy_series_id: int,
                              step_function: bool) -> tuple[Sequence['float | datetime64'], Sequence[float]] | None:
        """Returns the x and y values from the xy series.

        Args:
            xy_series_id (int): The xy series ID.
            step_function (bool): If true, the series should be treated as a step function.

        Returns:
            (tuple[list[float], list[float]]): The x and y values.
        """
        if xy_series_id is None or xy_series_id <= 0:
            self._log.error(f'Time series missing on {_feature_type(self.feature)} with ID {self.feature.id}.')
            return None
        xy = self._coverage_data.get_curve(xy_series_id, False)
        if step_function:
            return xy_util.get_step_function(xy[0], xy[1])
        else:
            return xy[0], xy[1]

    def _get_start_and_end_xy_series(self) -> tuple['XySeries | None', 'XySeries | None']:
        """Returns the starting and ending xy series as XySeries.

        Returns:
            (tuple[XySeries|None, XySeries|None]): See description.
        """
        interpolate_start, xy_id_start, interpolate_end, xy_id_end = None, None, None, None
        rv = self.value('xy_series_start')
        if rv and len(rv) == 2:
            interpolate_start, xy_id_start = rv[0], rv[1]
        rv = self.value('xy_series_end')
        if rv and len(rv) == 2:
            interpolate_end, xy_id_end = rv[0], rv[1]
        xy_values_start = self._get_xy_series_values(xy_id_start, not interpolate_start)
        xy_values_end = self._get_xy_series_values(xy_id_end, not interpolate_end)
        if not xy_values_start or not xy_values_end:
            self._log_missing_xy_series(xy_values_start, xy_values_end)
            return None, None

        xy_start = XySeries(xy_values_start[0], xy_values_start[0])
        xy_end = XySeries(xy_values_end[0], xy_values_end[0])
        return xy_start, xy_end

    def _log_missing_xy_series(self, xy_values_start, xy_values_end):
        """Logs that there is a missing xy series."""
        if not xy_values_start:
            self._log.error(f'Start time series missing on {_feature_type(self.feature)} with ID {self.feature.id}.')
        if not xy_values_end:
            self._log.error(f'End time series missing on {_feature_type(self.feature)} with ID {self.feature.id}.')

    def _interpolate_xy_series_at_t_values(self, t_values: list[float], xy_start: XySeries, xy_end: XySeries):
        """Returns a dict of t_values and their interpolated XySeries.

        All XySeries will have the same number of x,y values, but that number may not match xy_start or xy_end.

        Args:
            t_values (list[float]): T values (interpolation weights).
            xy_start (XySeries): XySeries at the beginning of the arc.
            xy_end (XySeries): XySeries at the end of the arc.

        Returns:
            (dict[float, XySeries]): See description.
        """
        t_value_xy_series = {t_value: xy_util.interpolate(xy_start, xy_end, t=t_value) for t_value in t_values}
        return t_value_xy_series

    def _get_point_values_at_time(self, t_values: list[float], t_value_xy_series, time_idx: int):
        """Returns a list of the value at each point at the time_idx.

        Args:
            t_values (list[float]): T values (interpolation weights). Only not None for arcs.
            t_value_xy_series (dict[float, XySeries]): Dict of t value -> interpolated XySeries
            time_idx (int): Current time index.

        Returns:
            (list[float]): See description.
        """
        point_values = [t_value_xy_series[t_value].y[time_idx] for t_value in t_values]
        return point_values

    def _make_time_raster_table(self) -> TimeRasterTable | None:
        """Returns a time raster table.

        Returns:
            (TimeRasterTable | None): Table (2d list) of the times and the values.
        """
        values = self.value('time_raster_table')
        time_raster_table = [i for i in zip(*values)]  # Transpose from column-wise to row-wise
        return time_raster_table

    def _make_time_value_table(self) -> TimeValueTable | None:
        """Returns a time value table.

        Returns:
            (TimeValueTable | None): Table (2d list) of the times and the values, and the interpolate flag.
        """
        input_type = self.value('input_type')
        time_value_table = None
        if input_type == 'Constant':
            constant = self.value('constant')
            time_value_table = [[0.0, constant]], False
        else:  # 'Time series'
            rv = self.value('xy_series')
            if rv and len(rv) == 2:
                interpolate, xy_id = rv[0], rv[1]
                xy_values_tuple = self._get_xy_series_values(xy_id, step_function=False)
                if xy_values_tuple:
                    table_values = [[x, y] for x, y in zip(xy_values_tuple[0], xy_values_tuple[1])]
                    time_value_table = table_values, interpolate
        return time_value_table

    def _write_time_files(self, set_id_str: str, t_values: list[float]) -> TimeFileTable | None:
        """Writes the time file tables for the bc and returns a table of the times and the files.

        See 'time file table' command in .grok file. Writes one file per time. Each file contains n+1 values where n is
        len(point_idxs). The first line is the number of values (hence the +1). Values are the y values from the xy
        series.

        Args:
            set_id_str (str): String that uniquely identifies the bc.
            t_values (list[float]): T values (interpolation weights). Only not None for arcs.

        Returns:
            (TimeFileTable): Table (2d list) of the times and the files written.
        """
        input_type = self.value('input_type')
        if input_type == 'Constant':
            constant_start = self.value('constant_start')
            constant_end = self.value('constant_end')
            diff = constant_end - constant_start
            point_values = [constant_start + (diff * t_value) for t_value in t_values]
            file_path = self._write_time_file(set_id_str, 0, point_values)
            time_files = [(0.0, file_path.name)]
        else:  # 'Time series'
            xy_start, xy_end = self._get_start_and_end_xy_series()
            if not xy_start or not xy_end:
                return None

            t_value_xy_series = self._interpolate_xy_series_at_t_values(t_values, xy_start, xy_end)
            time_files = []
            _, xy_series = next(iter(t_value_xy_series.items()))  # Get the first interpolated xy series
            for time_idx, x in enumerate(xy_series.x):
                point_values = self._get_point_values_at_time(t_values, t_value_xy_series, time_idx)
                file_path = self._write_time_file(set_id_str, time_idx, point_values)
                time_files.append((x, file_path.name))
        return time_files

    def _write_time_file(self, set_id_str: str, i: int, point_values: list[float]) -> Path:
        """Writes a time file table.

        See 'time file table' command in .grok file. Writes point_count+1 values (first line is number of values).

        Args:
            set_id_str (str): String that uniquely identifies the bc.
            i (int): Index of time in the table.
            point_values (list[float]): Value at each point at time[i].
        """
        file_path = self._grok_filepath.with_name(f'{set_id_str}-time-{i}.txt')
        with file_path.open('w') as file:
            file.write(f'{len(point_values)}\n')
            for value in point_values:
                file.write(f'{value}\n')
        return file_path
