"""Read solution datasets."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
from dataclasses import dataclass
import logging
import math
import os
from pathlib import Path
import sys
from typing import Optional

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.funwave.dmi.xms_data import XmsData


@dataclass
class _DatasetFile:
    path: Path
    mask_path: Optional[Path]
    time: float
    multiplier: float


STEADY_STATE_FILE_TYPES = [
    ("dep.out", "Depth", 1.0),
    ("dep.out", "Elevation from Depth", -1.0),
    ("cd_breakwater.out", "Cd Breakwater", 1.0),
]


TRANSIENT_FILE_TYPES = {
    "eta": "surface elevation",
    "etamean": "average surface elevation",
    "u": "x velocity",
    "v": "y velocity",
    "umean": "average x velocity",
    "vmean": "average y velocity",
    "mask": "wetting mask",
    # "mask9": "9 point wetting mask",
    "hmax": "maximum surface elevation",
    "hmin": "minimum surface elevation",
    "Havg": "average wave height",
    "Hrms": "root mean square wave height",
    "Hsig": "significant wave height",
    "umax": "maximum velocity",
    "MFmax": "maximum momentum flux",
    "VORmax": "maximum vertical vorticiy",
    "time": "tsunami arrival time",
    "p": "x volume flux",
    "q": "y volume flux",
    "nubrk": "breaking induced eddy viscosity",
    "etat": "eta T",
    "age": "breaking age",
    "roller": "roller-induced mass flux",
    "U_undertow": "x roller-induced extra undertow",
    "V_undertow": "y roller-induced extra undertow",
    "Pstorm": "air pressure",
    "Ustorm": "x wind velocity",
    "Vstorm": "y wind velocity",
    "Fves": "vessel mass flux gradient",
    "Pves": "vessel pressure source",
    "C": "sediment concentration",
    "Pick": "pickup rate",
    "Depo": "deposition rate",
    "DchgS": "depth change in suspended load",
    "DchgB": "depth change in bed load",
    "BedFx": "x bedload flux",
    "BedFy": "y bedload flux",
}


def read_solution(output_folder: Path, value_count: int, geom_uuid: str, binary: bool, xms_data: XmsData) -> (
        list)[DatasetWriter]:
    """Read FUNWAVE solution datasets.

    Args:
        output_folder(:obj:`Path`): The output folder.
        geom_uuid(:obj:`str`): Grid UUID.
        value_count(:obj:`int`): The number of values each time step.
        binary(:obj:`bool`): Are the values stored as binary?
        xms_data (XmsData): Xms Data of the simulation

    Returns:
        (:obj:`list[DatasetWriter]`)List of dataset writers.
    """
    times = get_solution_times(output_folder)
    directory_files = get_folder_files(output_folder)
    data_files = get_solution_data_files(directory_files, times)
    datasets = build_datasets(data_files, geom_uuid, value_count, binary, xms_data)
    process_station_files(output_folder)
    return datasets


@staticmethod
def rotate_dataset_by_angle(x_data, y_data, rotation_angle, null_value):
    """Rotates the XY vector passed in by the rotation angle on the Z axis.

    Arguments:
        vx (double): x component of the vector.
        vy (double): y component of the vector.
        rotation_angle (double): angle in degrees to rotate the vector by.

    Returns:
        (list of double): list of rotated [vx, vy] vector values.
    """
    new_x, new_y = [], []
    for x, y in zip(x_data, y_data):
        if x == null_value or y == null_value:
            new_x.append(null_value)
            new_y.append(null_value)
            continue
        rotated = rotate_vector_by_angle(x, y, rotation_angle)
        new_x.append(rotated[0])
        new_y.append(rotated[1])
    return new_x, new_y


@staticmethod
def rotate_vector_by_angle(vx, vy, rotation_angle):
    """Rotates the XY vector passed in by the rotation angle on the Z axis.

    Arguments:
        vx (double): x component of the vector.
        vy (double): y component of the vector.
        rotation_angle (double): angle in degrees to rotate the vector by.

    Returns:
        (list of double): list of rotated [vx, vy] vector values.
    """
    # Rotate vector about Z Axis by the grid angle:
    # Matrix multiplication for rotation about Z, where t = angle in radians:
    # |cos(t), -sin(t), 0|       |x|     |cos(t) * x + -sin(t) * y + 0 * 0|
    # |sin(t), cos(t),  0|   *   |y|  =  |sin(t) * x + cos(t) * y + 0 * 0 |
    # |0,      0,       1|       |0|     |0 * x + 0 * y + 1 * 0           |
    # x' = vx * cos(t) - vy * sin(t)
    # y' = vx * sin(t) + vy * cos(t)
    theta = math.radians(rotation_angle)
    cs = math.cos(theta)
    sn = math.sin(theta)
    return [vx * cs - vy * sn, vx * sn + vy * cs]


def get_folder_files(folder: Path) -> list[Path]:
    """Get files in a folder.

    Args:
        folder(:obj:`Path`): The folder.

    Returns:
        (:obj:`list[Path]`)A list of files in the folder.
    """
    folder_files = [f for f in Path(folder).iterdir() if f.is_file()]
    return folder_files


def build_datasets(data_files: dict[str, list[_DatasetFile]], geom_uuid: str, value_count: int,
                   binary: bool, xms_data: XmsData) -> list[DatasetWriter]:
    """Read and build FUNWAVE solution datasets.

    Args:
        data_files(:obj:`dict[str, list[_DatasetFile]]`): The data files to read.
        geom_uuid(:obj:`str`): Grid UUID.
        value_count(:obj:`int`): The number of values each time step.
        binary(:obj:`bool`): Are the values stored as binary?
        xms_data (XmsData): Xms Data of the simulation

    Returns:
        (:obj:`list[DatasetWriter]`):List of dataset writers.
    """
    datasets = []
    logger = logging.getLogger('xms.funwave')

    grid_angle = 0.0

    if xms_data.sim_uuid == '':
        xms_data.set_sim(xms_data.sim_item.uuid, xms_data.sim_item.name)
    values = xms_data.sim_data_model_control
    time_group = values.group('Time')

    t_intv_mean = time_group.parameter('T_INTV_mean').value
    steady_time = time_group.parameter('STEADY_TIME').value

    if 'X Velocity' in data_files and 'Y Velocity' in data_files:
        grid_angle = xms_data.cogrid.angle

        data_files['Velocity'] = {}
        data_files['Velocity']['X Velocity'] = copy.copy(data_files['X Velocity'])
        data_files['Velocity']['Y Velocity'] = copy.copy(data_files['Y Velocity'])
        data_files.pop('X Velocity', None)
        data_files.pop('Y Velocity', None)

    if 'Average X Velocity' in data_files and 'Average Y Velocity' in data_files:
        grid_angle = xms_data.cogrid.angle

        data_files['Average Velocity'] = {}
        data_files['Average Velocity']['X Velocity'] = copy.copy(data_files['Average X Velocity'])
        data_files['Average Velocity']['Y Velocity'] = copy.copy(data_files['Average Y Velocity'])
        data_files.pop('Average X Velocity', None)
        data_files.pop('Average Y Velocity', None)

    for name, dataset_files in data_files.items():
        if name in ['Velocity', 'Average Velocity']:
            dataset_writer = DatasetWriter(name=name, num_components=2, geom_uuid=geom_uuid, location='cells',
                                           time_units='Seconds')
            index = 1
            x_dataset_files = dataset_files['X Velocity']
            y_dataset_files = dataset_files['Y Velocity']

            for index in range(0, len(x_dataset_files)):
                logger.info(f'Reading {name} time step {index + 1}')
                x_data = read_data_from_file(x_dataset_files[index], value_count, binary, index + 1, dataset_writer)
                y_data = read_data_from_file(y_dataset_files[index], value_count, binary, index + 1, dataset_writer)

                new_x, new_y = rotate_dataset_by_angle(x_data, y_data, grid_angle, dataset_writer.null_value)

                listed_data = np.array([new_x, new_y], dtype=float).transpose()
                time = x_dataset_files[index].time
                if name in ['Average Velocity']:
                    time = steady_time + t_intv_mean * (index + 1)
                dataset_writer.append_timestep(time, listed_data)

        else:
            dataset_writer = DatasetWriter(name=name, geom_uuid=geom_uuid, location='cells', time_units='Seconds')
            index = 1
            for dataset_file in dataset_files:
                logger.info(f'Reading {name} time step {index}')
                data = read_data_from_file(dataset_file, value_count, binary, index, dataset_writer)
                time = dataset_file.time
                if name in ['Average Surface Elevation', 'Average Wave Height', 'Root Mean Square Wave Height',
                            'Significant Wave Height']:
                    time = steady_time + t_intv_mean * index
                dataset_writer.append_timestep(time, data)
                index += 1

        dataset_writer.appending_finished()
        datasets.append(dataset_writer)
    return datasets


def read_data_from_file(dataset_file, value_count, binary, index, dataset_writer):
    """
    Read data from a file and append it to the dataset writer.

    Args:
        dataset_file: The dataset file to read from.
        value_count: The number of values to read.
        binary: Whether the file is in binary format.
        index: The time step index.
        dataset_writer: The dataset writer to append data to.

    Returns:
        The read data.
    """
    data = read_data_file(dataset_file.path, value_count, binary)
    if dataset_file.multiplier != 1.0:
        data = dataset_file.multiplier * data
    if dataset_file.mask_path is not None:
        if index == 1:
            dataset_writer.null_value = -999.0
        mask = read_data_file(dataset_file.mask_path, value_count, binary)
        data[mask < 0.99] = float('nan')

    return data


def get_solution_times(output_folder: Path) -> Optional[np.ndarray]:
    """Read FUNWAVE solution times.

    Args:
        output_folder(:obj:`Path`): The output folder.

    Returns:
        (:obj:`numpy.ndarray`)Optional array of absolute times.
    """
    times_file = output_folder.parent / 'time_dt.out'
    if times_file.is_file():
        times_values = np.loadtxt(times_file)
        times = times_values.transpose()[0]
        return times
    return None


def get_solution_data_files(directory_files: list[Path], times: Optional[np.ndarray]) -> dict[str, list[_DatasetFile]]:
    """Get a list of FUNWAVE solution data files.

    Args:
        directory_files(:obj:`list[Path]`): The folder to look in.
        times(:obj:`numpy.ndarray`): The absolute time of each time step.

    Returns:
        (:obj:`dict[str, list[_DatasetFile]]`):Dictionary of solution files to read with dataset name to a list of
        (file, time step).
    """
    data_files = {}
    # look for steady state data files
    for name, description, multiplier in STEADY_STATE_FILE_TYPES:
        for file in directory_files:
            if file.name == name and file.is_file():
                data_files[description] = [_DatasetFile(file, None, 0.0, multiplier)]
    # look for transient data files
    # must have list of times to read transient data
    if times is not None:
        for type_, description_ in TRANSIENT_FILE_TYPES.items():
            files = [f for f in directory_files if f.stem.startswith(type_ + '_')]
            files.sort()
            if files:
                description = description_.title()
                data_files[description] = []
                for file in files:
                    # get time index from file name (for example 'eta_00001')
                    # remove leading name, remove leading '_' and zeros, and get time index
                    file_name = file.stem
                    index_string = file_name.removeprefix(type_)
                    index_string = index_string[1:].lstrip('0')
                    time_index = 0 if index_string == '' else int(index_string)
                    data_files[description].append(_DatasetFile(file, None, times[time_index], 1.0))
    if 'Surface Elevation' in data_files and 'Wetting Mask' in data_files:
        wetting_mask = data_files.pop('Wetting Mask')
        surface_elevation = data_files['Surface Elevation']
        for i in range(len(wetting_mask)):
            surface_elevation[i].mask_path = wetting_mask[i].path
    return data_files


def read_data_file(file: Path, value_count: int, binary: bool) -> np.ndarray:
    """Read data from FUNWAVE solution file.

    Args:
        file(:obj:`Path`): The file.
        value_count(:obj:`int`): The number of values each time step.
        binary(:obj:`bool`): Is data binary?

    Returns:
        (:obj:`np.ndarray`):Numpy array of the data.
    """
    if binary:
        with open(file, mode='rb') as f:
            # first try single precision
            data_type = np.dtype('<f4')
            data = np.fromfile(f, dtype=data_type)
            if len(data) > value_count:
                # try double precision
                data_type = np.dtype('<f8')
                data = np.fromfile(f, dtype=data_type)
    else:
        data = np.loadtxt(file)
        data = data.flatten()
    return data


def process_station_files(output_folder):
    """Read the station files and write them in a format that SMS will use.

    Args:
        output_folder: The output folder.

    """
    # Read Existing data
    count = 1
    found = True
    stations = []
    while found:
        last_timestep = -sys.float_info.max
        station_filename = str(output_folder / 'sta_') + str(count).zfill(4)
        if os.path.exists(station_filename):
            with open(station_filename) as f:
                lines = f.readlines()
                timesteps = []
                vel_x = []
                vel_y = []
                vel_mag = []
                wse = []
                for line in lines:
                    tokens = line.split()
                    if len(tokens) > 3:
                        try:
                            timestep = float(tokens[0])
                            if timestep > last_timestep:
                                last_timestep = timestep
                                timesteps.append(timestep)
                                wse.append(float(tokens[1]))
                                vel_x_val = float(tokens[2])
                                vel_y_val = float(tokens[3])
                                vel_x.append(vel_x_val)
                                vel_y.append(vel_y_val)
                                vel_mag.append((vel_x_val * vel_x_val + vel_y_val * vel_y_val) ** 0.5)
                        except ValueError:
                            pass
                data_dict = {'Time_s': np.array(timesteps), 'Vel_mag_mps': np.array(vel_mag),
                             'X_vel_mps': np.array(vel_x), 'Y_vel_mps': np.array(vel_y),
                             'WSE_m': np.array(wse)}
                new_array = pd.DataFrame(data_dict)
                stations.append(new_array)
        else:
            found = False
        count += 1

    # Write data
    count = 1
    Path(output_folder / 'Output_MISC').mkdir(parents=True, exist_ok=True)
    for station in stations:
        new_filename = str(output_folder / 'Output_MISC' / 'case_PT') + str(count) + '.dat'
        station.to_csv(new_filename, sep='\t', index=False)
        count += 1
