"""Read files used to store STWAVE simualtions as part of old SMS projects."""

# 1. Standard Python modules
import math
import os
import uuid

# 2. Third party modules
import h5py
import numpy
import pandas

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint.rectilinear_geometry import Numbering, Orientation
from xms.constraint.rectilinear_grid_builder import RectilinearGridBuilder
from xms.core.filesystem import filesystem as io_util
from xms.data_objects.parameters import (Component, Coverage, julian_to_datetime, Point, RectilinearGrid, Simulation,
                                         UGrid)
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules
from xms.stwave.data import simulation_data
from xms.stwave.data import stwave_consts as const
from xms.stwave.data.simulation_data import SimulationData
from xms.stwave.file_io.eng_reader import get_coords_from_ij


class XmsReader:
    """A class for reading STWAVE simulations from files that are part of old SMS projects."""

    def __init__(self):
        """Constructor."""
        self.query = None
        self.nest_pts = []
        self.monitor_pts = []
        self.using_jonswap = False
        # self.debugger = open('debug.txt', 'w')
        self.query = Query()
        self._link_items = []
        self.free_dsets = {}
        self.data = None

    def read(self):
        """Reads the file, builds grid, coverage, loads data and sends it all to SMS."""
        in_proj_filename = self.query.read_file

        # Get the SMS temp directory.
        comp_dir = os.path.join(self.query.xms_temp_directory, 'Components')
        os.makedirs(comp_dir, exist_ok=True)

        with h5py.File(in_proj_filename, 'r') as f1:
            grid_names = list(f1["Cart2DModule"].keys())

        # Key is simulation UUID
        sim_data = {}
        grid_info = {}
        grids = {}

        # add each simulation we will need
        for grid_name in grid_names:
            sim = Simulation(model='STWAVE', name=grid_name, sim_uuid=str(uuid.uuid4()))

            comp_uuid = str(uuid.uuid4())
            sim_comp_dir = os.path.join(comp_dir, comp_uuid)
            os.makedirs(sim_comp_dir, exist_ok=True)
            sim_mainfile = os.path.join(sim_comp_dir, 'simulation_comp.nc')
            comp = Component(comp_uuid=comp_uuid, main_file=sim_mainfile, model_name='STWAVE',
                             unique_name='Sim_Component')
            self.data = SimulationData(sim_mainfile)  # This will create a default model control

            # Copy the H5 file so we can feed it to data_objects RectilinearGrid reader code.
            in_filename = os.path.join(self.query.xms_temp_directory, f'{grid_name}_os.path.basename(in_proj_filename)')
            io_util.copyfile(in_proj_filename, in_filename)

            sim_uuid = sim.uuid
            depth_values = None
            self.free_dsets[sim_uuid] = []
            grid_path = 'Cart2DModule/' + grid_name + '/'

            with h5py.File(in_proj_filename, 'r') as f:
                if f.__contains__(grid_path + 'PROPERTIES/GUID'):
                    grid_uuid = f[grid_path + 'PROPERTIES/GUID'][0].decode('UTF-8')
                elif f.__contains__(grid_path + 'PROPERTIES/Guid'):
                    grid_uuid = f[grid_path + 'PROPERTIES/Guid'][0].decode('UTF-8')  # used to be camelcase
                else:
                    grid_uuid = str(uuid.uuid4())
                self.data.info.attrs['grid_uuid'] = grid_uuid
                grid_info[sim_uuid] = (in_filename, grid_path, grid_uuid)

                base_path = grid_path + 'PROPERTIES/Model Params/'
                dset_path = grid_path + 'Datasets'
                spec_path = base_path + 'Spectral Grid/'
                model_type = int(f[base_path + 'ModelType'][0])
                option = const.PLANE_TYPE_FULL
                if model_type == 0:
                    option = const.PLANE_TYPE_HALF
                self.data.info.attrs['plane'] = option

                source_type = int(f[base_path + 'SourceType'][0])
                option = const.SOURCE_PROP_ONLY
                if source_type == 0:
                    option = const.SOURCE_PROP_AND_TERMS
                self.data.info.attrs['source_terms'] = option

                depth_option = const.DEP_OPT_NONTRANSIENT
                if f.__contains__(base_path + 'DepthType'):
                    depth_type = int(f[base_path + 'DepthType'][0])
                    if depth_type == 1:
                        depth_option = const.DEP_OPT_TRANSIENT
                    elif depth_type == 2:
                        depth_option = const.DEP_OPT_COUPLED
                self.data.info.attrs['depth'] = depth_option

                current_type = int(f[base_path + 'CurrentType'][0])
                current_option = const.OPT_NONE
                if current_type == 1:
                    current_option = const.OPT_DSET
                self.data.info.attrs['current_interaction'] = current_option

                fric_type = int(f[base_path + 'FrictionType'][0])
                option = const.OPT_NONE
                if fric_type == 1:
                    option = const.FRIC_OPT_JONSWAP_CONST
                    self.using_jonswap = True
                elif fric_type == 2:
                    option = const.FRIC_OPT_JONSWAP_DSET
                    self.using_jonswap = True
                elif fric_type == 3:
                    option = const.FRIC_OPT_MANNING_CONST
                elif fric_type == 4:
                    option = const.FRIC_OPT_MANNING_DSET
                self.data.info.attrs['friction'] = option

                break_type = int(f[base_path + 'BreakType'][0])
                option = const.BREAK_OPT_NONE
                if break_type == 1:
                    option = const.BREAK_OPT_WRITE
                elif break_type == 2:
                    option = const.BREAK_OPT_CALCULATE
                self.data.info.attrs['breaking_type'] = option

                rad_type = int(f[base_path + 'RadiationType'][0]) == 1
                self.data.info.attrs['rad_stress'] = 1 if rad_type else 0

                if f.__contains__(base_path + 'C2ShoreType'):
                    c2shore_type = int(f[base_path + 'C2ShoreType'][0]) == 1
                    self.data.info.attrs['c2shore'] = 1 if c2shore_type else 0

                surge_type = int(f[base_path + 'SurgeType'][0])
                option = const.OPT_CONST
                if surge_type == 1:
                    option = const.OPT_DSET
                self.data.info.attrs['surge'] = option

                wind_type = int(f[base_path + 'WindType'][0])
                option = const.OPT_CONST
                if wind_type == 1:
                    option = const.OPT_DSET
                self.data.info.attrs['wind'] = option

                if f.__contains__(base_path + 'IceType'):
                    ice_type = int(f[base_path + 'IceType'][0])
                    option = const.OPT_NONE
                    if ice_type == 1:
                        option = const.OPT_DSET
                    self.data.info.attrs['ice'] = option

                if f.__contains__(base_path + 'ComputerIProcs'):
                    i_type = int(f[base_path + 'ComputerIProcs'][0])
                    j_type = int(f[base_path + 'ComputerJProcs'][0])
                    init_itr = int(f[base_path + 'InitIterations'][0])
                    final_itr = int(f[base_path + 'FinalIterations'][0])
                    init_stop = float(f[base_path + 'StopInitValue'][0])
                    final_stop = float(f[base_path + 'StopFinalValue'][0])
                    init_per = float(f[base_path + 'StopInitPercent'][0])
                    final_per = float(f[base_path + 'StopFinalPercent'][0])

                    self.data.info.attrs['processors_i'] = i_type
                    self.data.info.attrs['processors_j'] = j_type
                    self.data.info.attrs['max_init_iters'] = init_itr
                    self.data.info.attrs['init_iters_stop_value'] = init_stop
                    self.data.info.attrs['init_iters_stop_percent'] = init_per
                    self.data.info.attrs['max_final_iters'] = final_itr
                    self.data.info.attrs['final_iters_stop_value'] = final_stop
                    self.data.info.attrs['final_iters_stop_percent'] = final_per

                if f.__contains__(base_path + 'Friction Value'):
                    fric_val = float(f[base_path + 'Friction Value'][0])
                    if self.using_jonswap:
                        self.data.info.attrs['JONSWAP'] = fric_val
                    else:
                        self.data.info.attrs['manning'] = fric_val

                if f.__contains__(base_path + 'Ice Threshold'):
                    ice_thresh = float(f[base_path + 'Ice Threshold'][0])
                    self.data.info.attrs['ice_threshold'] = ice_thresh

                bc_type = []
                if f.__contains__(base_path + 'BCTypes'):
                    bc_type = f[base_path + 'BCTypes']
                elif f.__contains__(base_path + 'BCTypes2'):
                    bc_type = f[base_path + 'BCTypes2']

                if bc_type and len(bc_type) == 4:
                    bc_type_str = [const.I_BC_SPECIFIED, const.I_BC_SPECIFIED, const.I_BC_SPECIFIED,
                                   const.I_BC_SPECIFIED]
                    for i, x in enumerate(bc_type):
                        if x > 0:
                            bc_type_int = int(bc_type[i]) + 1
                        else:
                            bc_type_int = int(bc_type[i])

                        if bc_type_int == 0:
                            bc_type_str[i] = const.I_BC_ZERO
                        elif bc_type_int == 3:
                            bc_type_str[i] = const.I_BC_LATERAL
                    self.data.info.attrs['side1'] = bc_type_str[0]
                    self.data.info.attrs['side2'] = bc_type_str[1]
                    self.data.info.attrs['side3'] = bc_type_str[2]
                    self.data.info.attrs['side4'] = bc_type_str[3]

                if f.__contains__(base_path + 'InterpType'):
                    interp_type = int(f[base_path + 'InterpType'][0])
                    option = const.INTERP_OPT_LINEAR
                    if interp_type == 1:
                        option = const.INTERP_OPT_MORPHIC
                    self.data.info.attrs['interpolation'] = option

                bc_source = 0
                if f.__contains__(base_path + 'BcSource'):
                    bc_source = int(f[base_path + 'BcSource'][0])
                    if bc_source > 0:
                        bc_source = bc_source - 1
                elif f.__contains__(base_path + 'BcSource2'):
                    bc_source = int(f[base_path + 'BcSource2'][0])
                option = const.SPEC_OPT_COV
                if bc_source != 0:
                    option = const.OPT_NONE
                self.data.info.attrs['boundary_source'] = option

                # get the names of the datasets referenced by the simulation
                grid_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, grid_uuid)
                if f.__contains__(base_path + 'CurrentFunc'):
                    dset_name = f[base_path + 'CurrentFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    self.data.info.attrs['current_uuid'] = dset_uuid
                if f.__contains__(base_path + 'DepthFunc'):
                    dset_name = f[base_path + 'DepthFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    self.data.info.attrs['depth_uuid'] = dset_uuid
                if f.__contains__(base_path + 'FrictionFunc'):  # JONSWAP or Manning's N
                    dset_name = f[base_path + 'FrictionFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    if self.using_jonswap:
                        self.data.info.attrs['JONSWAP_uuid'] = dset_uuid
                    else:
                        self.data.info.attrs['manning_uuid'] = dset_uuid
                if f.__contains__(base_path + 'IceFunc'):
                    dset_name = f[base_path + 'IceFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    self.data.info.attrs['ice_uuid'] = dset_uuid
                if f.__contains__(base_path + 'SurgeFunc'):
                    dset_name = f[base_path + 'SurgeFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    self.data.info.attrs['surge_uuid'] = dset_uuid
                if f.__contains__(base_path + 'WindFunc'):
                    dset_name = f[base_path + 'WindFunc'][0].decode('UTF-8')
                    child = tree_util.first_descendant_with_name(grid_item, dset_name)
                    dset_uuid = child.uuid
                    self.data.info.attrs['wind_uuid'] = dset_uuid

                # get the spectral coverage's uuid if there is one
                if f.__contains__(base_path + 'Spectral Coverage 1'):
                    spec_uuid = f[base_path + 'Spectral Coverage 1'][0].decode('UTF-8')
                    self.data.info.attrs['spectral_uuid'] = spec_uuid
                    self._link_items.append((sim_uuid, spec_uuid))

                # Timesteps stored as julians in TimeVec - no need for reftime/units
                case_times = numpy.array([])
                if f.__contains__(base_path + 'TimeVec'):
                    dset = f[base_path + 'TimeVec']
                    case_times = numpy.array(dset, dtype=numpy.float64)

                # Read const wind dir
                wind_dirs = None
                if f.__contains__(base_path + 'WindDirVec'):
                    dset = f[base_path + 'WindDirVec']
                    wind_dirs = numpy.array(dset, dtype=numpy.float64)

                # Read const wind mag
                wind_mags = None
                if f.__contains__(base_path + 'WindSpdVec'):
                    dset = f[base_path + 'WindSpdVec']
                    wind_mags = numpy.array(dset, dtype=numpy.float64)

                # Read const surge
                tide_lvls = None
                if f.__contains__(base_path + 'TideVec'):
                    dset = f[base_path + 'TideVec']
                    tide_lvls = numpy.array(dset, dtype=numpy.float64)

                # Convert absolute times to offsets in hours (the default units)
                if case_times is not None and len(case_times) > 0:
                    # Set the first case time as the reference time
                    ref_dt = julian_to_datetime(case_times[0])
                    self.data.info.attrs['reftime'] = ref_dt.strftime(ISO_DATETIME_FORMAT)
                    # Convert absolute julian date case times to offsets from the first case time
                    case_times -= case_times[0]
                    case_times *= 24.0  # Convert day offsets to hour offsets (default units in interface)
                    case_times = numpy.around(case_times, 3)  # Discard excess precision

                # Found some dirty projects in the wild that do not have case times for all rows.
                if wind_dirs is not None and len(wind_dirs) > len(case_times):
                    case_times = numpy.concatenate((case_times, numpy.zeros(len(wind_dirs) - len(case_times))), axis=0)
                # If times only but not wind dir, wind mags, and surge, clean up
                if case_times is not None and len(case_times) > 0:
                    if wind_dirs is None:
                        wind_dirs = numpy.zeros(len(case_times))
                    if wind_mags is None:
                        wind_mags = numpy.zeros(len(case_times))
                    if tide_lvls is None:
                        tide_lvls = numpy.zeros(len(case_times))
                case_time_data = simulation_data.case_data_table(case_times, wind_dirs, wind_mags, tide_lvls)
                self.data.case_times = pandas.DataFrame(case_time_data).to_xarray()

                # read const spec widgets
                min_freq = None
                delta_freq = None
                if f.__contains__(spec_path + 'MinFreq'):
                    min_freq = float(f[spec_path + 'MinFreq'][0])
                    self.data.info.attrs['min_frequency'] = min_freq
                if f.__contains__(spec_path + 'DeltaFreq'):
                    delta_freq = float(f[spec_path + 'DeltaFreq'][0])
                    self.data.info.attrs['delta_frequency'] = min_freq
                if min_freq and delta_freq and f.__contains__(spec_path + 'MaxFreq'):
                    max_freq = float(f[spec_path + 'MaxFreq'][0])
                    num_freq = math.ceil((max_freq - min_freq) / delta_freq)
                    self.data.info.attrs['num_frequencies'] = num_freq

                # hook up link to spectral subset coverage if one
                if f.__contains__(base_path + 'Spectral Subset'):
                    subset_uuid = f[base_path + 'Spectral Subset'][0].decode('UTF-8')
                    self.data.info.attrs['location_coverage'] = 1
                    self.data.info.attrs['location_coverage_uuid'] = subset_uuid

                # TODO: Fix this for the Z. May have to create a new RectilinearGrid
                # Setup Z for the grid. Need to switch impls for a minute.
                grid = RectilinearGrid(in_filename, grid_path)
                i_sizes = grid.i_sizes  # this makes us a SerialImpl
                j_sizes = grid.j_sizes
                if not depth_values:
                    depth_values = f[dset_path + '/Depth/Values']
                cell_zs = depth_values[:].tolist()[0]

                builder = RectilinearGridBuilder()
                builder.angle = grid.angle
                builder.origin = (grid.origin.x, grid.origin.y)
                builder.numbering = Numbering.kji
                builder.orientation = (Orientation.x_increase, Orientation.y_increase)
                builder.is_2d_grid = True
                builder.is_3d_grid = False

                locations_x = [0.0]
                locations_y = [0.0]
                offset = 0.0
                for i in i_sizes:
                    offset = offset + i
                    locations_x.append(offset)
                offset = 0.0
                for j in j_sizes:
                    offset = offset + j
                    locations_y.append(offset)
                builder.locations_x = locations_x
                builder.locations_y = locations_y
                rect_grid = builder.build_grid()
                rect_grid.cell_elevations = cell_zs  # already in elevations
                cogrid_file = os.path.join(self.query.process_temp_directory, 'stwave_domain.xmc')
                rect_grid.write_to_file(cogrid_file, True)
                ugrid = UGrid(cogrid_file, name=grid.name)
                ugrid.projection = grid.projection
                ugrid.uuid = str(uuid.uuid4())
                ugrid.uuid = grid_uuid
                self._link_items.append((sim_uuid, grid_uuid))
                # data.info.attrs['grid_uuid'] = grid_uuid
                grids[sim_uuid] = ugrid

                # read in coverages that used to be defined as cell attributes
                # set up variables needed to convert i,j to x,y
                if i_sizes and j_sizes:
                    dx = i_sizes[0]
                    dy = j_sizes[0]
                    origin = grid.origin
                    angle = grid.angle

                    if f.__contains__(base_path + 'CellTypes'):
                        station_pts = []  # can handle this here because it is a selector not a take
                        dset = f[base_path + 'CellTypes']
                        cell_types = numpy.array(dset)
                        for i in range(len(i_sizes)):
                            for j in range(len(j_sizes)):
                                idx = j * len(i_sizes) + i
                                type_enum = int(cell_types[idx])
                                if type_enum != 1:  # not default type
                                    ptx, pty = get_coords_from_ij(i + 1, j + 1, angle, origin.x,
                                                                  origin.y, dx, dy)
                                    cov_pt = Point(ptx, pty)
                                    if type_enum == 2:  # monitor pts
                                        cov_pt.id = len(self.monitor_pts) + 1
                                        self.monitor_pts.append(cov_pt)
                                    elif type_enum == 5:  # nesting pts
                                        cov_pt.id = len(self.nest_pts) + 1
                                        self.nest_pts.append(cov_pt)
                                    elif type_enum == 8:  # output station pts
                                        cov_pt.id = len(station_pts) + 1
                                        station_pts.append(cov_pt)

                        if station_pts:
                            output_cov = Coverage()
                            output_cov.name = 'STWAVE Station Points'
                            output_cov.set_points(station_pts)
                            output_cov.projection = grid.projection
                            output_cov.uuid = str(uuid.uuid4()).encode('ascii', 'ignore')
                            output_cov.complete()
                            self.data.info.attrs['output_stations'] = 1
                            self.data.info.attrs['output_stations_uuid'] = output_cov.uuid
                sim_data[sim_uuid] = self.data, sim, comp

        for sim_uuid, sim_tuple in sim_data.items():
            cur_sim_data = sim_tuple[0]
            sim = sim_tuple[1]
            comp = sim_tuple[2]
            # add the simulation
            self.query.add_simulation(sim, [comp])

            # add the simulation's CGrid
            # if sim_uuid in grids:
            #     self.query.add_ugrid(grids[sim_uuid])

            # add the datasets
            for dset in self.free_dsets[sim_uuid]:
                self.query.add_dataset(dset)

            # add optional takes
            if self.nest_pts:
                nest_cov = Coverage()
                nest_cov.name = 'STWAVE Nesting Points'
                nest_cov.set_points(self.nest_pts)
                # nest_cov.projection = grids[root_idx].projection
                nest_cov.uuid = str(uuid.uuid4())
                nest_cov.complete()
                cur_sim_data.info.attrs['nesting'] = 1
                cur_sim_data.info.attrs['nesting_uuid'] = nest_cov.uuid
                self.query.add_coverage(nest_cov)
            if self.monitor_pts:
                monitor_cov = Coverage()
                monitor_cov.name = 'STWAVE Monitoring Cells'
                monitor_cov.set_points(self.monitor_pts)
                # monitor_cov.projection = grids[root_idx].projection
                monitor_cov.uuid = str(uuid.uuid4())
                monitor_cov.complete()
                cur_sim_data.info.attrs['monitoring'] = 1
                cur_sim_data.info.attrs['monitoring_uuid'] = monitor_cov.uuid
                self.query.add_coverage(monitor_cov)
            cur_sim_data.commit()

        for link_items in self._link_items:
            self.query.link_item(link_items[0], link_items[1])

        self.query.send()  # send the data to SMS
