"""A writer for spectral coverages for STWAVE."""

# 1. Standard Python modules
from collections import OrderedDict
import datetime
import logging

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem as io_util
from xms.coverage.spectral import PLANE_TYPE_enum
from xms.data_objects.parameters import Dataset, FilterLocation, julian_to_datetime
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules
from xms.stwave.data import stwave_consts as const
from xms.stwave.data.simulation_data import SimulationData
import xms.stwave.file_io.eng_writer


class SpectralCoverageWriter:
    """A class for writing spectral coverages."""

    def __init__(self, query=None, filename='', sim_export=False, xms_data=None):
        """Constructor that does nothing.

        Args:
            query (Optional[Query]): The XMS interprocess communicator. If not provided, implies a simulation export.
                Should be provided if called from the partial export dialog.
            filename (Optional[str]): Path to filename to export. Should be provided if called from the partial export
                dialog.
            sim_export (bool): Exporting entire simulation
            xms_data (dict or None): dict of data from xms
        """
        self.data = None
        self._filename = filename
        self._xms_data = xms_data if xms_data is not None else {}
        if query:
            self._query = query
        else:
            self._query = self._xms_data.get('query', Query())
        self._logger = self._xms_data.get('logger', logging.getLogger('xms.stwave'))
        self._sim_export = sim_export

    def write_spec_cov(self, a_sim_grid, a_grid_name, a_times, a_wind_dirs, a_wind_mags, a_surges, a_using_wind_d_set):
        """
        Writes Spectral Coverage to file.

        Args:
            a_sim_grid (:obj:`data_objects.Parameters.Spatial.SpatialVector.RectilinearGrid`): The grid.
            a_grid_name (str): Name of the grid.
            a_times (:obj:`list` of double): list of times in julian double format
            a_wind_dirs (:obj:`list` of double): list of wind directions.
            a_wind_mags (:obj:`list` of double): list of wind speeds.
            a_surges (:obj:`list` of double): list of water levels.
            a_using_wind_d_set (bool): True to write wind data to file, False to not.
        """
        self._logger.info('Writing spectral coverage...')
        # get the case data
        cases = []
        self._logger.info('Building spectral parameters and case times...')
        for time, wind_dir, wind_mag, w_lvl in zip(a_times, a_wind_dirs, a_wind_mags, a_surges):
            case = xms.stwave.file_io.eng_writer.STWAVECase(time, wind_dir,
                                                            wind_mag, w_lvl)
            cases.append(case)

        # simulation plane type and computational grid definition
        sim_plane_type = self.data.info.attrs['plane']
        sim_plane_type_sms = xms.stwave.file_io.eng_writer.PlaneTypes.LOCAL
        if sim_plane_type == const.PLANE_TYPE_HALF:
            sim_plane_type_sms = xms.stwave.file_io.eng_writer.PlaneTypes.HALF
        num_freqs = self.data.info.attrs['num_frequencies']
        delta_freq = self.data.info.attrs['delta_frequency']
        min_freq = self.data.info.attrs['min_frequency']
        max_freq = ((num_freqs - 1) * delta_freq) + min_freq

        # build the simulation grid definition
        self._logger.info('Building simulation grid definition...')
        dummy_d_set = Dataset()
        sim_grid = read_grid_from_file(a_sim_grid.cogrid_file)
        global_params = xms.stwave.file_io.eng_writer.SpectralParams(sim_plane_type_sms, sim_grid, dummy_d_set, 0,
                                                                     [], [], a_sim_grid.projection)
        global_params.set_freqs_const(min_freq, max_freq, delta_freq)
        if sim_plane_type_sms == xms.stwave.file_io.eng_writer.PlaneTypes.HALF:
            global_params.set_angles_const(-85, 85, 5)
        else:
            global_params.set_angles_const(0, 360, 5)

        self._logger.info('Retrieving spectral coverage from SMS...')
        spec_cov = self._query.item_with_uuid(self.data.info.attrs['spectral_uuid'], generic_coverage=True)
        if not spec_cov:
            raise RuntimeError('Unable to retrieve spectral coverage linked to the simulation.')
        pt_id_map = {}
        pt_param_map = {}
        spec_pts = spec_cov.m_cov.get_points(FilterLocation.PT_LOC_DISJOINT)
        for spec_pt in spec_pts:
            spec_pt_id = spec_pt.id
            pt_id_map[spec_pt_id] = spec_pt
            pt_param_map[spec_pt.id] = []
            unsorted_dict = {}
            spec_grids = spec_cov.GetSpectralGrids(spec_pt_id)
            for spec_grid in spec_grids:
                # only do this once so we only have one file per dataset
                spec_dset = spec_grid.get_dataset(io_util.temp_filename())
                for i in range(spec_dset.num_times):
                    plane_type = xms.stwave.file_io.eng_writer.PlaneTypes.LOCAL
                    if spec_grid.m_planeType.value == PLANE_TYPE_enum.FULL_GLOBAL_PLANE.value:
                        plane_type = xms.stwave.file_io.eng_writer.PlaneTypes.GLOBAL
                    elif spec_grid.m_planeType.value == PLANE_TYPE_enum.HALF_PLANE.value:
                        plane_type = xms.stwave.file_io.eng_writer.PlaneTypes.HALF
                    spec_dset.ts_idx = i
                    ts_time = julian_to_datetime(spec_dset.ts_time)
                    spec_params = xms.stwave.file_io.eng_writer.SpectralParams(plane_type, spec_grid.m_rectGrid,
                                                                               spec_dset, i, [], [],
                                                                               spec_grid.m_rectGrid.projection,
                                                                               ts_time)
                    unsorted_dict[round(spec_dset.ts_time, 7)] = spec_params
                    # ^use Julian double representation of time as key
                sorted_dict = OrderedDict(sorted(unsorted_dict.items()))
                for _, v in sorted_dict.items():
                    pt_param_map[spec_pt.id].append(v)

        reftime = datetime.datetime.strptime(self.data.info.attrs['reftime'], ISO_DATETIME_FORMAT)
        time_units = self.data.info.attrs['reftime_units']
        writer = xms.stwave.file_io.eng_writer.EngWriter(a_grid_name, pt_id_map, pt_param_map,
                                                         global_params, cases, a_using_wind_d_set,
                                                         reftime, time_units, self.data.info.attrs['angle_convention'])
        writer.filename = self._filename  # Didn't want to add yet another argument to the constructor
        writer.write()

    def export_coverage(self):
        """Write the coverage to the file."""
        # Get the simulation tree item and its hidden component
        self._logger.info('Retrieving simulation data from SMS...')
        if self._sim_export:
            sim_uuid = self._xms_data.get('sim_uuid', self._query.current_item_uuid())
            sim_comp = self._xms_data.get('sim_comp', None)
            if sim_comp is None:
                sim_comp = self._query.item_with_uuid(sim_uuid, model_name='STWAVE', unique_name='Sim_Component')
        else:
            sim_uuid = self._query.parent_item_uuid()
            sim_comp = self._query.current_item()
        sim_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, sim_uuid)
        self.data = SimulationData(sim_comp.main_file)

        # Check if we need to write the .eng file
        if self.data.info.attrs['boundary_source'] == const.OPT_NONE:
            self._logger.info('Boundary condition source set to constant, no .eng file will be exported.')
            return

        # get case times for simulation
        times = self.data.times_in_seconds().tolist()

        # get the simulation grid
        self._logger.info('Retrieving domain grid from SMS...')
        do_grid = self._query.item_with_uuid(self.data.info.attrs['grid_uuid'])

        # get the wind data for the cases
        terms = self.data.info.attrs['source_terms']
        using_wind_d_set = False
        if terms == const.SOURCE_PROP_ONLY:  # propagation only
            wind_dirs = [0.0 for _ in range(len(times))]
            wind_mags = [0.0 for _ in range(len(times))]
        else:
            use_const_wind = self.data.info.attrs['wind']
            if use_const_wind == const.OPT_CONST:  # constant wind values
                wind_dirs = self.data.case_times['Wind Direction']
                wind_mags = self.data.case_times['Wind Magnitude']
            else:  # write wind dataset to file
                using_wind_d_set = True
                wind_dirs = [0.0 for _ in range(len(times))]
                wind_mags = [0.0 for _ in range(len(times))]

        # get the water levels for the cases
        use_const_tidal = self.data.info.attrs['surge']
        if use_const_tidal == const.OPT_CONST:  # const tidal surge values
            water_lvl = self.data.case_times['Water Level']
        else:  # get tidal data from dataset
            water_lvl = [0.0 for _ in range(len(times))]

        # write the spectral energy file
        self.write_spec_cov(do_grid, sim_item.name, times, wind_dirs, wind_mags, water_lvl, using_wind_d_set)
