"""Class to write a WaveWatch3 prnc nml file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import logging
import shutil

# 2. Third party modules
import netCDF4
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.constraint import read_grid_from_file
from xms.guipy.time_format import string_to_datetime

# 4. Local modules
from xms.wavewatch3.file_io.namelists_defaults_util import WW3NamelistsDefaultsUtil
from xms.wavewatch3.gui.gui_util import get_formatted_date_string


class WW3PrncNmlWriter:
    """Class to write a WaveWatch3 prnc nml file (NetCDF input field preprocessor)."""
    def __init__(self, xms_data, forcing_field_name, forcing_uuid):
        """Constructor.

        Args:
            xms_data (:obj:`XmsData`): Simulation data retrieved from SMS.
            forcing_field_name (:obj:`str`): The forcing field to process.
            forcing_uuid (:obj:`str`): The forcing field uuid.
        """
        self._ss = StringIO()
        self._logger = logging.getLogger('xms.wavewatch3')
        self._bound_process = None
        self._grid_process = None
        self._xms_data = xms_data
        self._forcing_field = forcing_field_name
        self._forcing_uuid = forcing_uuid
        self._defaults = WW3NamelistsDefaultsUtil()

    def _write_prnc_nml_file(self):
        """Writes the prnc namelist file."""
        file_w_path = f"ww3_prnc_{self._forcing_field}.nml"

        self._write_prnc_header()
        self._write_forcing_namelist()
        self._write_end_comments()
        self._flush(file_w_path)

    def _write_prnc_header(self):
        """Writes the header comments for the prnc namelist file."""
        self._ss.write(
            '! -------------------------------------------------------------------- !\n'
            '! WAVEWATCH III - ww3_prnc.nml - Field preprocessor                    !\n'
            '! -------------------------------------------------------------------- !\n'
            '\n\n'
        )

    def _write_forcing_namelist(self):
        """Writes the FORCING_NML namelist for the forcing fields to preprocess."""
        s = "! -------------------------------------------------------------------- !\n" \
            "! Define the forcing fields to preprocess via FORCING_NML namelist\n" \
            "!\n" \
            "! * only one FORCING%FIELD can be set at true\n" \
            "! * only one FORCING%grid can be set at true\n" \
            "! * tidal constituents FORCING%tidal is only available on grid%asis with FIELD%level or FIELD%current\n" \
            "!\n" \
            "! * namelist must be terminated with /\n" \
            "! * definitions & defaults:\n" \
            "!     FORCING%TIMESTART            = '19000101 000000'  ! Start date for the forcing field\n" \
            "!     FORCING%TIMESTOP             = '29001231 000000'  ! Stop date for the forcing field\n" \
            "!\n" \
            "!     FORCING%FIELD%ICE_PARAM1     = F           ! Ice thickness                      (1-component)\n" \
            "!     FORCING%FIELD%ICE_PARAM2     = F           ! Ice viscosity                      (1-component)\n" \
            "!     FORCING%FIELD%ICE_PARAM3     = F           ! Ice density                        (1-component)\n" \
            "!     FORCING%FIELD%ICE_PARAM4     = F           ! Ice modulus                        (1-component)\n" \
            "!     FORCING%FIELD%ICE_PARAM5     = F           ! Ice floe mean diameter             (1-component)\n" \
            "!     FORCING%FIELD%MUD_DENSITY    = F           ! Mud density                        (1-component)\n" \
            "!     FORCING%FIELD%MUD_THICKNESS  = F           ! Mud thickness                      (1-component)\n" \
            "!     FORCING%FIELD%MUD_VISCOSITY  = F           ! Mud viscosity                      (1-component)\n" \
            "!     FORCING%FIELD%WATER_LEVELS   = F           ! Level                              (1-component)\n" \
            "!     FORCING%FIELD%CURRENTS       = F           ! Current                            (2-components)\n" \
            "!     FORCING%FIELD%WINDS          = F           ! Wind                               (2-components)\n" \
            "!     FORCING%FIELD%WIND_AST       = F           ! Wind and air-sea temp. dif.        (3-components)\n" \
            "!     INPUT%FORCING%ATM_MOMENTUM   = F           ! Atmospheric momentum               (2-components)\n" \
            "!     INPUT%FORCING%AIR_DENSITY    = F           ! Air density                        (1-component)\n" \
            "!     FORCING%FIELD%ICE_CONC       = F           ! Ice concentration                  (1-component)\n" \
            "!     FORCING%FIELD%ICE_BERG       = F           ! Icebergs and sea ice concentration (2-components)\n" \
            "!     FORCING%FIELD%DATA_ASSIM     = F           ! Data for assimilation              (1-component)\n" \
            "!\n" \
            "!     FORCING%GRID%ASIS            = F           ! Transfer field 'as is' on the model grid\n" \
            "!     FORCING%GRID%LATLON          = F           ! Define field on regular lat/lon or cartesian grid\n" \
            "!\n" \
            "!     FORCING%TIDAL                = 'unset'     ! Set the tidal constituents [FAST | VFAST | " \
            "'M2 S2 N2']\n" \
            "! -------------------------------------------------------------------- !\n"
        self._ss.write(s)
        run_control = self._xms_data.sim_data_model_control.group('run_control')
        ice_and_mud = self._xms_data.sim_data_model_control.group('ice_and_mud')
        forcing_fields = [
            ('water_levels', run_control.parameter('water_levels').value, 'FORCING%FIELD%WATER_LEVELS'),
            ('currents', run_control.parameter('currents').value, 'FORCING%FIELD%CURRENTS'),
            ('winds', run_control.parameter('winds').value, 'FORCING%FIELD%WINDS'),
            ('atm_momentum', run_control.parameter('define_atm_momentum').value, 'INPUT%FORCING%ATM_MOMENTUM'),
            ('air_density', run_control.parameter('air_density').value, 'INPUT%FORCING%AIR_DENSITY'),
            ('ice_concentration', ice_and_mud.parameter('concentration').value, 'FORCING%FIELD%ICE_CONC'),
            ('ice_parameter1', ice_and_mud.parameter('param_1').value, 'FORCING%FIELD%ICE_PARAM1'),
            ('ice_parameter2', ice_and_mud.parameter('param_2').value, 'FORCING%FIELD%ICE_PARAM2'),
            ('ice_parameter3', ice_and_mud.parameter('param_3').value, 'FORCING%FIELD%ICE_PARAM3'),
            ('ice_parameter4', ice_and_mud.parameter('param_4').value, 'FORCING%FIELD%ICE_PARAM4'),
            ('ice_parameter5', ice_and_mud.parameter('param_5').value, 'FORCING%FIELD%ICE_PARAM5'),
            ('mud_density', ice_and_mud.parameter('mud_density').value, 'FORCING%FIELD%MUD_DENSITY'),
            ('mud_thickness', ice_and_mud.parameter('mud_thickness').value, 'FORCING%FIELD%MUD_THICKNESS'),
            ('mud_viscosity', ice_and_mud.parameter('mud_viscosity').value, 'FORCING%FIELD%MUD_VISCOSITY'),
        ]

        for field in forcing_fields:
            if field[0] == self._forcing_field and field[1] == 'T: external forcing file':
                # Get the field passed in, and only write if it's set to a forcing field file
                self._ss.write('&FORCING_NML\n')
                start_date = run_control.parameter('starting_date').value
                start_date = string_to_datetime(start_date)
                date_str = f"'{get_formatted_date_string(start_date)}'"
                self._ss.write(f'     FORCING%TIMESTART = {date_str}\n')
                end_date = run_control.parameter('end_date').value
                end_date = string_to_datetime(end_date)
                date_str = f"'{get_formatted_date_string(end_date)}'"
                self._ss.write(f'     FORCING%TIMESTOP = {date_str}\n')

                self._ss.write(f'     {field[2]} = T\n')

                self._ss.write('     FORCING%GRID%ASIS = T\n')  # We are only doing datasets on the mesh/grid, not cgrid
                self._ss.write('/\n\n')

                self._write_file_namelist()

    def _write_file_namelist(self):
        """Writes the FILE_NML namelist for the forcing fields input file."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the content of the input file via FILE_NML namelist\n"
            "!\n"
            "! * input file must respect netCDF format and CF conventions\n"
            "! * input file must contain :\n"
            "!      -dimension : time, name expected to be called time\n"
            "!      -dimension : longitude/latitude, names can defined in the namelist\n"
            "!      -variable : time defined along time dimension\n"
            "!      -attribute : time with attributes units written as ISO8601 convention\n"
            "!      -attribute : time with attributes calendar set to standard as CF convention\n"
            "!      -variable : longitude defined along longitude dimension\n"
            "!      -variable : latitude defined along latitude dimension\n"
            "!      -variable : field defined along time,latitude,longitude dimensions\n"
            "! * FILE%VAR(I) must be set for each field component\n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     FILE%FILENAME      = 'unset'           ! relative path input file name\n"
            "!     FILE%LONGITUDE     = 'unset'           ! longitude/x dimension name\n"
            "!     FILE%LATITUDE      = 'unset'           ! latitude/y dimension name\n"
            "!     FILE%VAR(I)        = 'unset'           ! field component\n"
            "!     FILE%TIMESHIFT     = '00000000 000000' ! shift the time value to 'YYYYMMDD HHMMSS'\n"
            "! -------------------------------------------------------------------- !\n"
        )
        forcing_vars = []
        forcing_vars.append(('water_levels', ['waterLevel']))
        forcing_vars.append(('currents', ['currentU', 'currentV']))
        forcing_vars.append(('winds', ['windU', 'windV']))
        forcing_vars.append(('atm_momentum', ['momentumU', 'momentumV']))
        forcing_vars.append(('air_density', ['airDensity']))
        forcing_vars.append(('ice_concentration', ['iceConcentration']))
        forcing_vars.append(('ice_parameter1', ['iceThickness']))
        forcing_vars.append(('ice_parameter2', ['iceViscosity']))
        forcing_vars.append(('ice_parameter3', ['iceDensity']))
        forcing_vars.append(('ice_parameter4', ['iceModulus']))
        forcing_vars.append(('ice_parameter5', ['iceFlowMeanDiam']))
        forcing_vars.append(('mud_density', ['mudDensity']))
        forcing_vars.append(('mud_thickness', ['mudThickness']))
        forcing_vars.append(('mud_viscosity', ['mudViscosity']))
        self._ss.write('&FILE_NML\n')
        self._ss.write(f"FILE%FILENAME = 'ww3_forcing_{self._forcing_field}.nc'\n")
        self._ss.write("FILE%LONGITUDE = 'longitude'\n")
        self._ss.write("FILE%LATITUDE = 'latitude'\n")
        for field in forcing_vars:
            if field[0] == self._forcing_field:
                var_num = 1
                for var in field[1]:
                    self._ss.write(f"FILE%VAR({var_num}) = '{var}'\n")
                    var_num += 1

                self._write_forcing_field_to_netcdf(f'ww3_forcing_{self._forcing_field}.nc', field[1])
        self._ss.write('/\n\n')

    def _write_end_comments(self):
        """Writes out the end comments at the end of the file."""
        self._ss.write(
            '! -------------------------------------------------------------------- !\n'
            '! WAVEWATCH III - end of namelist                                      !\n'
            '! -------------------------------------------------------------------- !\n'
            '\n'
        )

    def _write_forcing_field_to_netcdf(self, filename, field_name_list):
        """
        Writes out the forcing field data to a netCDF4 file.

        Args:
            filename (:obj:`str`): The filename to write to.
            field_name_list (:obj:`list[str]`): The variable names to write to.
        """
        root_grp = netCDF4.Dataset(filename=filename, mode='w', format='NETCDF4')
        root_grp.description = 'Created by xmswavewatch3'

        query = Query()
        dataset_reader = query.item_with_uuid(self._forcing_uuid)
        geom_item = query.item_with_uuid(dataset_reader.geom_uuid)
        grid = read_grid_from_file(geom_item.cogrid_file)
        locations = grid.ugrid.locations
        longitudes = [point[0] for point in locations]
        latitudes = [point[1] for point in locations]
        dataset_times = [time_val for time_val in dataset_reader.times]
        dataset_values = [cur_time.tolist() for cur_time in dataset_reader.values]
        dataset_values = np.array(dataset_values)

        # Make some dimensions with the appropriate sizes
        _ = root_grp.createDimension('time', len(dataset_times))
        _ = root_grp.createDimension('latitude', len(latitudes))
        _ = root_grp.createDimension('longitude', len(longitudes))
        _ = root_grp.createDimension('values', len(longitudes))

        # Make some variables to store the forcing field data
        time_var = root_grp.createVariable('time', 'f8', ('time', ))
        longitude_var = root_grp.createVariable('longitude', 'f8', ('longitude', ))
        latitude_var = root_grp.createVariable('latitude', 'f8', ('latitude', ))
        data_vars = []
        for field_name in field_name_list:
            # Make a variable for each forcing field
            cur_var = root_grp.createVariable(field_name, 'f8', (
                'time',
                'values',
            ))
            data_vars.append(cur_var)

        # Store the forcing field data in the variables
        time_var[:] = np.array(dataset_times)
        latitude_var[:] = np.array(latitudes)
        longitude_var[:] = np.array(longitudes)
        for idx, data_var in enumerate(data_vars):
            # Store each forcing field, depending on whether it's a vector or scalar
            data_var[:] = dataset_values[:, :, idx] if len(dataset_values.shape) > 2 else dataset_values

        # Save the nc
        root_grp.close()

    def _flush(self, file_w_path):
        """Writes the StringIO previously processed to a file.

        Args:
            file_w_path (:obj:`str`):  String of the filename to write to.
        """
        f = open(file_w_path, 'w')
        self._ss.seek(0)
        shutil.copyfileobj(self._ss, f, 100000)
        f.close()

    def write(self):
        """Top-level entry point for NetCDF input field preprocessor file writer."""
        self._write_prnc_nml_file()
