"""Class to retrieve all the data needed to export a fort.13 (standalone or full simulation export)."""

# 1. Standard Python modules
import os

# 2. Third party modules
import orjson

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.core.filesystem import filesystem as xfs
from xms.datasets.dataset_reader import DatasetReader

# 4. Local modules
from xms.adcirc.data.sim_data import SimData
from xms.adcirc.feedback.xmlog import XmLog

FORT13_EXPORT_ARGS_JSON = 'fort13_input_data.json'  # Name of JSON file written by fort.15 simulation export script
GEOID_OFFSET_IDX = 4  # GeoidOffset is a special case. Only attribute that is a constant, not a dataset.
NA_DSET_NAMES = [
    'surface_submergence_state', 'surface_directional_effective_roughness_length', 'surface_canopy_coefficient',
    'bottom_roughness_length', 'sea_surface_height_above_geoid', 'wave_refraction_in_swan',
    'average_horizontal_eddy_viscosity_in_sea_water_wrt_depth', 'primitive_weighting_in_continuity_equation',
    'quadratic_friction_coefficient_at_sea_floor', 'bridge_pilings_friction_paramenters', 'mannings_n_at_sea_floor',
    'chezy_friction_coefficient_at_sea_floor', 'elemental_slope_limiter', 'advection_state', 'initial_river_elevation'
]
NA_DSET_ATTRS = [
    ['surface_submergence_state'],
    [
        'z0land_000',
        'z0land_030',
        'z0land_060',
        'z0land_090',
        'z0land_120',
        'z0land_150',
        'z0land_180',
        'z0land_210',
        'z0land_240',
        'z0land_270',
        'z0land_300',
        'z0land_330',
    ],
    ['surface_canopy_coefficient'],
    ['bottom_roughness_length'],
    ['sea_surface_height_above_geoid'],  # const value instead of data set (special case)
    ['wave_refraction_in_swan'],
    ['average_horizontal_eddy_viscosity_in_sea_water_wrt_depth'],
    ['primitive_weighting_in_continuity_equation'],
    ['quadratic_friction_coefficient_at_sea_floor'],
    [
        'BK',
        'BAlpha',
        'BDelX',
        'POAN',
    ],
    ['mannings_n_at_sea_floor'],
    ['chezy_friction_coefficient_at_sea_floor'],
    ['elemental_slope_limiter'],
    ['advection_state'],
    ['initial_river_elevation'],
]
NA_DSET_UNITS = [
    'Unitless', 'm', 'Unitless', 'm', 'm', 'Unitless', 'm**2/s', 'Unitless', 'Unitless', 'Vary', 'User', 'User', 'm/m',
    'm', 'm'
]


class Fort13DataGetter:
    """Class to retrieve all the data needed to export a fort.13 (standalone or full simulation export)."""
    def __init__(self, query, xms_data):
        """Constructor.

        Args:
            query (:obj:`Query`): Object for requesting data from XMS
            xms_data (:obj:`Optional[dict]`): Dict of XMS data to fill
        """
        self._query = query
        self._xms_data = xms_data

    def _retrieve_sim_data(self):
        """Retrieve the simulation data if this is not a full simulation export (haven't already got it)."""
        if 'sim_data' not in self._xms_data:
            try:
                # Get the name of the mesh and it's number of nodes. If we are executing this code, it implies from the
                # Partial Export simulation component command.
                sim_uuid = self._query.parent_item_uuid()
                sim_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, sim_uuid)
                linked_meshes = tree_util.descendants_of_type(
                    sim_item, xms_types=['TI_MESH2D_PTR'], allow_pointers=True, recurse=False
                )
                self._xms_data['domain_name'] = linked_meshes[0].name  # Should be one and only one
                self._xms_data['num_nodes'] = linked_meshes[0].num_points

                # Get the simulation component's data.
                sim_comp = self._query.item_with_uuid(sim_uuid, model_name='ADCIRC', unique_name='Sim_Component')
                self._xms_data['sim_data'] = SimData(sim_comp.main_file)
            except Exception:
                raise RuntimeError('Unable to retrieve ADCIRC simulation data from SMS.')

    def _retrieve_datasets(self):
        """Retrieve the nodal attributes datasets."""
        # Get the nodal attribute datasets
        XmLog().instance.info('Retrieving ADCIRC nodal attribute datasets from SMS...')
        try:
            # Ask the simulation data for the enabled nodal attributes.
            self._xms_data['att_names'] = self._xms_data['sim_data'].get_enabled_nodal_atts()

            # Select the nodal dataset attributes by UUID.
            for att_name in self._xms_data['att_names']:
                att_idx = NA_DSET_NAMES.index(att_name)
                self._xms_data['att_units'].append(NA_DSET_UNITS[att_idx])  # Don't think this matters much
                att_dsets = []
                if att_idx != GEOID_OFFSET_IDX:
                    # Get the dataset UUIDs for this attribute from the simulation data (can be multiple per attribute).
                    for att_dset_attr in NA_DSET_ATTRS[att_idx]:
                        dset_uuid = self._xms_data['sim_data'].nodal_atts.attrs[att_dset_attr]
                        att_dsets.append(self._query.item_with_uuid(dset_uuid))
                else:  # Special case for GeoidOffset. Get constant value from simulation data.
                    att_dsets = [[float(self._xms_data['sim_data'].nodal_atts.attrs[NA_DSET_ATTRS[att_idx][0]])]]
                self._xms_data['att_dsets'].append(att_dsets)
        except Exception:
            raise RuntimeError('Unable to retrieve ADCIRC nodal attribute datasets from SMS.')

    def retrieve_data(self):
        """Retrieve data required to export fort.13 from XMS."""
        # Get the simulation data and domain mesh attributes
        XmLog().instance.info('Retrieving ADCIRC simulation data from SMS...')
        self._xms_data['att_dsets'] = []
        self._xms_data['att_units'] = []
        self._retrieve_sim_data()
        self._retrieve_datasets()


def write_fort13_data_json(xms_data):
    """Write a JSON file containing data needed to write the fort.13 (only applicable to full simulation exports).

    Args:
        xms_data (:obj:`dict`): The XMS data required to export the fort.13
    """
    # Serialize the dataset H5 filenames and group paths instead of DatasetReader objects.
    nodal_atts = xms_data.get('att_dsets', [])
    att_dsets = [
        [dset if type(dset) is list else (dset.h5_filename, dset.group_path) for dset in nodal_att]
        for nodal_att in nodal_atts
    ]
    json_data = {  # Don't need sim_data because it is only used to get other data from XMS.
        'domain_name': xms_data.get('domain_name', ''),
        'num_nodes': xms_data.get('num_nodes', 0),
        'att_names': xms_data.get('att_names', []),
        'att_dsets': att_dsets,
        'att_units': xms_data.get('att_units', []),
        'template': xms_data.get('template', False)
    }
    # Write the JSON file
    with open(os.path.join(os.getcwd(), FORT13_EXPORT_ARGS_JSON), 'wb') as file:
        data = orjson.dumps(json_data)
        file.write(data)


def read_fort13_data_json():
    """Read the JSON file of XMS input data written by the fort.15 script during a full simulation export.

    Returns:
        (:obj:`dict`): The XMS data required to export the fort.13
    """
    filename = os.path.join(os.getcwd(), FORT13_EXPORT_ARGS_JSON)
    with open(filename, 'rb') as file:
        xms_data = orjson.loads(file.read())
    xfs.removefile(filename)
    # Reconstruct the DatasetReaders from filenames and group paths in the JSON file.
    nodal_atts = xms_data.get('att_dsets', [])
    att_dsets = [
        [dset if len(dset) == 1 else DatasetReader(h5_filename=dset[0], group_path=dset[1]) for dset in nodal_att]
        for nodal_att in nodal_atts
    ]
    xms_data['att_dsets'] = att_dsets
    return xms_data
