"""A writer for CMS-Wave energy files."""

# 1. Standard Python modules
import copy
import datetime
from enum import Enum
from io import StringIO
import math
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import RectilinearGrid

# 4. Local modules
from xms.cmswave.data.simulation_data import convert_time_into_seconds
import xms.cmswave.file_io.dataset_writer

# ------------------------------------------------------------------------------
# Code adapted from GetAndTransformSpectralVals() in SMS
# ------------------------------------------------------------------------------

TOLERANCE = 0.000001


class PlaneTypes(Enum):
    """Enumeration of plane types."""
    LOCAL = 0
    GLOBAL = 1
    HALF = 2
    BOTH = 3


def get_2d_list(a_1d_data, a_grid_params):
    """
    Converts a 1-dimensional list of data to a 2-dimensional one.

    Args:
        a_1d_data (:obj:`double*`): list of data to be converted
        a_grid_params (:obj:`SpectralParams`): spectral object with frequencies and angles

    Returns:
        (:obj:`list[list[float]]`): The 2 dimensional list of data
    """
    # rearrange lists into [freq][angle]
    num_freq = len(a_grid_params.freqs)
    num_ang = len(a_grid_params.angles)  # want the actual size here no matter what
    two_d_list = [[0.0 for j in range(num_ang)] for i in range(num_freq)]
    for i in range(num_freq):
        for j in range(num_ang):
            idx = (j * num_freq) + i
            two_d_list[i][j] = a_1d_data[idx]
        if a_grid_params.plane_type != PlaneTypes.HALF:
            two_d_list[i][0] = two_d_list[i][num_ang - 1]
    return two_d_list


def reverse_directions(a_2d_list):
    """
    Reverses the order of a 2-dimensional list.

    This Reverses the order of the deepest dimension (angles) and preserves the order of the top level (frequencies).
    There is no return value, since the list is reversed in-place and implicitly passed by reference.

    Args:
        a_2d_list (:obj:`list` of :obj:`list` of float): the 2-dimensional list to be reversed.
    """
    for idx, freq in enumerate(a_2d_list):
        a_2d_list[idx] = list(reversed(freq))


def angles_match(a_grid1, a_grid2):
    """
    Checks if the angles of the two grids given match.

    Args:
        a_grid1 (:obj:`SpectralParams`): First grid
        a_grid2 (:obj:`SpectralParams`): Second grid

    Returns:
        (:obj:`bool`): True if the angles match, False if they do not. Tolerance is {TOLERANCE}.
    """
    if len(a_grid1.angles) != len(a_grid2.angles):
        return False
    if a_grid1.const_angle != a_grid2.const_angle:
        return False
    if not math.isclose(a_grid1.ang_size, a_grid2.ang_size, abs_tol=TOLERANCE):
        return False
    if a_grid1.const_angle:
        if a_grid1.angles and not math.isclose(a_grid1.angles[0], a_grid2.angles[0], abs_tol=TOLERANCE):
            return False
    else:
        for ang1, ang2 in zip(a_grid1.angles, a_grid2.angles):
            if not math.isclose(ang1, ang2, abs_tol=TOLERANCE):
                return False
    return True


def freqs_match(a_grid1, a_grid2):
    """
    Checks if the frequencies of the two given grids match.

    Args:
        a_grid1 (:obj:`SpectralParams`): First grid
        a_grid2 (:obj:`SpectralParams`): Second grid

    Returns:
        (:obj:`bool`): True if the angles match, False if they do not. Tolerance is {TOLERANCE}.
    """
    if len(a_grid1.freqs) != len(a_grid2.freqs):
        return False
    if a_grid1.const_freq != a_grid2.const_freq:
        return False
    if not math.isclose(a_grid1.freq_size, a_grid2.freq_size, abs_tol=TOLERANCE):
        return False
    if a_grid1.const_freq:
        if a_grid1.freqs and not math.isclose(a_grid1.freqs[0], a_grid2.freqs[0], abs_tol=TOLERANCE):
            return False
    else:
        for freq1, freq2 in zip(a_grid1.freqs, a_grid2.freqs):
            if not math.isclose(freq1, freq2, abs_tol=TOLERANCE):
                return False
    return True


class SpectralParams:
    """A class for spectral parameters."""
    def __init__(self, plane_type, grid, func, ts_idx, freqs, angles, grid_proj, time=None):
        """
        Defines a single spectra.

        Args:
            plane_type (:obj:`PlaneType`): Enum indicating half or full plane
            grid (:obj:`RectilinearGrid` or :obj:`xms.constraint.RectilinearGrid2d`): The Cartesian grid the spectra is
                defined on
            func (:obj:`data_objects.parameters.Dataset`): The spectral dataset
            ts_idx (:obj:`int`): Time step index of the spectra in the spectral dataset
            freqs (:obj:`list`): List of the frequencies
            angles (:obj:`list`): List of the angles
            grid_proj (:obj:`data_objects.parameters.Projection`): Projection of the spectral grid
            time (:obj:`data_objects.parameters.DateTimeLiteral`): Timestamp of the spectra

        """
        self.freqs = freqs
        self.angles = angles  # use len(self.angles) when C++ code uses size of angle vector
        self.plane_type = plane_type
        self.proj = grid_proj
        if isinstance(grid, RectilinearGrid):  # api RectilinearGrid
            origin = grid.origin
            self.grid_origin_x = origin.x
            self.grid_origin_y = origin.y
            self.grid_angle = grid.angle
            self.grid_i_sizes = [i_size for i_size in grid.i_sizes]
            self.grid_j_sizes = [j_size for j_size in grid.j_sizes]
        else:  # xms.constraint.RectilinearGrid2d
            self.grid_origin_x = grid.origin[0]
            self.grid_origin_y = grid.origin[1]
            self.grid_angle = grid.angle
            # Convert offsets from the origin to i/j sizes
            self.grid_i_sizes = [
                grid.locations_x[i + 1] - grid.locations_x[i] for i in range(len(grid.locations_x) - 1)
            ]
            self.grid_j_sizes = [
                grid.locations_y[i + 1] - grid.locations_y[i] for i in range(len(grid.locations_y) - 1)
            ]
        self.func = func  # api Dataset
        self.ts_idx = ts_idx
        self.const_angle = True
        self.const_freq = True
        self.ang_size = 0.0  # either the minimum delta or constant angle
        self.freq_size = 0.0  # either the minimum delta or constant frequency
        self.time = time
        self.load_angles()
        self.load_frequencies()
        self.orig_angs = copy.deepcopy(self.angles)  # don't touch these, used for deep copying
        self.orig_freqs = copy.deepcopy(self.freqs)
        self.from_type = copy.deepcopy(self.plane_type)

    def reset(self):
        """Resets angles, freqs, and plane_type to the original values used for construction."""
        self.angles = copy.deepcopy(self.orig_angs)
        self.freqs = copy.deepcopy(self.orig_freqs)
        self.plane_type = copy.deepcopy(self.from_type)

    def get_num_angs(self):  # make sure to use this function when C++ code calls SpecGridDef::GetNumAngles()
        """
        Returns the appropriate number of angles depending on plane_type.

        Returns:
            (:obj:`int`): Number of angles.

        """
        if self.plane_type == PlaneTypes.HALF:
            return len(self.angles)
        else:
            return len(self.angles) - 1

    def load_angles(self):
        """
        Determines and loads the correct value for ang_size from angles.

        Changes Members:
            ang_size
            angles
            const_angle
        """
        if not self.angles:  # load default angles if none provided
            curr_j = self.grid_origin_y
            self.angles.append(curr_j)
            if self.grid_j_sizes:
                curr_j += self.grid_j_sizes[0]
                self.angles.append(curr_j)
                delta = self.angles[1] - self.angles[0]
                min_delta = float("inf")
                for i in range(1, len(self.grid_j_sizes)):
                    curr_j += self.grid_j_sizes[i]
                    self.angles.append(curr_j)
                    temp_delta = self.angles[i] - self.angles[i - 1]
                    if self.const_angle and not math.isclose(delta, temp_delta, abs_tol=TOLERANCE):
                        self.const_angle = False
                    if not self.const_angle:
                        min_delta = min(min_delta, temp_delta)
                if self.const_angle:
                    self.ang_size = delta
                else:
                    self.ang_size = min_delta
        elif len(self.angles) > 1:  # compute constant/min angle
            first_ang = self.angles[0]
            last_ang = self.angles[1]
            self.ang_size = last_ang - first_ang
            for i in range(2, len(self.angles)):
                delta = last_ang - self.angles[i]
                if delta < self.ang_size:
                    self.ang_size = delta
                    self.const_angle = False
                last_ang = self.angles[i]

    def load_frequencies(self):
        """
        Determines and loads the correct value for freq_size from freqs.

        Changes Members:
            freq_size
            const_freq
            freqs
        """
        if not self.freqs:  # load default angles if none provided
            curr_i = self.grid_origin_x
            self.freqs.append(curr_i)
            if self.grid_i_sizes:
                curr_i += self.grid_i_sizes[0]
                self.freqs.append(curr_i)
                curr_delta = self.freqs[1] - self.freqs[0]
                min_delta = float("inf")
                for i in range(1, len(self.grid_i_sizes)):
                    curr_i += self.grid_i_sizes[i]
                    self.freqs.append(curr_i)
                    temp_delta = self.freqs[i] - self.freqs[i - 1]
                    if self.const_freq and not math.isclose(curr_delta, temp_delta, abs_tol=TOLERANCE):
                        self.const_freq = False
                    if not self.const_freq:
                        min_delta = min(min_delta, temp_delta)
                if self.const_freq:
                    self.freq_size = self.freqs[1] - self.freqs[0]
                else:
                    self.freq_size = min_delta
        elif len(self.freqs) > 1:  # compute constant/min angle
            first_freq = self.freqs[0]
            last_freq = self.freqs[1]
            self.freq_size = last_freq - first_freq
            for i in range(2, len(self.freqs)):
                delta = last_freq - self.freqs[i]
                if delta < self.freq_size:
                    self.freq_size = delta
                    self.const_freq = False
                last_freq = self.freqs[i]

    def set_angles(self, angles):
        """
        Loads angles using list provided.

        Args:
            angles (:obj:`list[int]`): list of angles.
        """
        self.const_angle = True
        self.angles = []
        if not angles:
            return
        self.angles.append(angles[0])
        if len(angles) == 1:
            return
        self.angles.append(angles[1])
        delta = angles[1] - angles[0]
        min_delta = float("inf")
        for i in range(2, len(angles)):
            self.angles.append(angles[i])
            temp_delta = angles[i] - angles[i - 1]
            if self.const_angle and not math.isclose(delta, temp_delta, abs_tol=TOLERANCE):
                self.const_angle = False
            if not self.const_angle:
                min_delta = min(min_delta, temp_delta)
        if self.const_angle:
            self.ang_size = delta
        else:
            self.ang_size = min_delta

    def set_angles_const(self, a_min, a_max, a_delta):
        """
        Loads a set of angles with constant delta.

        Args:
            a_min (:obj:`int`): minimum angle to start with
            a_max (:obj:`int`): maximum angle to end with
            a_delta (:obj:`int`): constant delta
        """
        self.angles = []
        curr_angle = a_min
        while curr_angle < a_max or math.isclose(curr_angle, a_max, abs_tol=TOLERANCE):
            self.angles.append(curr_angle)
            curr_angle += a_delta
        self.const_angle = True
        self.ang_size = a_delta

    def set_freqs(self, a_freqs):
        """
        Loads frequencies using list provided.

        Args:
            a_freqs (:obj:`list[int]`): list of frequencies.
        """
        self.const_freq = True
        self.freqs = []
        if not a_freqs:
            return
        self.freqs.append(a_freqs[0])
        if len(a_freqs) == 1:
            return
        self.freqs.append(a_freqs[1])
        delta = a_freqs[1] - a_freqs[0]
        min_delta = float("inf")
        for i in range(2, len(a_freqs)):
            self.freqs.append(a_freqs[i])
            temp_delta = a_freqs[i] - a_freqs[i - 1]
            if self.const_freq and not math.isclose(delta, temp_delta, abs_tol=TOLERANCE):
                self.const_freq = False
            if not self.const_freq:
                min_delta = min(min_delta, temp_delta)
        if self.const_freq:
            self.freq_size = delta
        else:
            self.freq_size = min_delta

    def set_freqs_const(self, a_min, a_max, a_delta):
        """
        Loads a set of frequencies with constant delta.

        Args:
            a_min (:obj:`int`): minimum frequency to start with
            a_max (:obj:`int`): maximum frequency to end with
            a_delta (:obj:`int`): constant delta
        """
        self.freqs = []
        curr_freq = a_min
        while curr_freq < a_max or math.isclose(curr_freq, a_max, abs_tol=TOLERANCE):
            self.freqs.append(curr_freq)
            curr_freq += a_delta
        self.const_freq = True
        self.freq_size = a_delta


class CMSWAVECase:
    """Represents a single case."""
    def __init__(self, case_time, wind_dir, wind_speed, tidal_surge):
        """
        Constructor.

        Args:
            case_time (:obj:`float`): A julian double of the time.
            wind_dir (:obj:`float`): The wind direction in degrees.
            wind_speed (:obj:`float`): The wind speed.
            tidal_surge (:obj:`float`): The height of the tidal surge.
        """
        self.time = case_time  # julian double
        self.windDir = wind_dir
        self.windSpeed = wind_speed
        self.tidalSurge = tidal_surge


class SpectralGridConverter:
    """A class for converting spectral grids."""
    def __init__(self, params1, params2, global_params, case_time):
        """Constructor."""
        self.params1 = params1
        self.params2 = params2
        self.globalParams = global_params
        self.time = case_time
        self.dTime1 = 1.0
        self.dTime2 = 0.0
        self.finalVals = []
        self.vals1 = []
        self.vals2 = []
        self.twoGrids = False
        if params2:
            self.twoGrids = True
        # self.debugger = open("debug.txt", "a")

    def transform_spectral_vals(self):
        """Transforms data, then interpolates it into vals."""
        if not self.globalParams.angles or not self.globalParams.freqs:
            return  # nothing to do
        self.params1.func.ts_idx = self.params1.ts_idx
        self.vals1 = get_2d_list(self.params1.func.data, self.params1)
        if not self.vals1:  # nothing to do
            return
        if self.twoGrids:
            self.params2.func.ts_idx = self.params2.ts_idx
            self.vals2 = get_2d_list(self.params2.func.data, self.params2)
            if not self.vals2:  # datasets don't match up
                return

        # make sure everything is counter-clockwise
        if self.params1.plane_type == PlaneTypes.GLOBAL:
            reverse_directions(self.vals1)
        if self.twoGrids and self.params2.plane_type == PlaneTypes.GLOBAL:
            reverse_directions(self.vals2)

        self.rotate()

        if self.need_simple_full_to_half(True):
            self.full_to_half_only(True)
        if self.twoGrids and self.need_simple_full_to_half(False):
            self.full_to_half_only(False)

        # compute interpolation factors if interpolating between two points
        if self.twoGrids:
            case_time = self.time
            total_timespan = self.params2.time - self.params1.time
            func1_timespan = case_time - self.params1.time
            func2_timespan = self.params2.time - case_time
            if total_timespan:
                self.dTime1 = 1 - (func1_timespan / total_timespan)
                self.dTime2 = 1 - (func2_timespan / total_timespan)

        self.interpolate_data()

        if self.globalParams.plane_type == PlaneTypes.GLOBAL:
            reverse_directions(self.finalVals)

    def rotate(self):
        """Rotates one of two grids to correct for global grid angle."""
        is_close = math.isclose(self.params1.grid_angle, self.globalParams.grid_angle, abs_tol=TOLERANCE)
        plane = self.params1.plane_type == PlaneTypes.GLOBAL or self.globalParams.plane_type == PlaneTypes.GLOBAL
        if not is_close or (plane and self.params1.plane_type != self.globalParams.plane_type):
            self.rotate_fp(True)
        if self.twoGrids:
            is_close = math.isclose(self.params2.grid_angle, self.globalParams.grid_angle, abs_tol=TOLERANCE)
            plane = self.params2.plane_type == PlaneTypes.GLOBAL or self.globalParams.plane_type == PlaneTypes.GLOBAL
            if not is_close or (plane and self.params2.plane_type != self.globalParams.plane_type):
                self.rotate_fp(False)

    def rotate_fp(self, is_first_grid):
        """
        Helper function for rotate.

        Args:
            is_first_grid (:obj:`bool`): True to use params1, False to use params2.
        """
        # get the correct data
        if is_first_grid:
            is_global = self.params1.plane_type == PlaneTypes.GLOBAL
            from_angle = self.params1.grid_angle
            from_vals = self.vals1
            num_freq = len(self.params1.freqs)
            num_ang = self.params1.get_num_angs()
        else:
            is_global = self.params2.plane_type == PlaneTypes.GLOBAL
            from_angle = self.params2.grid_angle
            from_vals = self.vals2
            num_freq = len(self.params2.freqs)
            num_ang = self.params2.get_num_angs()

        if num_ang == 0:
            return  # avoid division by zero
        to_angle = self.globalParams.grid_angle
        if not is_global and self.globalParams.plane_type == PlaneTypes.GLOBAL:
            to_angle = -90.0
        if is_global and self.globalParams.plane_type != PlaneTypes.GLOBAL:
            to_angle += 90.0

        conv_factor = math.pi / 180.0

        to_conv = to_angle * conv_factor
        from_conv = from_angle * conv_factor
        rotate_factor = math.atan2(
            math.sin(to_conv) * math.cos(from_conv) - math.cos(to_conv) * math.sin(from_conv),
            math.cos(to_conv) * math.cos(from_conv) + math.sin(to_conv) * math.sin(from_conv)
        )
        dth = (360.0 / num_ang) * conv_factor
        irot = int(rotate_factor / dth)
        wt1 = 1.0 - abs(math.fmod(rotate_factor, dth) / dth)
        wt2 = 1.0 - wt1
        rotated_vals = [[0.0 for _ in range(len(from_vals[0]))] for _ in range(len(from_vals))]
        for i in range(num_freq):
            if irot >= 0:  # positive rotation
                for j in range(num_ang):
                    temp_ang = j - irot
                    if temp_ang < 0:
                        temp_ang += num_ang
                    rotated_vals[i][temp_ang] = (from_vals[i][j] * wt1) + (from_vals[i][j + 1] * wt2)
                rotated_vals[i][num_ang] = rotated_vals[i][0]
            else:  # negative rotation
                for j in range(num_ang, 0, -1):
                    temp_ang = j - irot
                    if temp_ang > num_ang:
                        temp_ang -= num_ang
                    rotated_vals[i][temp_ang] = (from_vals[i][j] * wt1) + (from_vals[i][j - 1] * wt2)
                rotated_vals[i][0] = rotated_vals[i][num_ang]

        # TODO: I believe there is a potential optimization we could do here. In the SMS 12.3 code, see
        #       ConvertSpectralGridVals::RotateFP. We set the angle of the spectral grid to be the global grid
        #       angle. I think this is so we don't have to translate this grid again. This broke though when we
        #       switched to CoGrid because setting the angle on a data_objects Rectilinear grid does nothing.
        #       To get it working, I think we need to do some resetting because doing this is only valid in some
        #       cases. In the SMS 12.3 code, see GetAndTransformSpectralVals(). Note the logic for transforming
        #       vs. resetting of the spectral params.
        if is_first_grid:
            self.vals1 = rotated_vals
            # self.params1.grid_angle = self.globalParams.grid_angle
        else:
            self.vals2 = rotated_vals
            # self.params2.grid_angle = self.globalParams.grid_angle

    def need_simple_full_to_half(self, is_first_grid):
        """Determines if a full to half conversion is needed.

        Args:
            is_first_grid (:obj:`bool`): True to use params1, False to use params2.

        Returns:
            (:obj:`bool`):True if conversion is needed, False if not.

        """
        if self.globalParams.plane_type != PlaneTypes.HALF:
            return False

        if is_first_grid:
            angles = self.params1.angles
            const_ang = self.params1.const_angle
            ang_size = self.params1.ang_size
        else:
            angles = self.params2.angles
            const_ang = self.params2.const_angle
            ang_size = self.params2.ang_size

        angle_close = math.isclose(ang_size, self.globalParams.ang_size, abs_tol=TOLERANCE)
        if not const_ang or not self.globalParams.const_angle or not angle_close:
            return False

        first_half_angle = self.globalParams.angles[0]
        if first_half_angle < 0.0:
            first_half_angle += 360.0

        for ang in angles:
            if math.isclose(first_half_angle, ang, abs_tol=TOLERANCE):
                return True

        return False

    def full_to_half_only(self, is_first_grid):
        """
        Loads the half angles into vals, sets the plane_type to HALF, and loads const angles into params.

        Args:
            is_first_grid (:obj:`bool`): True to use params1, False to use params2.
        """
        if is_first_grid:
            full_angles = self.params1.angles
            full_vals = self.vals1
            num_freq = len(self.params1.freqs)
        else:
            full_angles = self.params2.angles
            full_vals = self.vals2
            num_freq = len(self.params2.freqs)
        half_angles = self.globalParams.angles
        half_vals = [[0.0 for _ in range(self.globalParams.get_num_angs())] for _ in range(num_freq)]

        for i in range(len(half_angles)):
            angle_val = half_angles[i]
            if angle_val < 0.0:
                angle_val += 360.0

            found = False
            j_index = 0
            for j in range(len(full_angles)):
                j_index = j
                if full_angles[j] == angle_val:
                    found = True
                    break

            if not found:
                continue

            for k in range(0, num_freq):
                half_vals[k][i] = full_vals[k][j_index]

        if is_first_grid:
            self.vals1 = half_vals
            self.params1.plane_type = PlaneTypes.HALF
            self.params1.set_angles_const(half_angles[0], half_angles[len(half_angles) - 1], self.globalParams.ang_size)
        else:
            self.vals2 = half_vals
            self.params2.plane_type = PlaneTypes.HALF
            self.params2.set_angles_const(half_angles[0], half_angles[len(half_angles) - 1], self.globalParams.ang_size)

    def interpolate_data(self):
        """Interpolates data between params1 and params2 into self.finalVals."""
        # if we aren't doing two datasets make sure everything matches, call it good
        if not self.twoGrids:
            frequencies_match = freqs_match(self.params1, self.globalParams)
            angs_match = angles_match(self.params1, self.globalParams)
            if frequencies_match and angs_match:
                self.finalVals = self.vals1
            else:
                self.super_sample(True, not frequencies_match, not angs_match)
                self.resample(True, not frequencies_match, not angs_match)
                self.finalVals = self.vals1
            return

        # we are doing two datasets, interpolate
        frequencies_match = freqs_match(self.params1, self.params2) and freqs_match(self.params2, self.globalParams)
        # super sample both sets of data
        self.super_sample(True, not frequencies_match, True)
        self.super_sample(False, not frequencies_match, True)

        # determine the peak direction for vals1 and vals2
        peak1 = 0.0
        peak2 = 0.0
        peak_dir_1 = self.globalParams.angles[0]
        peak_dir_2 = self.globalParams.angles[0]
        for i in range(len(self.params1.freqs)):
            for j in range(self.params1.get_num_angs()):
                if self.vals1[i][j] > peak1:
                    peak1 = self.vals1[i][j]
                    peak_dir_1 = self.params1.angles[j]
                if self.vals2[i][j] > peak2:
                    peak2 = self.vals2[i][j]
                    peak_dir_2 = self.params1.angles[j]

        # determine what the interpolated peak direction should be
        tot = peak_dir_1 - peak_dir_2
        if tot > 180.0:
            tot -= 360.0
        elif tot < -180.0:
            tot += 360.0
        new_peak_dir = peak_dir_1 - (tot * self.dTime2)
        # determine how far to rotate each set
        rot1 = int(peak_dir_1 - new_peak_dir)
        rot2 = int(peak_dir_2 - new_peak_dir)

        # rotate vals1 and vals2 as necessary
        finished_vals = [[0.0 for _ in range(len(self.params1.angles))] for _ in range(len(self.params1.freqs))]
        for i in range(len(self.params1.angles)):
            tmp_angle_1 = i + rot1
            tmp_angle_2 = i + rot2
            if self.globalParams.plane_type != PlaneTypes.HALF:
                if tmp_angle_1 < 0:
                    tmp_angle_1 += self.params1.get_num_angs()
                elif tmp_angle_1 > self.params1.get_num_angs():
                    tmp_angle_1 -= self.params1.get_num_angs()
                if tmp_angle_2 < 0:
                    tmp_angle_2 += self.params1.get_num_angs()
                elif tmp_angle_2 > self.params1.get_num_angs():
                    tmp_angle_2 -= self.params1.get_num_angs()
            for j in range(len(self.params1.freqs)):
                val1 = 0.0
                val2 = 0.0
                if (tmp_angle_1 >= 0) and (tmp_angle_1 < len(self.params1.angles)):
                    val1 = self.vals1[j][tmp_angle_1]
                if (tmp_angle_2 >= 0) and (tmp_angle_2 < len(self.params1.angles)):
                    val2 = self.vals2[j][tmp_angle_2]
                finished_vals[j][i] = (val1 * self.dTime1) + (val2 * self.dTime2)

        self.vals1 = finished_vals
        self.resample(True, not frequencies_match, True)
        self.finalVals = self.vals1

    def super_sample(self, a_is_first_grid, a_do_freqs, a_do_angs):
        """
        Performs Super Sampling for frequencies and/or angles.

        Args:
            a_is_first_grid (:obj:`bool`): True to use params1, False to use params2
            a_do_freqs (:obj:`bool`): True to super_sample_freqs, False to skip
            a_do_angs (:obj:`bool`): True to super_sample_angles, False to skip
        """
        if a_is_first_grid:
            vals = self.vals1
            params = self.params1
        else:
            vals = self.vals2
            params = self.params2
        if a_do_angs:
            sampled_angs = self.super_sample_angles(vals, params)
            if a_is_first_grid:
                self.vals1 = sampled_angs
            else:
                self.vals2 = sampled_angs
        if a_do_freqs:
            sampled_freqs = self.super_sample_freqs(vals, params)
            if a_is_first_grid:
                self.vals1 = sampled_freqs
            else:
                self.vals2 = sampled_freqs

    def super_sample_angles(self, a_vals, a_params):
        """
        Takes a sampled list from vals and sets new data for params angles and plane_type.

        Args:
            a_vals (:obj:`list[list[float]]`): The data to be sampled from.
            a_params (:obj:`SpectralParams`): The SpectralParams to write new data for.

        Returns:
            (:obj:`list[list[float]]`): The sample data.
        """
        if not self.globalParams.angles:  # nothing to do
            return []

        # set stuff up
        super_angles = []
        min_angle = self.globalParams.angles[0]
        max_angle = self.globalParams.angles[len(self.globalParams.angles) - 1]
        while min_angle < max_angle or math.isclose(min_angle, max_angle, abs_tol=TOLERANCE):
            super_angles.append(min_angle)
            min_angle += 1
        sampled_angs = [[0.0 for _ in range(len(super_angles))] for _ in range(len(a_params.freqs))]
        half_to_full = False
        full_to_half = False
        if self.globalParams.plane_type != PlaneTypes.HALF and a_params.plane_type == PlaneTypes.HALF:
            half_to_full = True
        elif self.globalParams.plane_type == PlaneTypes.HALF and a_params.plane_type != PlaneTypes.HALF:
            full_to_half = True
        old_angle_1 = 0
        old_angle_2 = 1
        for i in range(len(super_angles)):
            super_angle_val = super_angles[i]
            if full_to_half and super_angle_val < 0.0:
                super_angle_val += 360.0
            found = False
            use_two = False
            restarted = False
            angle_val_1 = 0.0
            angle_val_2 = 0.0

            # we have to start over if we reached the end
            # this happens when going from full to half
            if old_angle_1 == a_params.get_num_angs():
                old_angle_1 = 0
                old_angle_2 = 1
                restarted = True

            while old_angle_1 < a_params.get_num_angs():
                angle_val_1 = a_params.angles[old_angle_1]
                angle_1_was_below_zero = False
                if half_to_full and angle_val_1 < 0.0:
                    angle_val_1 += 360.0
                    angle_1_was_below_zero = True
                angle_val_2 = angle_val_1
                angle_2_was_below_zero = False
                if old_angle_2 < len(a_params.angles):
                    angle_val_2 = a_params.angles[old_angle_2]
                    if half_to_full and angle_val_2 < 0.0:
                        angle_val_2 += 360.0
                        angle_2_was_below_zero = True
                    use_two = True

                if (angle_val_1 < super_angle_val or math.isclose(angle_val_1, super_angle_val, abs_tol=TOLERANCE)) \
                        and (angle_val_2 > super_angle_val or math.isclose(angle_val_2, super_angle_val,
                                                                           abs_tol=TOLERANCE)):
                    found = True
                    break

                # if we are going from half to full, we might miss the zero mark
                # in cases like if we are looking between angles 355 and 5.
                # Zero is between those, but the check above won't catch it.
                if half_to_full and angle_1_was_below_zero != angle_2_was_below_zero:
                    if (angle_val_1 - 360.0 < super_angle_val or math.isclose(angle_val_1, super_angle_val,
                                                                              abs_tol=TOLERANCE)) \
                            and (angle_val_2 > super_angle_val or math.isclose(angle_val_2, super_angle_val,
                                                                               abs_tol=TOLERANCE)):
                        found = True
                        angle_val_1 -= 360.0
                        break
                    elif (angle_val_1 < super_angle_val or math.isclose(angle_val_1, super_angle_val,
                                                                        abs_tol=TOLERANCE)) \
                            and (angle_val_2 + 360.0 > super_angle_val or math.isclose(angle_val_2, super_angle_val,
                                                                                       abs_tol=TOLERANCE)):
                        found = True
                        angle_val_2 += 360.0
                        break
                if old_angle_1 == a_params.get_num_angs() - 1 and not restarted:
                    old_angle_1 = -1
                    old_angle_2 = 0
                    restarted = True
                old_angle_1 += 1
                old_angle_2 += 1

            if not found:
                i += 1
                continue
            d1 = 1.0
            d2 = 0.0
            if use_two:
                d2 = (super_angle_val - angle_val_1) / (angle_val_2 - angle_val_1)
                d1 = 1 - d2

            val2 = 0.0
            for j in range(len(a_params.freqs)):
                val1 = a_vals[j][old_angle_1]
                if use_two:
                    val2 = a_vals[j][old_angle_2]
                sampled_angs[j][i] = (val1 * d1) + (val2 * d2)

        a_params.set_angles(super_angles)
        a_params.plane_type = self.globalParams.plane_type
        return sampled_angs

    def super_sample_freqs(self, a_vals, a_params):
        """
        Takes a sampled list from vals and sets new data for params freqs.

        Args:
            a_vals (:obj:`list[list[float]]`): The data to be sampled from.
            a_params (:obj:`SpectralParams`): The SpectralParams to write new data for.

        Returns:
            (:obj:`list[list[float]]`): The sample data.
        """
        if not a_params.freqs or not self.globalParams.freqs:
            return []
        super_delta_freq = self.globalParams.freq_size / 3.0
        super_freqs = []
        min_freq = self.globalParams.freqs[0]
        max_freq = self.globalParams.freqs[len(self.globalParams.freqs) - 1]
        while min_freq < max_freq or math.isclose(min_freq, max_freq, abs_tol=TOLERANCE):
            super_freqs.append(min_freq)
            min_freq += super_delta_freq
        sampled_freqs = [[0.0 for _ in range(len(a_params.angles))] for _ in range(len(super_freqs))]
        old_freq_1 = 0
        old_freq_2 = 1
        super_freq = 0
        freq_val_1 = a_params.freqs[0]
        freq_val_2 = a_params.freqs[0]
        super_freq_val = super_freqs[0]

        # in case the new frequencies begin before the "from" ones, advance to
        # where the data will line up
        while super_freq_val < freq_val_1:
            super_freq += 1
            super_freq_val = super_freqs[super_freq]
        while super_freq < len(super_freqs):
            super_freq_val = super_freqs[super_freq]
            found = False
            use_two = False
            while old_freq_1 < len(a_params.freqs):
                freq_val_1 = a_params.freqs[old_freq_1]
                freq_val_2 = freq_val_1
                if old_freq_2 < len(a_params.freqs):
                    freq_val_2 = a_params.freqs[old_freq_2]
                    use_two = True
                freq_1_close = math.isclose(freq_val_1, super_freq_val, abs_tol=TOLERANCE)
                freq_2_close = math.isclose(freq_val_2, super_freq_val, abs_tol=TOLERANCE)
                if (freq_val_1 < super_freq_val or freq_1_close) and (freq_val_2 > super_freq_val or freq_2_close):
                    found = True
                    break
                old_freq_1 += 1
                old_freq_2 += 1
            if not found:
                super_freq += 1
                continue
            d1 = 1.0
            d2 = 0.0
            if use_two:
                d2 = (super_freq_val - freq_val_1) / (freq_val_2 - freq_val_1)
                d1 = 1 - d2
            val2 = 0.0
            for i in range(len(a_params.angles)):
                val1 = a_vals[old_freq_1][i]
                if use_two:
                    val2 = a_vals[old_freq_2][i]
                sampled_freqs[super_freq][i] = (val1 * d1) + (val2 * d2)
            super_freq += 1
        a_params.set_freqs(super_freqs)
        return sampled_freqs

    def resample(self, a_is_first_grid, a_do_freqs, a_do_angs):
        """
        Resample frequencies and/or angles into "vals" members.

        Args:
            a_is_first_grid (:obj:`bool`): True to use params1, False to use params2.
            a_do_freqs (:obj:`bool`): True to resample frequencies.
            a_do_angs (:obj:`bool`): True to resample angles.
        """
        if a_is_first_grid:
            vals = self.vals1
            params = self.params1
        else:
            vals = self.vals2
            params = self.params2
        if a_do_angs:
            final_angs = self.resample_angles(vals, params)
            if a_is_first_grid:
                self.vals1 = final_angs
            else:
                self.vals2 = final_angs
        if a_do_freqs:
            final_freqs = self.resample_freqs(vals, params)
            if a_is_first_grid:
                self.vals1 = final_freqs
            else:
                self.vals2 = final_freqs

    def resample_angles(self, a_vals, a_params):
        """
        Samples the angle values and sets the angles for params object.

        Args:
            a_vals (:obj:`list[list[float]]`): The data to be sampled
            a_params (:obj:`SpectralParams`): The object to set angles for.

        Returns:
            (:obj:`list[list[float]]`): 2D list of final angles.
        """
        final_angs = [[0.0 for _ in range(len(self.globalParams.angles))] for _ in range(len(a_params.freqs))]
        super_angle = 0
        delta = a_params.ang_size
        for final_angle in range(len(self.globalParams.angles)):
            half_before = 0
            half_after = 0
            fa = self.globalParams.angles[final_angle]
            if final_angle != 0:
                half_before = int((fa - self.globalParams.angles[final_angle - 1]) / (2 * delta))
            if final_angle != len(self.globalParams.angles) - 1:
                half_after = int((self.globalParams.angles[final_angle + 1] - fa) / (2 * delta))
            found = False
            while super_angle < len(a_params.angles):
                if math.isclose(a_params.angles[super_angle], self.globalParams.angles[final_angle], abs_tol=0.5):
                    found = True
                    break
                super_angle += 1
            if not found:
                return []

            super_ang_begin = super_angle - half_before
            super_ang_end = super_angle + half_after
            for i in range(len(a_params.freqs)):
                d_tot = 0.0
                i_tot = 0
                for j in range(super_ang_begin, super_ang_end + 1):
                    d_tot += a_vals[i][j]
                    i_tot += 1
                final_angs[i][final_angle] = float(d_tot / i_tot)
        a_params.set_angles(self.globalParams.angles)
        return final_angs

    def resample_freqs(self, a_vals, a_params):
        """
        Samples the frequency values and sets the freqs for params object.

        Args:
            a_vals (:obj:`list[list[float]]`): The data to be sampled.
            a_params (:obj:`SpectralParams`): The object to set frequencies for.

        Returns:
            (:obj:`list[list[float]]`): 2D list of final freqs.
        """
        final_freqs = [[0.0 for _ in range(len(self.globalParams.angles))] for _ in range(len(a_params.freqs))]
        super_freq = 0
        for final_freq in range(len(self.globalParams.freqs)):
            half_before = 0
            half_after = 0
            ff = self.globalParams.freqs[final_freq]
            if final_freq != 0:
                half_before = int((ff - self.globalParams.freqs[final_freq - 1]) / 2)
            if final_freq != len(self.globalParams.freqs) - 1:
                half_after = int((self.globalParams.freqs[final_freq + 1] - ff) / 2)

            found = False
            while super_freq < len(a_params.freqs):
                if math.isclose(a_params.freqs[super_freq], self.globalParams.freqs[final_freq], abs_tol=TOLERANCE):
                    found = True
                    break
                super_freq += 1
            if not found:
                return []

            super_freq_begin = super_freq - half_before
            super_freq_end = super_freq + half_after
            for i in range(len(a_params.angles)):
                d_tot = 0.0
                i_tot = 0
                for j in range(super_freq_begin, super_freq_end + 1):
                    d_tot += a_vals[j][i]
                    i_tot += 1
                final_freqs[final_freq][i] = float(d_tot / i_tot)
        a_params.set_freqs(self.globalParams.freqs)
        return final_freqs


class EngWriter:
    """A class for writing an CMS-Wave energy file."""
    def __init__(
        self,
        grid_name,
        pt_map,
        point_params,
        global_params,
        cases,
        using_wind_d_set,
        simref_time,
        time_units,
        angle_convention,
        date_format='12 digits'
    ):
        """Constructor."""
        self.gridName = grid_name
        self.ptMap = pt_map  # {pointId: Point} - would require an api change to use Point object as key in dict below
        self.pointParams = point_params  # {pointId: [SpectralParams]} - one SpectralParams per case per point
        self.globalParams = global_params  # simulation grid definition
        self.cases = cases  # [cases]
        self.usingWindDset = using_wind_d_set
        self.ss = StringIO()
        self.ref_time = simref_time
        self.time_units = time_units
        self.angle_convention = angle_convention
        self.date_format = date_format
        self.zero_spectrum = False

    def write(self):
        """
        Writes fortran header, frequencies, point header, and data vals to file.

        The file name is defined by the grid name: {self.gridName + ".eng", "w"}.
        """
        if not self.cases or not self.ptMap or not self.pointParams:
            return  # no data to write
        self.write_eng_header()
        self.write_frequencies()
        i = 1
        for case in self.cases:  # loop through cases
            cur_time = convert_time_into_seconds(self.time_units, float(case.time))
            dt = self.ref_time + datetime.timedelta(seconds=float(cur_time))
            for key, val in self.pointParams.items():  # loop through each point for this case
                pt_case_data = self.find_spec_dset_for_time(val, dt)
                if pt_case_data[0]:  # make sure we got at least one grid
                    pt_case_data[0].reset()
                    if pt_case_data[1]:
                        pt_case_data[1].reset()
                    converter = SpectralGridConverter(pt_case_data[0], pt_case_data[1], self.globalParams, dt)
                    converter.transform_spectral_vals()
                    self.write_point_header(case, key, converter.finalVals)
                    if self.zero_spectrum:
                        # Write out 0.0 energy values
                        vals = [[0.0] * len(converter.finalVals[i]) for i in range(len(converter.finalVals))]
                    else:
                        vals = converter.finalVals
                    self.write_data_vals(vals)
            i += 1

        # dump string stream to file
        out = open(self.gridName + ".eng", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out, 100000)
        out.close()

    @staticmethod
    def find_spec_dset_for_time(a_pt_params, a_time):
        """
        Finds the two SpectralParams for matching time and greater than a_time.

        Args:
            a_pt_params (:obj:`list[SpectralParams]`):
            a_time (:obj:`datetime`): datetime time value.

        Returns:
            (:obj:`tuple(SpectralParams)`):
                First index is set for matching time, second for greater than a_time.
        """
        if not a_pt_params:
            return None, None
        b_first = True
        params1 = a_pt_params[0]
        params2 = None
        for param in a_pt_params:
            ts_time = param.time
            if ts_time == a_time:
                params1 = param
                break
            elif ts_time > a_time:
                if not b_first:
                    params2 = param
                break
            params1 = param
            b_first = False
        return params1, params2

    def write_eng_header(self):
        """Writes header to EngWriter object file."""
        num_pts = len(self.ptMap)
        g_angle = self.globalParams.grid_angle
        num_freqs = len(self.globalParams.freqs)
        num_ang = self.globalParams.get_num_angs()

        self.ss.write(f'{num_freqs} {num_ang} {num_pts} {g_angle:.4f}\n')

    def write_frequencies(self):
        """Writes globalParam frequency data to file."""
        for idx, freq in enumerate(self.globalParams.freqs):  # write out 4 frequencies per line
            self.ss.write(f"{freq:.5f}")
            if (idx + 1) % 10 == 0 or idx == (len(self.globalParams.freqs) - 1):
                self.ss.write("\n")
            else:
                self.ss.write(" ")

    def write_point_header(self, a_case, a_point_id, vals):
        """
        Writes point header to file.

        Args:
            a_case (:obj:`CMSWAVECase`): Case with the data to be written.
            a_point_id (:obj:`int`): Point ID for indexing ptMap.
            vals (:obj:`list`):  List of values to write
        """
        rad = math.pi / 180.0
        dth = 180.0 / float(len(self.globalParams.angles) + 1.0)
        dth *= rad
        hs = 0.0
        for idx, _ in enumerate(self.globalParams.freqs):
            if idx == 0:
                delta_freq = self.globalParams.freqs[1] - self.globalParams.freqs[0]
            elif idx == len(self.globalParams.freqs) - 1:
                delta_freq = self.globalParams.freqs[idx] - self.globalParams.freqs[idx - 1]
            else:
                delta_freq = (self.globalParams.freqs[idx + 1] - self.globalParams.freqs[idx - 1]) / 2.0
            for j in range(len(self.globalParams.angles)):
                hs += vals[idx][j] * delta_freq
        hs = 4.0 * math.sqrt(hs * dth)
        peak_freq = 0.0
        peak_eng = -99999999.0
        for i in range(len(self.globalParams.freqs)):
            for j in range(len(self.globalParams.angles)):
                if vals[i][j] > peak_eng:
                    peak_eng = vals[i][j]
                    peak_freq = self.globalParams.freqs[i]

        cur_time = convert_time_into_seconds(self.time_units, float(a_case.time))
        case_time = self.ref_time + datetime.timedelta(seconds=float(cur_time))
        point = self.ptMap[a_point_id]
        # convert the spectral grids angle (in cartesian convention) to shore normal
        if self.usingWindDset:
            shore_normal_dir = 0.0  # only need to transform the angle if data is spatially constant
        else:
            angle = self.globalParams.grid_angle
            if self.angle_convention == 'Cartesian':
                shore_normal_dir = xms.cmswave.file_io.dataset_writer.cart_to_shore_normal(float(a_case.windDir), angle)
            elif self.angle_convention == 'Meteorologic':
                shore_normal_dir = xms.cmswave.file_io.dataset_writer.meteor_to_shore_normal(
                    float(a_case.windDir), angle
                )
            elif self.angle_convention == 'Oceanographic':
                shore_normal_dir = xms.cmswave.file_io.dataset_writer.ocean_to_shore_normal(
                    float(a_case.windDir), angle
                )
            else:
                shore_normal_dir = float(a_case.windDir)  # Shore normal
        ref_time_string = xms.cmswave.file_io.dataset_writer.get_time_string(case_time, self.date_format)
        # should be: "ref_time windSpeed windDir peqkFreq tidalSurge ptX ptY hs"
        self.ss.write(
            f'{ref_time_string} {float(a_case.windSpeed):.6f} {float(shore_normal_dir):.6f} '
            f'{peak_freq:.6f} {float(a_case.tidalSurge):.6f} {point.x:.6f} {point.y:.6f} {hs:.6f}\n'
        )

    def write_data_vals(self, a_vals):
        """
        Writes the data to file.

        Args:
            a_vals(:obj:`list[list[float]]`): Data to be written, in the shape of globalParams.
        """
        num_added = 0
        for i in range(0, len(self.globalParams.freqs)):
            for j in range(0, self.globalParams.get_num_angs()):
                num_added += 1
                self.ss.write(f"{a_vals[i][j]:>11.5f} ")
                if num_added % 7 == 0:
                    self.ss.write('\n')
            if num_added % 7 != 0:
                self.ss.write("\n")
            num_added = 0
