"""Class for writing XMDF data set files."""

# 1. Standard Python modules
import datetime
import os
import pathlib
from typing import List, Optional, Sequence, Tuple, Union
import uuid

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules
from xms.core.filesystem.filesystem import temp_filename
from xms.core.time import datetime_to_julian, julian_to_datetime

# 4. Local modules
from xms.datasets.dataset_io import DSET_NULL_VALUE
from xms.datasets.dataset_metadata import XMDF_DATA_LOCATIONS
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.numpy_util import ensure_sequence_is_numpy_array
from xms.datasets.vectors import find_vector_mins_and_maxs


class DatasetWriter:
    """Class for writing XMDF data set files."""
    def __init__(
        self,
        h5_filename: Optional[str] = None,
        name: str = 'Dataset',
        dset_uuid: str = None,
        geom_uuid: str = '',
        num_components: int = 1,
        ref_time: Optional[Union[float, datetime.datetime]] = None,
        null_value: float = None,
        time_units: str = 'Days',
        units: str = '',
        dtype: str = 'f',
        use_activity_as_null: bool = False,
        location: str = 'points',
        overwrite: bool = True,
        geom_path: Optional[str] = '',
        h5_handle: Optional[h5py.File] = None
    ) -> None:
        """Constructor.

        Args:
            h5_filename (:obj:`str`): Path to the H5 file to write
            name (:obj:`str`): Name of the dataset (will be used to build the dataset's group path)
            dset_uuid (:obj:`str`): UUID of the dataset
            geom_uuid (:obj:`str`): UUID of the dataset's geometry
            num_components (:obj:`int`): The number of data components (1=scalar, 2=vector, 3=3D vector)
            ref_time (:obj:`float, datetime.datetime`): The dataset's reference time. Either a Julian float or a
                Python datetime.datetime
            null_value (:obj:`float`): The null value of the dataset, if there is one
            time_units (:obj:`str`): The dataset's time units. One of: 'Seconds', 'Minutes', 'Hours', 'Days'
            units (:obj:`str`): Units of the dataset values
            dtype (:obj:`str`): Data type of the dataset's values. One of: 'f', 'd', 'i'
            use_activity_as_null (:obj:`bool`): If True, inactive values (0) will be treated as null values when
                computing timestep mins and maxs. Implies that activity array is on same location as data values
                (i.e. has same shape).
            location (:obj:`str`): Location of the dataset values. One of the XMDF_DATA_LOCATIONS keys. Note that this
                does not usually need to be set XMS is going to ignore it in most cases. XMS will try to determine the
                dataset location based on the geometries currently loaded (number of nodes, number of points). Here for
                historical reasons.
            overwrite (:obj:`bool`): If True and the file already exists, it will be overwritten
            geom_path (:obj:`str`): Path to the dataset's geometry group. If provided, dataset will be written to this
                location. Allows for datasets in folders.
            h5_handle (:obj:`h5py.File`): If provided, uses an existing file handle instead of opening and closing the
                file.
        """
        if h5_filename is None:
            h5_filename = os.path.join(temp_filename(), f'{name}.h5')
            pathlib.Path(h5_filename).parent.mkdir(parents=True)
        self.h5_filename = h5_filename
        if name.encode('ascii', 'ignore').decode('ascii') != name:
            raise ValueError('Dataset name must be ASCII.')
        self.name = name
        # Strip of all separators and then add our own. Be flexible with user provided path.
        self.geom_path = geom_path.strip('/\\')
        self.uuid = dset_uuid if dset_uuid else str(uuid.uuid4())
        self.geom_uuid = geom_uuid
        self.num_components = num_components
        self.null_value = null_value
        self.units = units
        self.use_activity_as_null = use_activity_as_null
        self.num_values = None  # Not initialized until a time step is written
        self.num_activity_values = None  # Not initialized until a time step is written
        self.dtype = dtype
        self._overwrite = overwrite
        self._time_units = ''
        self._location = ''
        self._ref_time = None
        self._h5file = h5_handle  # File handle to the H5 file
        self._close_handle = h5_handle is None
        self._multi_group = None  # h5py.Group of the XMDF MultiDatasets group
        self._dset_group = None
        self._added_data = False
        self._active_timestep = -1  # index of the active timestep

        # Initialize attributes through property setters.
        self.ref_time = ref_time
        self.time_units = time_units
        self.location = location

        # Used when building dataset by appending timesteps.
        self._values_dset = None
        self._activity_dset = None
        self._times: List[float] = []
        self._mins: List[float] = []
        self._maxs: List[float] = []

        # Set manually computed timestep mins and maxs. Useful for when activity array is at different location than
        # the dataset values.
        self.timestep_mins = None
        self.timestep_maxs = None

        # If point data with cell activity set this. Must define calc(np.ndarray) that returns a np.ndarray activity
        # mask that matches the number of data values. See CellToPointActivityCalculator class in xmsconstraint.
        self.activity_calculator = None

    def _validate_location(self, data_loc: str) -> str:
        """Ensure location string is valid.

        Args:
            data_loc (:obj:`str`): The location string to verify

        Returns:
            (:obj:`str`): The cleaned location string if valid

        Raises:
            ValueError if location string is invalid
        """
        data_loc_clean = data_loc.lower()
        if data_loc_clean not in XMDF_DATA_LOCATIONS.keys():
            raise ValueError(f'Invalid data location: {data_loc}. Must be one of: {XMDF_DATA_LOCATIONS.keys()}')
        return data_loc_clean

    def _validate_time_units(self, time_units: str) -> str:
        """Ensure time units string is valid.

        Args:
            time_units (:obj:`str`): The time units string to verify

        Returns:
            (:obj:`str`): The cleaned time units string if valid

        Raises:
            ValueError if time units string is invalid
        """
        time_units_clean = time_units.title()
        if time_units_clean not in ['Seconds', 'Minutes', 'Hours', 'Days', 'Years', 'None']:
            raise ValueError(
                f'Invalid time units: {time_units}. Must be one of: "Seconds", "Minutes", "Hours", '
                f'"Days", "Years", "None"'
            )
        return time_units_clean

    def _validate_dtype(self, dtype: str) -> str:
        """Ensure dtype string is valid.

        Args:
            dtype (:obj:`str`): The dtype string to verify

        Returns:
            (:obj:`str`): The cleaned dtype string if valid

        Raises:
            ValueError if dtype string is invalid
        """
        dtype_clean = dtype.lower()
        if dtype_clean not in ['f', 'd', 'i']:
            raise ValueError(f'Invalid dtype: {dtype}. Must be one of: "f", "d", "i"')
        return dtype_clean

    @property
    def location(self) -> str:
        """The location of the dataset values.

        May not be accurate. Used for compatibility with older and obscure models. XMS will determine data location
        based on currently loaded geometries.
        """
        return self._location

    @location.setter
    def location(self, data_loc: str) -> None:
        """The location of the dataset values."""
        self._location = self._validate_location(data_loc)

    @property
    def time_units(self) -> str:
        """The dataset's time units."""
        return self._time_units

    @time_units.setter
    def time_units(self, time_units: str) -> None:
        """The dataset's time units."""
        self._time_units = self._validate_time_units(time_units)

    @property
    def ref_time(self) -> Union[datetime.datetime, None]:
        """The dataset reference time."""
        return self._ref_time

    @ref_time.setter
    def ref_time(self, reference_time: Optional[Union[float, datetime.datetime]]) -> None:
        """The dataset value's numpy data type."""
        ref_time_type = type(reference_time)
        if ref_time_type in [datetime.datetime, type(None)]:
            self._ref_time = reference_time
        elif ref_time_type == float:  # If float value passed, assume we were given a Julian time
            self._ref_time = julian_to_datetime(reference_time)
        else:
            raise ValueError('Reference time must be a Julian float, a datetime.datetime object, or None.')

    @property
    def dtype(self) -> str:
        """The dataset value's numpy data type."""
        return self._dtype

    @dtype.setter
    def dtype(self, data_type: str) -> None:
        """The dataset value's numpy data type."""
        self._dtype = self._validate_dtype(data_type)

    @property
    def h5file(self) -> h5py.File:
        """File handle to the H5 file."""
        if self._h5file is None:
            if not self.h5_filename:
                raise RuntimeError('H5 Filename must be set before writing.')
            self._h5file = h5py.File(self.h5_filename, 'w' if self._overwrite else 'a')
        return self._h5file

    @property
    def multi_group(self) -> h5py.Group:
        """h5py.Group of the XMDF MultiDatasets group."""
        if self._multi_group is None:
            self._multi_group = self._create_xmdf_multi_datasets_group()
        return self._multi_group

    @property
    def dset_group(self) -> h5py.Group:
        """h5py.Group of the XMDF Dataset."""
        if self._dset_group is None:
            self._dset_group = self._create_xmdf_dataset_group()
        return self._dset_group

    @property
    def values(self) -> h5py.Dataset:
        """h5py.Dataset of the XMDF Dataset values."""
        if self._values_dset is None:
            is_scalar = self.num_components == 1
            shape = (1, self.num_values) if is_scalar else (1, self.num_values, self.num_components)
            max_shape = (None, self.num_values) if is_scalar else (None, self.num_values, self.num_components)
            self._values_dset = self.dset_group.create_dataset(
                'Values', dtype=self.dtype, shape=shape, maxshape=max_shape
            )
        return self._values_dset

    @property
    def activity(self) -> h5py.Dataset:
        """h5py.Dataset of the XMDF Dataset activity array."""
        if self._activity_dset is None:
            self._activity_dset = self.dset_group.create_dataset(
                'Active', dtype='u1', shape=(1, self.num_activity_values), maxshape=(None, self.num_activity_values)
            )
        return self._activity_dset

    @property
    def group_path(self):
        """Returns the H5 group path to the dataset.

        Group may not exist yet. If no data has been written to the H5 file, this is where it will be written.
        """
        if self.geom_path:
            return f'{self.geom_path}/Datasets/{self.name}'
        return f'Datasets/{self.name}'

    @property
    def active_timestep(self):
        """Returns the active timestep index.

        Returns:
            (:obj:`int`): See description.
        """
        return self._active_timestep

    @active_timestep.setter
    def active_timestep(self, timestep):
        """Sets the active timestep index.

        Args:
            timestep (:obj:`int`): The active timestep index.
        """
        self._active_timestep = timestep

    def _replace_inactive_with_nan(self, npdata: np.ndarray, npactivity: Union[np.ndarray,
                                                                               None]) -> Union[np.ndarray, None]:
        """Replace inactive dataset values with nan for numpy operations.

        Args:
            npdata (:obj:`numpy.ndarray`): The dataset values
            npactivity (:obj:`numpy.ndarray`): The activity array, if it exists

        Returns:
            (:obj:`numpy.ndarray`): The original values before applying the activity array, or None if no acitvity array
            provided.
        """
        original_data = None  # Used to restore values after applying activity
        if self.null_value is not None:  # Convert null values to nan for numpy operations
            npdata[npdata == self.null_value] = np.nan
        elif npactivity is not None and self.use_activity_as_null:  # Treat inactive values as null values
            shape = npdata.shape if self.num_components == 1 else npdata.shape[:-1]
            if shape != npactivity.shape and self.activity_calculator is None:  # scalar
                raise ValueError(
                    'shape of data values does not match shape of activity array. To use activity as null values, '
                    'the activity array must have the same dimensions as the dataset values.'
                )
            original_data = npdata.copy()
            npdata[npactivity == 0] = np.nan
        return original_data

    def _replace_nan_with_inactive(self, npdata: np.ndarray, original_data: Union[np.ndarray, None]) -> np.ndarray:
        """Replace nan dataset values with null value or original data (if using an activity array).

        Args:
            npdata (:obj:`numpy.ndarray`): The dataset values
            original_data (:obj:`numpy.ndarray`): The original dataset values before applying the activity array,
                if using one

        Returns:
            (:obj:`numpy.ndarray`): The original values
        """
        if original_data is not None:
            return original_data  # Using an activity array, return the original dataset values.
        if self.null_value is not None:  # Convert nan values back to null value
            npdata[np.isnan(npdata)] = self.null_value
        return npdata

    def _compute_timestep_min_max(self, npdata: Sequence) -> Tuple[np.ndarray, np.ndarray]:
        """Compute mins and maxs of n-number of timesteps.

        Args:
            npdata (:obj:`Sequence`): list-like array of timestep data.

        Returns:
            (:obj:`tuple(numpy.ndarray, numpy.ndarray)`): The timestep mins and maxs. If dataset is a vector, mins and
            maxs of the vector magnitudes.
        """
        if self.num_components == 1:  # scalar
            mins = np.nanmin(npdata, 1)
            maxs = np.nanmax(npdata, 1)
        else:  # 2D vector
            mins, maxs = find_vector_mins_and_maxs(npdata)
        return mins, maxs

    def _create_xmdf_multi_datasets_group(self) -> h5py.Group:
        """Creates an XMDF standard MULTI DATASETS group to aid in creating a dataset file.

        Returns:
            (:obj:`h5py.Group`): The created HDF5 group named 'MULTI DATASETS', with properties as defined in the
            XMDF format
        """
        path = f'{self.geom_path}/Datasets' if self.geom_path else 'Datasets'
        if path not in self.h5file:  # Create the MultiDatasets folder if it doesn't exist.
            grp_datasets = self.h5file.create_group(path)
            multi_datasets = 'MULTI DATASETS'
            ascii_list = [multi_datasets.encode("ascii", "ignore")]
            grp_datasets.attrs.create('Grouptype', data=ascii_list, shape=(1, ), dtype='S15')
            ascii_list = [self.geom_uuid.encode("ascii", "ignore")]
            grp_datasets.create_dataset('Guid', shape=(1, ), dtype='S37', data=ascii_list)
            if 'File Type' not in self.h5file:
                file_type = 'Xmdf'
                ascii_list = [file_type.encode("ascii", "ignore")]
                self.h5file.create_dataset('File Type', shape=(1, ), dtype='S5', data=ascii_list)
                self.h5file.create_dataset('File Version', data=99.99, dtype='f')
        else:
            grp_datasets = self.h5file[path]
        return grp_datasets

    def _create_xmdf_dataset_group(self) -> h5py.Group:
        """Creates an XMDF scalar dataset group to aid in creating a dataset file.

        Returns:
            (:obj:`h5py.Group`): The HDF5 group created.
        """
        # Write data set metadata required for XMDF to recognize the file.
        dataset_group = self.multi_group.create_group(self.name)
        dataset_group.attrs.create('Version', dtype='i4', data=[1])
        dataset_group.attrs.create('Data Type', dtype='i4', data=[0])
        dataset_group.attrs.create('DatasetCompression', dtype='i4', data=[-1])
        if self.num_components == 1:
            dataset_group.attrs.create('DatasetLocation', dtype='i4', data=[XMDF_DATA_LOCATIONS.get(self.location, 1)])
            ascii_list = ['DATASET SCALAR'.encode("ascii", "ignore")]
        else:  # 2D vector
            dataset_group.attrs.create('DatasetLocationI', dtype='i4', data=[XMDF_DATA_LOCATIONS.get(self.location, 1)])
            dataset_group.attrs.create('DatasetLocationJ', dtype='i4', data=[XMDF_DATA_LOCATIONS.get(self.location, 1)])
            ascii_list = ['DATASET VECTOR'.encode("ascii", "ignore")]
        dataset_group.attrs.create('Grouptype', data=ascii_list, shape=(1, ), dtype='S15')
        ascii_list = [self.time_units.encode("ascii", "ignore")]
        dtype = f'S{len(self.time_units)+1}'  # ('S5' for 'Days'}
        dataset_group.attrs.create('TimeUnits', data=ascii_list, shape=(1, ), dtype=dtype)
        ascii_list = [self.units.encode("ascii", "ignore")]
        dtype = f'S{len(self.units)+1}'
        dataset_group.attrs.create('DatasetUnits', data=ascii_list, shape=(1, ), dtype=dtype)

        if self.ref_time is not None:  # Write reftime if provided
            # Use default 64-bit float for reftime
            dataset_group.attrs.create('Reftime', data=[datetime_to_julian(self.ref_time)])

        prop_grp = None
        if self.null_value is not None:  # Write null value to PROPERTIES group if provided
            prop_grp = self._ensure_properties_group_exists(dataset_group, prop_grp)
            prop_grp.create_dataset('nullvalue', data=[self.null_value], dtype='f')

        if self.uuid:
            prop_grp = self._ensure_properties_group_exists(dataset_group, prop_grp)
            ascii_list = [self.uuid.encode("ascii", "ignore")]
            dtype = f'S{len(self.uuid) + 1}'
            prop_grp.create_dataset('GUID', data=ascii_list, dtype=dtype)
        return dataset_group

    def _ensure_properties_group_exists(self, dataset_group, prop_grp):
        """Creates the PROPERTIES group in the h5 file if it doesn't exist.

        Args:
            dataset_group: Handle to the datasets group.
            prop_grp: Handle to the properties group.

        Returns:
            The properties group.
        """
        if not prop_grp:  # Create the PROPERTIES group if we haven't already
            prop_grp = dataset_group.create_group('PROPERTIES')
            ascii_list = ['PROPERTIES'.encode("ascii", "ignore")]
            prop_grp.attrs.create('Grouptype', data=ascii_list, dtype='S11')
        return prop_grp

    def write_xmdf_dataset(self, times: Sequence, data: Sequence, activity: Optional[Sequence] = None) -> None:
        """Write an entire in-memory dataset to an XMDF formatted file.

        Args:
            times (:obj:`numpy.ndarray`): 1-D array of float time step offsets
            data (:obj:`numpy.ndarray`): The dataset values organized in XMDF structure. Rows are timesteps and columns
                are node/cell values. If a vector dataset, outer dimensions contain the additional components.
            activity (:obj:`numpy.ndarray`): The activity array, if it exists
        """
        # Ensure we haven't already been appending timesteps.
        if self._added_data:
            raise RuntimeError('Trying to write an entire in-memory dataset after appending timesteps.')
        self._added_data = True

        # Allow caller to pass Python list/tuple
        npdata = ensure_sequence_is_numpy_array(data)
        npactivity = activity if activity is None else ensure_sequence_is_numpy_array(activity)

        original_data = None
        # Compute timestep mins and maxs

        minmax_activity = npactivity
        if self.activity_calculator is not None and npactivity is not None:
            if npactivity.shape != npdata.shape:
                minmax_activity = np.resize(minmax_activity, npdata.shape)

            for i in range(npactivity.shape[0]):
                minmax_activity[i] = self.activity_calculator.calc(npactivity[i])
        if self.timestep_mins is not None and self.timestep_maxs is not None:  # User provided mins and maxs
            mins = self.timestep_mins
            maxs = self.timestep_maxs
        else:
            # Replace inactive values with nan for numpy operations
            original_data = self._replace_inactive_with_nan(npdata, minmax_activity)
            mins, maxs = self._compute_timestep_min_max(npdata)

        # Replace nan values with original data
        npdata = self._replace_nan_with_inactive(npdata, original_data)

        # Write the entire dataset to file
        self.dset_group.create_dataset('Maxs', dtype=self.dtype, data=maxs)
        self.dset_group.create_dataset('Mins', dtype=self.dtype, data=mins)
        self.dset_group.create_dataset('Times', dtype='f8', data=times)
        self.dset_group.create_dataset('Values', dtype=self.dtype, data=npdata)
        if npactivity is not None:  # Write the activity array if provided
            self.dset_group.create_dataset('Active', dtype='u1', data=npactivity)
        if self._active_timestep != -1:
            self.dset_group.create_dataset('Active Function', data=[self._active_timestep], dtype='i')
        if self._close_handle:
            self.h5file.close()
            self._h5file = None

    def append_timestep(self, time: float, data: Sequence, activity: Optional[Sequence] = None) -> None:
        """Append a timestep to the dataset file.

        Args:
            time (:obj:`float`): Offset from the reference time for this timestep
            data (:obj:`Sequence`): The dataset values for this timestep
            activity (:obj:`Sequence`): The activity array for this timestep
        """
        # Allow caller to pass Python list/tuple
        npdata = ensure_sequence_is_numpy_array(data)
        npactivity = activity if activity is None else ensure_sequence_is_numpy_array(activity)

        if self.num_values is None:  # This is the first timestep, determine number of dataset values.
            # Ensure that we have not written to file yet.
            if self._added_data:
                raise RuntimeError(
                    'Trying to append a timestep to file that has already been written. Create a new '
                    'DatasetWriter object to overwrite or append the file.'
                )

            self.num_values = npdata.shape[0]
            if npactivity is not None:
                self.num_activity_values = npactivity.shape[0]
        else:  # Not the first timestep, resize the datasets
            # Check for jagged activity arrays across timesteps.
            activity_without_activity_dset = activity is not None and self._activity_dset is None
            activity_dset_without_activity = activity is None and self._activity_dset is not None
            if activity_without_activity_dset or activity_dset_without_activity:
                raise RuntimeError(
                    'Incompatible activity arrays found across timesteps. All timesteps must have '
                    'an activity array or all timesteps must not have an activity array.'
                )

            self.values.resize(self.values.shape[0] + 1, axis=0)
            if activity is not None:
                self.activity.resize(self.activity.shape[0] + 1, axis=0)
        self._added_data = True

        original_data = None
        self._times.append(time)
        minmax_activity = npactivity
        if self.activity_calculator is not None and npactivity is not None:
            minmax_activity = self.activity_calculator.calc(npactivity)
        if self.timestep_mins is None:  # Compute timestep min and max if not overriden by user.
            original_data = self._replace_inactive_with_nan(npdata, minmax_activity)
            mins, maxs = self._compute_timestep_min_max([npdata])
            if self.null_value is not None:
                mins[np.isnan(mins)] = 0.0
                maxs[np.isnan(maxs)] = 0.0
            self._mins.append(mins[0])
            self._maxs.append(maxs[0])

        self.values[-1] = self._replace_nan_with_inactive(npdata, original_data)
        if activity is not None:  # Append the activity array if provided.
            self.activity[-1] = npactivity

    def appending_finished(self) -> None:
        """Call once finished appending timesteps to flush to file."""
        self.dset_group.create_dataset('Times', dtype='f8', data=self._times)
        if self.timestep_mins is not None and self.timestep_maxs is not None:  # User provided timestep mins/maxs
            self.dset_group.create_dataset('Maxs', dtype=self.dtype, data=self.timestep_maxs)
            self.dset_group.create_dataset('Mins', dtype=self.dtype, data=self.timestep_mins)
        else:  # Computed timestep mins/maxs
            self.dset_group.create_dataset('Maxs', dtype=self.dtype, data=self._maxs)
            self.dset_group.create_dataset('Mins', dtype=self.dtype, data=self._mins)
        if self._active_timestep != -1:
            self.dset_group.create_dataset('Active Function', data=[self._active_timestep], dtype='i')
        # Close the file handle
        if self._close_handle:
            self.h5file.close()
            self._h5file = None

    def duplicate_from_reader(self, dataset_reader: DatasetReader):
        """Copies the data set from a DatasetReader.

        Args:
            dataset_reader (:obj:`DatasetReader`): The reader to copy from.
            new_geom_uuid (:obj:`str`): UUID of the new geometry.
        """
        dataset_reader.duplicate_to_new_geometry(self.geom_uuid, self.h5_filename, self.name, self.uuid)

    @staticmethod
    def get_reasonable_null_value():
        """Returns the reasonable null value constant from dataset_io."""
        return DSET_NULL_VALUE
