"""StationData class."""

# 1. Standard Python modules
import os

# 2. Third party modules
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem as io_util

# 4. Local modules
from xms.adcirc.__version__ import version
from xms.adcirc.data.adcirc_data import UNINITIALIZED_COMP_ID

STATION_MAIN_FILE = 'station_comp.nc'


class StationData(XarrayBase):
    """Manages data file for the boundary conditions coverage hidden component."""
    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (:obj:`str`): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.

        """
        # Initialize member variables before calling super so they are available for commit() call
        self._filename = data_file
        self._info = None
        self._stations = None
        # Create the default file before calling super because we have our own attributes to write.
        self._get_default_datasets(data_file)
        super().__init__(data_file)

    def update_station(self, comp_id, new_atts):
        """Update the recording station attributes of a point.

        Args:
            comp_id (:obj:`int`): Component id of the station to update
            new_atts (:obj:`xarray.Dataset`): The new attributes for the recording station

        """
        self.stations['elevation'].loc[dict(comp_id=[comp_id])] = new_atts['elevation']
        self.stations['velocity'].loc[dict(comp_id=[comp_id])] = new_atts['velocity']
        self.stations['wind'].loc[dict(comp_id=[comp_id])] = new_atts['wind']

    @property
    def stations(self):
        """Load the stations dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the recording stations dataset in the main file

        """
        if self._stations is None:
            self._stations = self.get_dataset('stations', False)
        return self._stations

    @stations.setter
    def stations(self, dset):
        """Setter for the stations dataset."""
        if dset:
            self._stations = dset

    def add_station_atts(self, dset=None):
        """Add the recording stations attribute dataset for a point.

        Args:
            dset (:obj:`xarray.Dataset`): The attribute dataset to concatenate. If not provided, a new
                Dataset of default attributes will be generated.


        Returns:
            (:obj:`tuple(int)`): The newly generated component id

        """
        try:
            new_comp_id = self.info.attrs['next_comp_id'].item()
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = self._get_new_station_atts(new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            self._stations = xr.concat([self.stations, dset], 'comp_id')
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (:obj:`str`): Name of the data file. If it doesn't exist, it will be created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'ADCIRC_RECORDING_STATIONS',
                'VERSION': version,
                'cov_uuid': '',
                'next_comp_id': 0,
                'native_import': 0,  # Need to set component ids if read from a model native import.
            }
            self._info = xr.Dataset(attrs=info)

            station_table = {
                'elevation': ('comp_id', []),
                'velocity': ('comp_id', []),
                'wind': ('comp_id', []),
            }
            coords = {'comp_id': []}
            self._stations = xr.Dataset(data_vars=station_table, coords=coords)

            self.commit()

    @staticmethod
    def _get_new_station_atts(comp_id):
        """Get a new dataset with default attributes for a recording station.

        Args:
            comp_id (:obj:`int`): The unique XMS component id of the BC arc. If UNINITIALIZED_COMP_ID, a new one is
                generated.

        Returns:
            (:obj:`xarray.Dataset`): A new default dataset for a BC arc. Can later be concatenated to
            persistent dataset.

        """
        bc_table = {
            'elevation': ('comp_id', [0]),
            'velocity': ('comp_id', [0]),
            'wind': ('comp_id', [0]),
        }
        coords = {'comp_id': [comp_id]}
        ds = xr.Dataset(data_vars=bc_table, coords=coords)
        return ds

    def concat_station_points(self, station_data):
        """Adds the station point attributes from station_data to this instance of StationData.

        Args:
            station_data (:obj:`StationData`): another StationData instance

        Returns:
            (:obj:`dict`): The old ids of the station_data as key and the new ids as the data

        """
        next_comp_id = self.info.attrs['next_comp_id']
        # Reassign component id coordinates.
        new_station_points = station_data.stations
        num_concat_points = new_station_points.sizes['comp_id']
        if num_concat_points:
            old_comp_ids = new_station_points.coords['comp_id'].data.astype('i4').tolist()
            new_station_points.coords['comp_id'] = [next_comp_id + idx for idx in range(num_concat_points)]
            self.info.attrs['next_comp_id'] = next_comp_id + num_concat_points
            self._stations = xr.concat([self.stations, new_station_points], 'comp_id')
            return {
                old_comp_id: new_comp_id
                for old_comp_id, new_comp_id in
                zip(old_comp_ids, new_station_points.coords['comp_id'].data.astype('i4').tolist())
            }
        else:
            return {}

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._stations is not None:
            self._stations.close()
            self._drop_h5_groups(['stations'])
            self._stations.to_netcdf(self._filename, group='stations', mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space.

        All recording station datasets that need to be written to the file must be loaded into memory before calling
        this method.

        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._stations is None:
            self._stations = self.get_dataset('stations', False)
        io_util.removefile(self._filename)  # Delete the existing NetCDF file
        self.commit()  # Rewrite all datasets
