"""Class to write a WaveWatch3 grid namelist file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import logging
import shutil

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules


class WW3GridNamelistWriter:
    """Class to write a WaveWatch3 grid namelist file."""
    def __init__(self, xms_data, lateral, ocean):
        """Constructor.

        Args:
            xms_data (:obj:`XmsData`): Simulation data retrieved from SMS
            lateral: List of lateral nodes
            ocean: List of input nodes
        """
        self._ss = StringIO()
        self._logger = logging.getLogger('xms.wavewatch3')
        self._bound_process = None
        self._grid_process = None
        self._xms_data = xms_data
        self._namelists_nml = 'namelists.nml'
        self.lateral = lateral
        self.ocean = ocean

    def _write_grid_namelist_file(self):
        """Writes the namelist file."""
        file_w_path = "ww3_grid.nml"

        self._write_grid_header()
        self._write_spectrum_parameterization_namelist()
        self._write_run_namelist()
        self._write_timesteps_namelist()
        self._write_grid_namelist()
        self._write_unstructured_grid_namelist()
        self._write_inbound_namelists()
        self._write_end_comments()
        self._flush(file_w_path)

    def _write_grid_header(self):
        """Writes the header comments for the grid namelist file."""
        self._ss.write(
            '! -------------------------------------------------------------------- !\n'
            '! WAVEWATCH III - ww3_grid.nml - Grid pre-processing                   !\n'
            '! -------------------------------------------------------------------- !\n'
            '\n\n'
        )

    def _write_spectrum_parameterization_namelist(self):
        """Writes the namelist for the spectrum parameterization."""
        self._ss.write(
            '! -------------------------------------------------------------------- !\n'
            '! Define the spectrum parameterization via SPECTRUM_NML namelist\n'
            '!\n'
            '! * namelist must be terminated with /\n'
            '! * definitions & defaults:\n'
            '!     SPECTRUM%XFR         = 0.            ! frequency increment\n'
            '!     SPECTRUM%FREQ1       = 0.            ! first frequency (Hz)\n'
            '!     SPECTRUM%NK          = 0             ! number of frequencies (wavenumbers)\n'
            '!     SPECTRUM%NTH         = 0             ! number of direction bins\n'
            '!     SPECTRUM%THOFF       = 0.            ! relative offset of first direction [-0.5,0.5]\n'
            '! -------------------------------------------------------------------- !\n'
        )
        attrs = self._xms_data.sim_data_model_control.group('parameters')
        self._ss.write('&SPECTRUM_NML\n')
        if attrs.parameter('XFR').value != 0.0:
            self._ss.write(f"    SPECTRUM%XFR         = {attrs.parameter('XFR').value}\n")
        if attrs.parameter('FREQ1').value != 0.0:
            self._ss.write(f"    SPECTRUM%FREQ1       = {attrs.parameter('FREQ1').value}\n")
        if attrs.parameter('NK').value != 0:
            self._ss.write(f"    SPECTRUM%NK          = {attrs.parameter('NK').value}\n")
        if attrs.parameter('NTH').value != 0:
            self._ss.write(f"    SPECTRUM%NTH         = {attrs.parameter('NTH').value}\n")
        if attrs.parameter('THOFF').value != 0.0:
            self._ss.write(f"    SPECTRUM%THOFF       = {attrs.parameter('THOFF').value}\n")
        self._ss.write('/\n\n\n')

    def _write_run_namelist(self):
        """Writes the run namelist."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the run parameterization via RUN_NML namelist\n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     RUN%FLDRY            = F             ! dry run (I/O only, no calculation)\n"
            "!     RUN%FLCX             = F             ! x-component of propagation\n"
            "!     RUN%FLCY             = F             ! y-component of propagation\n"
            "!     RUN%FLCTH            = F             ! direction shift\n"
            "!     RUN%FLCK             = F             ! wavenumber shift\n"
            "!     RUN%FLSOU            = F             ! source terms\n"
            "! -------------------------------------------------------------------- !\n"
        )
        attrs = self._xms_data.sim_data_model_control.group('consts')
        self._ss.write('&RUN_NML\n')
        if attrs.parameter('FLDRY').value != 0:
            self._ss.write("    RUN%FLDRY           = T\n")
        if attrs.parameter('FLCX').value != 0:
            self._ss.write("    RUN%FLCX            = T\n")
        if attrs.parameter('FLCY').value != 0:
            self._ss.write("    RUN%FLCY            = T\n")
        if attrs.parameter('FLCTH').value != 0:
            self._ss.write("    RUN%FLCTH           = T\n")
        if attrs.parameter('FLCK').value != 0:
            self._ss.write("    RUN%FLCK            = T\n")
        if attrs.parameter('FLSOU').value != 0:
            self._ss.write("    RUN%FLSOU           = T\n")
        self._ss.write('/\n\n\n')

    def _write_timesteps_namelist(self):
        """Writes the timesteps namelist."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the timesteps parameterization via TIMESTEPS_NML namelist\n"
            "!\n"
            "! * It is highly recommended to set up time steps which are multiple \n"
            "!   between them. \n"
            "!\n"
            "! * The first time step to calculate is the maximum CFL time step\n"
            "!   which depend on the lowest frequency FREQ1 previously set up and the\n"
            "!   lowest spatial grid resolution in meters DXY.\n"
            "!   reminder : 1 degree=60minutes // 1minute=1mile // 1mile=1.852km\n"
            "!   The formula for the CFL time is :\n"
            "!   Tcfl = DXY / (G / (FREQ1*4*Pi) ) with the constants Pi=3,14 and G=9.8m/s²;\n"
            "!   DTXY  ~= 90% Tcfl\n"
            "!   DTMAX ~= 3 * DTXY   (maximum global time step limit)\n"
            "!\n"
            "! * The refraction time step depends on how strong can be the current velocities\n"
            "!   on your grid :\n"
            "!   DTKTH ~= DTMAX / 2   ! in case of no or light current velocities\n"
            "!   DTKTH ~= DTMAX / 10  ! in case of strong current velocities\n"
            "!\n"
            "! * The source terms time step is usually defined between 5s and 60s.\n"
            "!   A common value is 10s.\n"
            "!   DTMIN ~= 10\n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     TIMESTEPS%DTMAX      = 0.         ! maximum global time step (s)\n"
            "!     TIMESTEPS%DTXY       = 0.         ! maximum CFL time step for x-y (s)\n"
            "!     TIMESTEPS%DTKTH      = 0.         ! maximum CFL time step for k-th (s)\n"
            "!     TIMESTEPS%DTMIN      = 0.         ! minimum source term time step (s)\n"
            "! -------------------------------------------------------------------- !\n"
        )
        attrs = self._xms_data.sim_data_model_control.group('parameters')
        self._ss.write('&TIMESTEPS_NML\n')
        if attrs.parameter('maxGlobalDt').value != 0.0:
            self._ss.write(f"    TIMESTEPS%DTMAX         =  {attrs.parameter('maxGlobalDt').value}\n")
        if attrs.parameter('maxCFLXY').value != 0.0:
            self._ss.write(f"    TIMESTEPS%DTXY          =  {attrs.parameter('maxCFLXY').value}\n")
        if attrs.parameter('maxCFLKTheta').value != 0.0:
            self._ss.write(f"    TIMESTEPS%DTKTH         =  {attrs.parameter('maxCFLKTheta').value}\n")
        if attrs.parameter('minSourceDt').value != 0.0:
            self._ss.write(f"    TIMESTEPS%DTMIN         =  {attrs.parameter('minSourceDt').value}\n")
        self._ss.write('/\n\n\n')

    def _write_grid_namelist(self):
        """Writes out the grid namelist."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the grid to preprocess via GRID_NML namelist\n"
            "!\n"
            "! * the tunable parameters for source terms, propagation schemes, and \n"
            "!    numerics are read using namelists. \n"
            "! * Any namelist found in the folowing sections is temporarily written\n"
            "!   to param.scratch, and read from there if necessary. \n"
            "! * The order of the namelists is immaterial.\n"
            "! * Namelists not needed for the given switch settings will be skipped\n"
            "!   automatically\n"
            "!\n"
            "! * grid type can be : \n"
            "!    'RECT' : rectilinear\n"
            "!    'CURV' : curvilinear\n"
            "!    'UNST' : unstructured (triangle-based)\n"
            "!\n"
            "! * coordinate system can be : \n"
            "!    'SPHE' : Spherical (degrees)\n"
            "!    'CART' : Cartesian (meters)\n"
            "!\n"
            "! * grid closure can only be applied in spherical coordinates\n"
            "!\n"
            "! * grid closure can be : \n"
            "!    'NONE' : No closure is applied\n"
            "!    'SMPL' : Simple grid closure. Grid is periodic in the\n"
            "!           : i-index and wraps at i=NX+1. In other words,\n"
            "!           : (NX+1,J) => (1,J). A grid with simple closure\n"
            "!           : may be rectilinear or curvilinear.\n"
            "!    'TRPL' : Tripole grid closure : Grid is periodic in the\n"
            "!           : i-index and wraps at i=NX+1 and has closure at\n"
            "!           : j=NY+1. In other words, (NX+1,J<=NY) => (1,J)\n"
            "!           : and (I,NY+1) => (NX-I+1,NY). Tripole\n"
            "!           : grid closure requires that NX be even. A grid\n"
            "!           : with tripole closure must be curvilinear.\n"
            "!\n"
            "! * The coastline limit depth is the value which distinguish the sea \n"
            "!   points to the land points. All the points with depth values (ZBIN)\n"
            "!   greater than this limit (ZLIM) will be considered as excluded points\n"
            "!   and will never be wet points, even if the water level grows over.\n"
            "!   It can only overwrite the status of a sea point to a land point.\n"
            "!   The value must have a negative value under the mean sea level\n"
            "!\n"
            "! * The minimum water depth allowed to compute the model is the absolute\n"
            "!   depth value (DMIN) used in the model if the input depth is lower to \n"
            "!   avoid the model to blow up.\n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     GRID%NAME             = 'unset'            ! grid name (30 char)\n"
            "!     GRID%NML              = 'namelists.nml'    ! namelists filename\n"
            "!     GRID%TYPE             = 'unset'            ! grid type\n"
            "!     GRID%COORD            = 'unset'            ! coordinate system\n"
            "!     GRID%CLOS             = 'unset'            ! grid closure\n"
            "!\n"
            "!     GRID%ZLIM             = 0.        ! coastline limit depth (m)\n"
            "!     GRID%DMIN             = 0.        ! abs. minimum water depth (m)\n"
            "! -------------------------------------------------------------------- !\n"
        )
        attrs = self._xms_data.sim_data_model_control.group('parameters')
        self._ss.write('&GRID_NML\n')
        self._ss.write(f"    GRID%NAME             = '{self._xms_data.do_ugrid.name}'\n")
        self._ss.write(f"    GRID%NML              = '{self._namelists_nml}'\n")
        self._ss.write("    GRID%TYPE             = 'UNST'\n")  # Hard coded for unstructured grid
        self._ss.write("    GRID%COORD            = 'SPHE'\n")  # Hard coded for spherical
        self._ss.write("    GRID%CLOS             = 'NONE'\n")  # Hard coded for non-closing grid
        self._ss.write(f"    GRID%ZLIM             = {attrs.parameter('ZLIM').value}\n")
        self._ss.write(f"    GRID%DMIN             = {attrs.parameter('DMIN').value}\n")
        self._ss.write('/\n\n\n')

    def _write_unstructured_grid_namelist(self):
        """Write out the unstructured grid namelist."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the unstructured grid type via UNST_NML namelist\n"
            "! - only for UNST grids -\n"
            "!\n"
            "! * The minimum grid size is 3x3.\n"
            "!\n"
            "! * &MISC namelist must be removed\n"
            "!\n"
            "! * The depth value must have negative values under the mean sea level\n"
            "!\n"
            "! * The map value must be set as :\n"
            "!    -2 : Excluded boundary point (covered by ice)\n"
            "!    -1 : Excluded sea point (covered by ice)\n"
            "!     0 : Excluded land point\n"
            "!     1 : Sea point\n"
            "!     2 : Active boundary point\n"
            "!     3 : Excluded grid point\n"
            "!     7 : Ice point\n"
            "!\n"
            "! * the file must be a GMESH grid file containing node and element lists.\n"
            "!\n"
            "! * Extra open boundary list file with UGOBCFILE in namelist &UNST\n"
            "!   An example is given in regtest ww3_tp2.7\n"
            "!\n"
            "! * value <= scale_fac * value_read\n"
            "!\n"
            "! * IDLA : Layout indicator :\n"
            "!                  1   : Read line-by-line bottom to top. (default)\n"
            "!                  2   : Like 1, single read statement.\n"
            "!                  3   : Read line-by-line top to bottom.\n"
            "!                  4   : Like 3, single read statement.\n"
            "! * IDFM : format indicator :\n"
            "!                  1   : Free format. (default)\n"
            "!                  2   : Fixed format.\n"
            "!                  3   : Unformatted.\n"
            "! * FORMAT : element format to read :\n"
            "!               '(....)'  : auto detected (default)\n"
            "!               '(f10.6)' : float type\n"
            "!\n"
            "! * Example :\n"
            "!      IDF  SF   IDLA  IDFM   FORMAT       FILENAME\n"
            "!      20  -1.   4     2     '(20f10.2)'  'ngug.msh'\n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     UNST%SF             = 1.       ! unst scale factor\n"
            "!     UNST%FILENAME       = 'unset'  ! unst filename\n"
            "!     UNST%IDF            = 20       ! unst file unit number\n"
            "!     UNST%IDLA           = 1        ! unst layout indicator\n"
            "!     UNST%IDFM           = 1        ! unst format indicator\n"
            "!     UNST%FORMAT         = '(....)' ! unst formatted read format\n"
            "!\n"
            "!     UNST%UGOBCFILE      = 'unset'  ! additional boundary list file\n"
            "! -------------------------------------------------------------------- !\n"
        )
        self._ss.write('&UNST_NML\n')
        # Assuming free format (IDFM = 1 -- the default) so no need to set UNST%FORMAT.
        # Using IDLA = 4 (top to bottom, read all at once)
        # Using scaling factor of -1.
        self._ss.write("    UNST%SF             = -1.0\n")
        self._ss.write(f"    UNST%FILENAME       = '{self._xms_data.do_ugrid.name}.msh'\n")
        self._ss.write("    UNST%IDLA           = 4\n")
        self._ss.write('/\n\n\n')

    def _write_inbound_namelists(self):
        """Write out the inbound points namelists."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! Define the input boundary points via INBND_COUNT_NML and\n"
            "!                                      INBND_POINT_NML namelist\n"
            "! - for RECT, CURV and UNST grids -\n"
            "!\n"
            "! * If no mask defined, INBOUND can be used\n"
            "!\n"
            "! * If the actual input data is not defined in the actual wave model run\n"
            "!   the initial conditions will be applied as constant boundary conditions.\n"
            "!\n"
            "! * The number of points is defined by INBND_COUNT\n"
            "!\n"
            "! * The points must start from index 1 to N\n"
            "!\n"
            "! * Each line contains:\n"
            "!     Discrete grid counters (IX,IY) of the active point and a\n"
            "!     connect flag. If this flag is true, and the present and previous\n"
            "!     point are on a grid line or diagonal, all intermediate points\n"
            "!     are also defined as boundary points.\n"
            "!\n"
            "! * Included point :\n"
            "!     grid points from segment data\n"
            "!     Defines as lines identifying points at which\n"
            "!     input boundary conditions are to be defined. \n"
            "!\n"
            "! * namelist must be terminated with /\n"
            "! * definitions & defaults:\n"
            "!     INBND_COUNT%N_POINT     = 0        ! number of segments\n"
            "!\n"
            "!     INBND_POINT(I)%X_INDEX  = 0        ! x index included point\n"
            "!     INBND_POINT(I)%Y_INDEX  = 0        ! y index included point\n"
            "!     INBND_POINT(I)%CONNECT  = F        ! connect flag\n"
            "!\n"
            "! OR\n"
            "!     INBND_POINT(I)          = 0 0 F    ! included point\n"
            "! -------------------------------------------------------------------- !\n"
        )
        if self.ocean:
            # Write the INBND_COUNT_NML namelist
            self._ss.write('&INBND_COUNT_NML\n')
            self._ss.write(f'  INBND_COUNT%N_POINT    = {len(self.ocean)}\n')
            self._ss.write('/\n\n')

            # Write the INBND_POINT_NML namelist
            self._ss.write('&INBND_POINT_NML\n')
            for idx, node_id in enumerate(self.ocean):
                self._ss.write(f'  INBND_POINT({idx + 1})         = {node_id}  1  F\n')
            self._ss.write('/')
        self._ss.write('\n\n\n')

    def _write_end_comments(self):
        """Writes out the comments on the bottom of the grid namelist file."""
        self._ss.write(
            "! -------------------------------------------------------------------- !\n"
            "! WAVEWATCH III - end of namelist                                      !\n"
            "! -------------------------------------------------------------------- !\n"
        )

    def _flush(self, file_w_path):
        """Writes the StringIO previously processed to a file.

        Args:
            file_w_path (:obj:`str`):  String of the filename to write to.
        """
        f = open(file_w_path, 'w')
        self._ss.seek(0)
        shutil.copyfileobj(self._ss, f, 100000)
        f.close()

    def write(self):
        """Top-level entry point for the WaveWatch3 grid input file writer."""
        self._write_grid_namelist_file()
