"""Class to read a WaveWatch3 grid namelist file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import shlex
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query

# 4. Local modules
from xms.wavewatch3.data.model import get_model
from xms.wavewatch3.file_io.io_util import READ_BUFFER_SIZE


def _get_true_false_as_int(value):
    """Gets the various formats of true/false in as an integer.

    value (:obj:`str`):  The string from the file.

    Returns:
        (:obj:`int`):  1 if some form of true, 0 if false
    """
    if value.upper() == '.FALSE.' or value.upper() == 'F' or value.upper() == "'F'" or value == '0':
        return 0
    elif value.upper() == '.TRUE.' or value.upper() == 'T' or value.upper() == "'T'" or value == '1':
        return 1
    else:
        raise ValueError(f"Unable to determine value of {value}.")


class GridNmlReader:
    """Class to read a WaveWatch3 grid nml file."""
    def __init__(self, filename='', sim_data=None):
        """Constructor.

        Args:
            filename (:obj:`str`): Path to the nml file. If not provided (not testing or control file read), will
                retrieve from Query.
            sim_data(:obj:`xms.wavewatch3.data.SimData`):  The simulation data to edit.
        """
        self._filename = filename
        self._query = None
        self._sim_data = sim_data
        self._sim_comp_uuid = str(uuid.uuid4())
        self._setup_query()
        self._lines = []
        self._current_line = 0
        self._logger = logging.getLogger('xms.wavewatch3')

        self._grid_name = ''
        self._grid_namelist = ''
        self._grid_type = ''
        self._grid_coord = ''
        self._grid_clos = ''
        self._gmsh_filename = ''
        self._inbound_count = 0
        self._inbound_points = {}  # Dict of {point_id: [x_index, y_index, connect]}
        self._global_values = get_model().global_parameters
        self._global_values.restore_values(self._sim_data.global_values)

    def _setup_query(self):
        """Setup the xmsapi Query for sending data to SMS and get the import filename."""
        if not self._filename:  # pragma: no cover - slow to setup Query for the filename
            self._query = Query()
            self._filename = self._query.read_file

    def _parse_next_line(self, shell=False):
        """Parse the next line of text from the file.

        Skips empty and comment lines.

        Args:
            shell (:obj:`bool`): If True will parse line using shlex. Slower but convenient for quoted tokens.

        Returns:
            (:obj:`list[str]`): The next line of text, split on whitespace
        """
        line = None
        while not line or line.startswith('!'):  # blank lines and control file identifier
            if self._current_line >= len(self._lines):
                # raise RuntimeError('Unexpected end of file.')
                return None
            line = self._lines[self._current_line].strip()
            self._current_line += 1
        if shell:
            return shlex.split(line, posix=False)
        return line.split()

    def _get_namelist_cards_and_values(self, firstline_data):
        """Reads the entire namelist until the closing / is found.  Stores cards and values.

        Handles cases where the data is on multiple lines, or on a single line.

        Args:
            firstline_data (:obj:`list[str]`):  List of data parsed on the opening line.

        Returns:
            (:obj:`dict`):  Dictionary with the keys being the cards read, and values of corresponding data.
        """
        namelist_list = firstline_data
        if '/' not in namelist_list:
            # We haven't found the end of the namelist yet.  Read more lines as necessary.
            end_not_found = True
            while end_not_found:
                # Grab the next line of data
                data = self._parse_next_line()
                # Extend the list of all data by the current line read
                namelist_list.extend(data)
                # Check if we've got the end of the namelist yet
                if '/' in data:
                    end_not_found = False

        # Now, we have the entire namelist as a list of values.  Parse into cards/commands and values
        cards = []
        values = []
        for i in range(len(namelist_list)):
            # Get the list elements on either side of each =
            if namelist_list[i] == '=':
                if 0 < i < len(namelist_list):
                    cards.append(namelist_list[i - 1].rstrip(','))
                    if namelist_list[i + 1][0] == "'" and namelist_list[i + 1].rstrip(',')[-1] != "'":
                        # We have a string that has an opening single quote, but not the end, so we need to concat it
                        # Something like:  value = 'abc def'
                        value_str = namelist_list[i + 1].rstrip(',') + ' '
                        found_end = False
                        index = 2
                        while not found_end:
                            value_str += namelist_list[i + index].rstrip(',')
                            if namelist_list[i + index].rstrip(',')[-1] == "'":
                                # We found the end of the string
                                found_end = True
                            else:
                                value_str += ' '
                            index += 1
                        # Store the multi part string surrounded in single quotes as a single string
                        values.append(value_str)
                    elif 'INBND_POINT' in cards[-1] and cards[-1][-1] == ')':
                        # We have an INBOUND_POINT(point_id) line
                        # This is where the INBOUND_POINT line has all three values:  X_INDEX, Y_INDEX, and CONNECT
                        x_index = namelist_list[i + 1].rstrip(',')
                        y_index = namelist_list[i + 2].rstrip(',')
                        connect_val = namelist_list[i + 3].rstrip(',')
                        values.append([x_index, y_index, connect_val])
                    else:
                        values.append(namelist_list[i + 1].rstrip(','))

        # Return the cards and values found throughout the namelist read
        return cards, values

    def _read_grid_nml_file(self):
        """Reads the various namelists in the grid nml file."""
        reading = True
        while reading:
            data = self._parse_next_line()
            if data:
                if '&SPECTRUM_NML' in data[0].strip():
                    self._read_spectrum_parameterization_namelist(data)
                elif '&RUN_NML' in data[0].strip():
                    self._read_run_parameterization_namelist(data)
                elif '&TIMESTEPS_NML' in data[0].strip():
                    self._read_timesteps_parameterization_namelist(data)
                elif '&GRID_NML' in data[0].strip():
                    self._read_grid_namelist(data)
                elif '&RECT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&CURV_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&UNST_NML' in data[0].strip():
                    self._read_unstructured_grid_namelist(data)
                elif '&SMC_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&DEPTH_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&MASK_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&OBST_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&SLOPE_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&SED_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&INBND_COUNT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_inbound_count_namelist(data)
                elif '&INBND_POINT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_inbound_point_namelist(data)
                elif '&EXCL_COUNT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&EXCL_POINT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&EXCL_BODY_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&OUTBND_COUNT_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                elif '&OUTBND_LINE_NML' in data[0].strip():
                    # A valid namelist type, but one we don't read
                    self._read_past_unsupported_namelist(data)
                else:
                    raise ValueError(f'Unrecognized namelist {data}')
            else:
                reading = False

    def _read_spectrum_parameterization_namelist(self, data):
        """Read the spectrum parameterization SPECTRUM_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        model_parameters = self._global_values
        params = model_parameters.group('parameters')
        self._logger.info('Reading SPECTRUM_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'SPECTRUM%XFR' in cards:
            params.parameter('XFR').value = float(values[cards.index('SPECTRUM%XFR')])
        if 'SPECTRUM%FREQ1' in cards:
            params.parameter('FREQ1').value = float(float(values[cards.index('SPECTRUM%FREQ1')]))
        if 'SPECTRUM%NK' in cards:
            params.parameter('NK').value = int(float(values[cards.index('SPECTRUM%NK')]))
        if 'SPECTRUM%NTH' in cards:
            params.parameter('NTH').value = int(float(values[cards.index('SPECTRUM%NTH')]))
        if 'SPECTRUM%THOFF' in cards:
            params.parameter('THOFF').value = float(float(values[cards.index('SPECTRUM%THOFF')]))

    def _read_run_parameterization_namelist(self, data):
        """Read the run parameterization RUN_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        model_parameters = self._global_values
        consts = model_parameters.group('consts')
        self._logger.info('Reading RUN_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'RUN%FLDRY' in cards:
            consts.parameter('FLDRY').value = _get_true_false_as_int(values[cards.index('RUN%FLDRY')])
        if 'RUN%FLCX' in cards:
            consts.parameter('FLCX').value = _get_true_false_as_int(values[cards.index('RUN%FLCX')])
        if 'RUN%FLCY' in cards:
            consts.parameter('FLCY').value = _get_true_false_as_int(values[cards.index('RUN%FLCY')])
        if 'RUN%FLCTH' in cards:
            consts.parameter('FLCTH').value = _get_true_false_as_int(values[cards.index('RUN%FLCTH')])
        if 'RUN%FLCK' in cards:
            consts.parameter('FLCK').value = _get_true_false_as_int(values[cards.index('RUN%FLCK')])
        if 'RUN%FLSOU' in cards:
            consts.parameter('FLSOU').value = _get_true_false_as_int(values[cards.index('RUN%FLSOU')])

    def _read_timesteps_parameterization_namelist(self, data):
        """Read the timesteps parameterization TIMESTEPS_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        model_parameters = self._global_values
        params = model_parameters.group('parameters')
        self._logger.info('Reading TIMESTEPS_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'TIMESTEPS%DTMAX' in cards:
            params.parameter('maxGlobalDt').value = float(values[cards.index('TIMESTEPS%DTMAX')])
        if 'TIMESTEPS%DTXY' in cards:
            params.parameter('maxCFLXY').value = float(values[cards.index('TIMESTEPS%DTXY')])
        if 'TIMESTEPS%DTKTH' in cards:
            params.parameter('maxCFLKTheta').value = float(values[cards.index('TIMESTEPS%DTKTH')])
        if 'TIMESTEPS%DTMIN' in cards:
            params.parameter('minSourceDt').value = float(values[cards.index('TIMESTEPS%DTMIN')])

    def _read_grid_namelist(self, data):
        """Read the grid GRID_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        model_parameters = self._global_values
        params = model_parameters.group('parameters')
        self._logger.info('Reading GRID_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'GRID%NAME' in cards:
            self._grid_name = values[cards.index('GRID%NAME')]
        if 'GRID%NML' in cards:
            self._grid_namelist = values[cards.index('GRID%NML')]
        if 'GRID%TYPE' in cards:
            self._grid_type = values[cards.index('GRID%TYPE')]
        if 'GRID%COORD' in cards:
            self._grid_coord = values[cards.index('GRID%COORD')]
        if 'GRID%CLOS' in cards:
            self._grid_clos = values[cards.index('GRID%CLOS')]
        if 'GRID%ZLIM' in cards:
            params.parameter('ZLIM').value = float(values[cards.index('GRID%ZLIM')])
        if 'GRID%DMIN' in cards:
            params.parameter('DMIN').value = float(values[cards.index('GRID%DMIN')])

    def _read_past_unsupported_namelist(self, data):
        """Read past an unspoorted namelist if we run into one, so we can continue on the next NML."""
        self._logger.info(f'Skipping namelist {data[0].strip()}...')
        # Call this so we get past the end of the namelist in the lines read:
        _, _ = self._get_namelist_cards_and_values(data)

    def _read_unstructured_grid_namelist(self, data):
        """Read the grid UNST_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        self._logger.info('Reading UNST_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'UNST%FILENAME' in cards:
            self._gmsh_filename = values[cards.index('UNST%FILENAME')]

    def _read_inbound_count_namelist(self, data):
        """Read the grid INBND_COUNT_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        self._logger.info('Reading INBND_COUNT_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        if 'INBND_COUNT%N_POINT' in cards:
            self._inbound_count = int(values[cards.index('INBND_COUNT%N_POINT')])

    def _read_inbound_point_namelist(self, data):
        """Read the grid INBND_POINT_NML namelist.

        Args:
            data(:obj:`list[str]`):  The current line (including the namelist ID) that may contain more data.
        """
        self._logger.info('Reading INBND_POINT_NML namelist...')
        cards, values = self._get_namelist_cards_and_values(data)
        for card, value in zip(cards, values):
            point_id = int(card[card.find("(") + 1:card.rfind(")")])
            if point_id not in self._inbound_points:
                self._inbound_points[point_id] = [-1, -1, 'F']
            if 'X_INDEX' in card:
                # INBOUND_POINT(point_id)%X_INDEX line
                self._inbound_points[point_id][0] = int(value)
            elif 'Y_INDEX' in card:
                # INBOUND_POINT(point_id)%Y_INDEX line
                self._inbound_points[point_id][1] = int(value)
            elif 'CONNECT' in card:
                # INBOUND_POINT(point_id)%CONNECT line
                self._inbound_points[point_id][2] = value
            elif card[-1] == ')' and type(value) is list:
                # INBOUND_POINT(point_id) line with 3 values for X_INDEX, Y_INDEX, CONNECT
                self._inbound_points[point_id][0] = int(value[0])
                self._inbound_points[point_id][1] = int(value[1])
                self._inbound_points[point_id][2] = value[2]

    def get_inbound_points(self):
        """Gets the inbound points, if any, read by the INBND_POINT_NML namelist."""
        return self._inbound_points

    def read(self):
        """Top-level entry point for the WaveWatch3 namelists nml input file reader."""
        try:
            self._logger.info('Parsing ASCII text from file...')
            with open(self._filename, 'r', buffering=READ_BUFFER_SIZE) as f:
                self._lines = f.readlines()

            self._read_grid_nml_file()
            self._logger.info('Committing changes....')
            self._sim_data.global_values = self._global_values.extract_values()
            self._sim_data.commit()
            self._logger.info('Finished!')
        except Exception:
            self._logger.exception(
                'Unexpected error in grid nml preprocessor file '
                f'(line {self._current_line + 1}).'
            )
            raise
