"""Class to read a WaveWatch3 spectra netCDF file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
import datetime
import logging
import os
import shlex
import uuid

# 2. Third party modules
import netCDF4
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.coverage.spectral import PLANE_TYPE_enum, SpectralCoverage, SpectralGrid
from xms.data_objects.parameters import Coverage, datetime_to_julian, Point, Projection, RectilinearGrid

# 4. Local modules
from xms.wavewatch3.file_io.io_util import GEOGRAPHIC_WKT, READ_BUFFER_SIZE


class WW3SpecListReader:
    """Class to read a WaveWatch3 spectra netCDF file."""
    def __init__(self, filename=''):
        """Constructor.

        Args:
            filename (:obj:`str`): Path to the nml file. If not provided (not testing or control file read),
                will retrieve from Query.
            reftime (:obj:`datetime`): The reference time to set on the spectra data.
        """
        self._filename = filename
        self._reftime = None
        self._query = None
        self._setup_query()
        self._lines = []
        self._current_line = 0
        self._logger = logging.getLogger('xms.wavewatch3')
        self.spectral_coverages = []

    def _setup_query(self):
        """Setup the xmsapi Query for sending data to SMS and get the import filename."""
        if not self._filename:  # pragma: no cover - slow to setup Query for the filename
            self._query = Query()
            self._filename = self._query.read_file

    def _parse_next_line(self, shell=False):
        """Parse the next line of text from the file.

        Skips empty and comment lines.

        Args:
            shell (:obj:`bool`): If True will parse line using shlex. Slower but convenient for quoted tokens.

        Returns:
            (:obj:`list[str]`): The next line of text
        """
        line = None
        while not line or line.startswith('!'):  # blank lines and control file identifier
            if self._current_line >= len(self._lines):
                # raise RuntimeError('Unexpected end of file.')
                return None
            line = self._lines[self._current_line].strip()
            self._current_line += 1
        if shell:
            return shlex.split(line, posix=False)
        return [line]

    def _read_spec_list_file(self):
        """Read the spec.list file."""
        reading = True
        while reading:
            data = self._parse_next_line()
            if data:
                self.read_spectra_nc_file(data[0])
            else:
                reading = False

    def _get_netcdf_data(self, root_grp, name):
        """Read netCDF4 data from the Dataset.

        Args:
            root_grp (:obj:`netCDF4.Dataset`):  The netCDF4 root dataset.
            name (:obj:`str`): The name of the data to read.
        """
        return root_grp[f'/{name}'][:] if name in root_grp.variables else None

    def read_spectra_nc_file(self, nc_file):
        """Read the spectra file.

        Args:
            nc_file(:obj:`str`):  The current line from the spec.list file, containing an nc file name.
        """
        filename = os.path.join(os.path.dirname(self._filename), nc_file)
        self._logger.info(f'Reading spectra file {nc_file}...')
        root_grp = netCDF4.Dataset(filename, "r", format="NETCDF4_CLASSIC")

        # Read efth data (time, station, frequency, direction):
        efth_data = self._get_netcdf_data(root_grp=root_grp, name='efth')
        # Read direction, frequency, time, latitude, longitude, and station data:
        direction_data = self._get_netcdf_data(root_grp=root_grp, name='direction')
        frequency_data = self._get_netcdf_data(root_grp=root_grp, name='frequency')
        time_data = self._get_netcdf_data(root_grp=root_grp, name='time')
        latitude_data = self._get_netcdf_data(root_grp=root_grp, name='latitude')
        longitude_data = self._get_netcdf_data(root_grp=root_grp, name='longitude')

        # Check for valid data
        if efth_data is None or direction_data is None or frequency_data is None or time_data is None or \
                latitude_data is None or longitude_data is None:
            return False

        self._reftime = datetime.datetime(1990, 1, 1) + datetime.timedelta(days=time_data[0])
        julian_dates = self._get_julian_dates_ww3(time_data)

        # Convert some of the data from ndarray to lists, making note that lat/lon might be a 0 dimension array
        longitude_data = longitude_data if longitude_data.ndim > 0 else [longitude_data[()]]
        latitude_data = latitude_data if latitude_data.ndim > 0 else [latitude_data[()]]
        direction_data = direction_data.tolist()
        frequency_data = frequency_data.tolist()
        frequency_delta = []
        for i in range(len(frequency_data) - 1):
            frequency_delta.append(frequency_data[i + 1] - frequency_data[i])

        # Build coverage geometry
        point_list = []
        pt_ids = 1
        for lon, lat in zip(longitude_data, latitude_data):
            # Make a point for each latitude longitude location found
            pt = Point(lon, lat)
            pt.id = pt_ids
            point_list.append(pt)
            pt_ids += 1
            break  # We don't support multiple spectral points in a single .nc file. This will change in the future.
        cov_geom = Coverage()
        cov_geom.set_points(point_list)
        # Get filename with extension
        base_filename = os.path.basename(nc_file)
        # Set coverage name to filename without the extension
        cov_geom.name = os.path.splitext(base_filename)[0]
        cov_geom.projection = Projection(wkt=GEOGRAPHIC_WKT)
        cov_geom.uuid = str(uuid.uuid4())
        cov_geom.complete()

        # Build the spectral coverage
        spec_cov = SpectralCoverage()
        spec_cov.m_cov = cov_geom

        # Build the spectral grids for each point
        spec_grids = {}
        for pt_idx, pt in enumerate(point_list):
            ref_dt = self._reftime if self._reftime else datetime.datetime(1990, 1, 1)
            ref_julian = datetime_to_julian(ref_dt)
            spec_grid = SpectralGrid(ref_julian)
            # spec_grid = SpectralGrid(julian_dates[0])
            spec_grids[pt_idx] = (spec_grid, pt)

            # Make a rectillinear grid, using the frequency and direction data read
            rect_grid = RectilinearGrid()
            rect_grid.origin = Point(frequency_data[0], 0.0, 0.0)
            # rect_grid.angle = 360.0 - direction_data[0]  # direction_data[0] oceanographic to cartesian, so 90 -> 0
            # temp_direction_data = [direction_data[1] - direction_data[0]] * (len(direction_data) + 1)
            rect_grid.angle = 90.0 - direction_data[0]  # direction_data[0] oceanographic to cartesian, so 90 -> 0
            delta_dir = abs(direction_data[0] - direction_data[1])
            roll_count = int(-90 / delta_dir + 1)
            temp_direction_data = [delta_dir] * (len(direction_data) + 1)
            rect_grid.set_sizes(frequency_delta, temp_direction_data[:-1])
            rect_grid.complete()
            spec_grids[pt_idx][0].m_rectGrid = rect_grid
            spec_grids[pt_idx][0].m_timeUnits = 'Days'
            spec_grids[pt_idx][0].m_planeType = PLANE_TYPE_enum.FULL_GLOBAL_PLANE

            # Add the frequency and direction values from the efth data
            for time_idx, julian in enumerate(julian_dates):
                ts_offset = julian - ref_julian
                ts_values = efth_data[time_idx][pt_idx].tolist()
                for i in range(len(ts_values)):
                    flipped_values = np.flip(ts_values[i])  # Flip the values to match their convention
                    rolled_values = np.roll(flipped_values, roll_count)
                    final_values = np.append(rolled_values, rolled_values[0])  # Add a value to match the grid storage
                    ts_values[i] = final_values

                # Transpose the 2D data to get it in the right order, then put in a 1D array for the spectral grid
                ts_values_transposed = np.transpose(ts_values)
                ts_values_1d = [j for sub in ts_values_transposed for j in sub]
                spec_grids[pt_idx][0].add_timestep(ts_offset, ts_values_1d)

        for _, v in spec_grids.items():
            spec_cov.AddSpectralGrid(v[1].id, v[0])
        self.spectral_coverages.append(spec_cov)

    def _get_julian_dates_ww3(self, time_data):
        """Gets the days passed in as julian dates, referenced to 1 Jan 1990.

        Args:
            time_data (:obj:`list[float]`):  List of time values in days, as floats.

        Returns:
            (:obj:`list[float]`):  The times converted to julian dates.
        """
        julian_dates = []
        for time_val in time_data:
            dt = datetime.datetime(1990, 1, 1) + datetime.timedelta(days=time_val)
            julian_dates.append(datetime_to_julian(dt))
        return julian_dates

    def read(self):
        """Top-level entry point for the WaveWatch3 bounc nml input file reader."""
        try:
            self._logger.info('Parsing ASCII text from file...')
            with open(self._filename, 'r', buffering=READ_BUFFER_SIZE) as f:
                self._lines = f.readlines()

            self._read_spec_list_file()
            self._logger.info('Committing changes....')
            self._logger.info('Finished!')
        except Exception:
            self._logger.exception('Unexpected error in spec.list file '
                                   f'(line {self._current_line + 1}).')
            raise
