"""TUFLOWFV NetCDF solution file importer."""
# 1. Standard python modules
import datetime
import logging
import os
import uuid
import warnings

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.tuflowfv.file_io import io_util


class NetcdfSolutionReader:
    """Class for reading TUFLOWFV NetCDF solutions."""
    IGNORED_NETCDF_VARIABLES = {  # Datasets in the NetCDF solution file we don't care about.
        'ResTime',
        'MaxNumCellVert',
        'cell_Nvert',
        'cell_node',
        'NL',
        'idx2',
        'idx3',
        'cell_X',
        'cell_Y',
        'cell_Zb',
        'cell_A',
        'node_X',
        'node_Y',
        'node_Zb',
        'layerface_Z',
        'stat',
    }
    UNMASKED_DATASETS = {  # Datasets we should not apply the wet/dry activity array to
        'ZB',
    }

    def __init__(self, filenames, ugrid_item):
        """Constructor.

        Args:
            filenames (list[str]): Absolute paths to the NetCDF solution files
            ugrid_item (TreeNode): The linked Mesh2D or UGrid module geometry tree item
        """
        self._logger = logging.getLogger('xms.tuflowfv')
        self._filenames = filenames
        self._ugrid_item = ugrid_item
        self._ugrid_uuid = self._ugrid_item.uuid
        self._zero_time = None
        self._temp_dir = XmEnv.xms_environ_temp_directory()
        self._builders = []  # [DatasetWriter] - the imported solution datasets

    def _check_for_linked_ugrid(self):
        """Checks if the linked geometry is a UGrid module object.

        Returns:
            bool: True if the linked geometry is a UGrid, False if a Mesh2D
        """
        tree_type = self._ugrid_item.item_typename if self._ugrid_item else ''
        if tree_type != 'TI_UGRID_SMS':
            self._logger.error('Cannot read NetCDF format solution file because no UGrid module geometry is linked '
                               'to the simulation. NetCDF format solution files are cell-centered and can only be '
                               'loaded onto a UGrid module object. To read this solution, right-click on the domain '
                               'mesh used with this solution\'s simulation and select "Convert > Mesh->UGrid". Link '
                               'the new UGrid module object to the solution\'s simulation, right-click on the '
                               'simulation, and select "Read Solution".')
            return False
        return True

    def _check_for_refdate(self, time_dset):
        """Check if a reference time is defined for this dataset.

        Args:
            time_dset (xr.Dataset): The ResTime dataset
        """
        long_name = time_dset.attrs.get('long_name', '')
        reftime_card = 'output time relative to '
        if reftime_card in long_name:
            long_name = long_name.replace(reftime_card, '')
            self._zero_time = datetime.datetime.strptime(long_name, '%d/%m/%Y %H:%M:%S')

    def _read_dataset(self, filename):
        """Build all the NetCDF datasets.

        Args:
            filename (str): Absolute path to the NetCDF solution file
        """
        self._logger.info(f'Reading NetCDF solution file: "{io_util.logging_filename(filename)}"...')
        # Don't decode times. TUFLOWFV does not write the CF units correctly. They write the reference date to the
        # 'long_name' attr of the 'ResTime' variable if using ISODATE format. This actually works in our favor as we
        # would have to convert the timestamps to offsets from the first timestep if xarray decoded the times.
        with xr.open_dataset(filename, decode_times=False) as ds:
            times = ds.ResTime.data
            if len(times) < 1:
                self._logger.error(f'Empty NetCDF solution file: {io_util.logging_filename(filename)}')
                return
            self._check_for_refdate(ds.ResTime)

            activity_mask = ds.stat.data < 0  # Switch from 0=inactive, -1=active to 0=inactive, 1=active
            for variable in ds.variables:
                dset_name = str(variable)
                if dset_name in self.IGNORED_NETCDF_VARIABLES or dset_name.endswith('_y'):
                    continue  # Not a solution dataset or the y-component of a vector dataset (read with x-component)
                self._logger.info(f'Reading NetCDF dataset from file: "{dset_name}"...')
                dset_name, num_components, values = self._get_dataset_info_and_values(ds, dset_name)
                activity = None if dset_name in self.UNMASKED_DATASETS else activity_mask
                self._write_xmdf_dataset(dset_name, num_components, times, values, activity)

    def _get_dataset_info_and_values(self, dset, dset_name):
        """Get the tree item name, number of components, and values for a NetCDF dataset.

        Args:
            dset (xr.Dataset): The xarray representation of the NetCDF file
            dset_name (str): Name of the dataset in the file. Return value may contain a prettier one for the tree item.

        Returns:
            tuple(str, int, np.ndarray): The tree item name for the dataset, the number of components (1=scalar,
                2=vector), and the dataset values read from the NetCDF file
        """
        if dset_name.endswith('_x'):  # Special case for vectors, read *_y and create a vector dataset.
            num_components = 2
            vx = dset[dset_name].data
            dset_name = dset_name[:-2]
            vy = dset[f'{dset_name}_y'].data
            values = np.stack([vx, vy], axis=2)
        else:  # Scalar dataset
            values = dset[dset_name].data
            num_components = 1
        return dset_name, num_components, values

    def _write_xmdf_dataset(self, dset_name, num_components, times, values, activity_mask):
        """Write a dataset to an XMDF file for SMS consumption.

        Args:
            dset_name (str): Name to give the dataset tree item
            num_components (int): 1=scalar, 2=vector
            times (Sequence): 1-D array of timestep offsets in hours
            values (Sequence): The dataset values. If scalar, shape=(num_times, num_cells). If vector,
                shape=(num_times, num_cells, 2)
            activity_mask (Sequence): Cell-based activity array. shape=(num_times, num_cells)
        """
        self._logger.info('Writing imported dataset to XMDF file...')
        dset_uuid = str(uuid.uuid4())
        h5_filename = f'{os.path.join(self._temp_dir, dset_uuid)}.h5'
        writer = DatasetWriter(h5_filename=h5_filename, name=dset_name, dset_uuid=dset_uuid,
                               geom_uuid=self._ugrid_uuid, num_components=num_components, location='cells',
                               use_activity_as_null=True, ref_time=self._zero_time, time_units='Hours')
        writer.write_xmdf_dataset(times, values, activity=activity_mask)
        self._builders.append(writer)

    def read(self):
        """Import the TUFLOWFV NetCDF solution datasets.

        Returns:
            list[DatasetWriter]: The imported datasets
        """
        # NetCDF format is cell-centered. Don't read if linked geometry is a Mesh2D module object.
        if not self._check_for_linked_ugrid():
            return []

        # Ignore numpy warning about all-nan slices (cells that are dry across all timesteps).
        with warnings.catch_warnings():
            warnings.filterwarnings(action='ignore', message='All-NaN slice encountered')
            for filename in self._filenames:
                try:  # If reading one barfs, continue reading other files.
                    self._read_dataset(filename)
                except Exception as e:
                    self._logger.error(f'Errors reading NetCDF file: {str(e)}')
        return self._builders
