"""Converts HEC-RAS solutions to XMDF datasets suitable for loading into XMS."""
# 1. Standard python modules
import datetime
import logging
import os
import re
import uuid

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint.ugrid_builder import UGridBuilder
from xms.data_objects.parameters import Projection as DoProjection, UGrid as DoUGrid
from xms.datasets.dataset_writer import DatasetWriter
from xms.gdal.rasters import RasterInput
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules
from xms.hecras._OrientationChecker import OrientationChecker


class HecrasSolutionReader:
    """Importer for HEC-RAS solution."""

    RESULTS_GROUP_PATH = 'Results/Unsteady/Output/Output Blocks/Base Output/Unsteady Time Series/2D Flow Areas/'
    WSE_NAME = 'Water Surface'  # Special case

    def __init__(self, filename=None):
        """Constructor.

        Args:
            filename (str): Path to the HEC-RAS solution file. If not provided (not testing), will retrieve from XMS.
        """
        self._filename = filename
        self._geom_uuid = str(uuid.uuid4())
        self._geom_name = ''
        self._query = None
        self._h5_file = None  # h5py File handle to the H5 solution file
        self._geometry_results_group = None  # H5 group of the geometry results group (not the geometry itself)
        self._projection = None  # data_objects Projection in case we need to build geometry
        self._reftime = None  # Read from the solution file
        self._timesteps = None  # Read from the solution file
        self._activity_array = None  # Derived from WSE
        self._process_temp_dir = None  # Retrieved from XMS if not testing, deleted by XMS after process ends
        self._temp_dir = None  # Retrieved from XMS if not testing, persists after process ends
        self._sim_uuid = None  # UUID of the active HEC-RAS simulation, if it exists. Retrieved from XMS
        self._dset_uuids = None  # Randomized if not testing
        self._degenerate_polys = set()  # contains indices of polygons we skipped so we can skip their dataset values
        self._read_vector_dsets = set()  # Names of vector datasets we have already read
        self._logger = logging.getLogger('xms.hecras')

        # Can set in test or will be retrieved from Query
        self.project_tree = None

        # Output variables
        self.do_ugrids = []  # data_object UGrids for the 2D unstructured solution grid
        self.do_datasets = []  # Solution data_object datasets

        self._get_xms_data()  # If not testing, initializes communication with XMS

    def _get_xms_data(self):
        """Creates an xmsapi Query and retrieves data from XMS needed for import."""
        if not self._filename:  # pragma: no cover
            self._logger.info('Retrieving data from SMS...')
            self._query = Query()
            self._filename = self._query.read_file
            self._temp_dir = self._query.xms_temp_directory
            self._process_temp_dir = self._query.process_temp_directory
            self._sim_uuid = self._query.current_item_uuid()  # May or may not have one
            self.project_tree = self._query.project_tree

    def _next_dataset_uuid(self):
        """Returns a UUID to assign a dataset; hard-coded if testing, otherwise randomized."""
        return self._dset_uuids.pop() if self._dset_uuids else str(uuid.uuid4())

    def _reset(self):
        """Clear member variables before reading a new geometry."""
        self._geom_uuid = str(uuid.uuid4())
        self._geom_name = ''
        self._reftime = None
        self._timesteps = None
        self._activity_array = None
        self._degenerate_polys = set()
        self._read_vector_dsets = set()
        self._geometry_results_group = None

    def _h5_str_to_python(self, h5_bytes):
        """Returns a python str from an H5 string value.

        HEC-RAS has some funky way of writing strings. All the normal decoding methods fail.

        Args:
            h5_bytes (bytes): The bytes string value as read from the H5 file

        Returns:
            str: A usable string
        """
        python_str = str(h5_bytes)
        # I have tried all types of codecs to decode this string, but haven't found
        # the right one. It certainly isn't UTF-8 or ASCII. Just str it and deal with the
        # bytes specifier.
        if python_str.startswith('b\''):
            python_str = python_str[2:]
        if python_str.endswith('\''):
            python_str = python_str[:-1]
        return python_str.strip()

    def _find_compatible_ugrid(self, num_cells, num_points):
        """Look for an existing UGrid in SMS that we can read this solution onto.

        Notes:
            1 - Check if there is an active HEC-RAS simulation in SMS
                1a - If there is, check for a linked UGrid. If it has the right number of cells and points, use it
            2 - Check all UGrids currently loaded in SMS. If we find one with right number of cells and points, use it
            3 - If no compatible geometry is found, build the UGrid in the solution file

        Args:
            num_cells (int): The number of cells in the solution datasets
            num_points (int): The number of nodes in the solution datasets

        Returns:
            bool: True if we found an existing geometry to read the datasets onto
        """
        if not self.project_tree:
            return False

        self._logger.info('Looking for a compatible UGrid in SMS...')
        # Check for a UGrid linked to the active simulation first
        if self._sim_uuid:
            sim_item = tree_util.find_tree_node_by_uuid(self.project_tree, self._sim_uuid)
            ugrid_item = tree_util.descendants_of_type(sim_item, xms_types=['TI_UGRID_PTR'], allow_pointers=True,
                                                       recurse=False, only_first=True)
            if ugrid_item:
                if ugrid_item.num_cells == num_cells and ugrid_item.num_points == num_points:
                    self._logger.info(
                        'Found a compatible UGrid linked to the active HEC-RAS simulation. Solution data sets will be '
                        'read onto this geometry.'
                    )
                    self._geom_uuid = ugrid_item.uuid
                    return True

        # Check for any compatible UGrid currently loaded in SMS.
        ugrid_items = tree_util.descendants_of_type(self.project_tree, xms_types=['TI_UGRID', 'TI_UGRID_SMS'])
        for ugrid_item in ugrid_items:
            self._logger.info('Found a compatible UGrid in SMS. Solution data sets will be read onto this geometry.')
            if ugrid_item.num_cells == num_cells and ugrid_item.num_points == num_points:
                self._geom_uuid = ugrid_item.uuid
                return True

        # No compatible UGrid exists, build the one in the solution file.
        self._logger.info(
            'No compatible UGrid found in SMS. A UGrid will be created from data in the HEC-RAS solution file.'
        )
        return False

    def _read_solutions(self, results_group):
        """Reads the solution geometry from file and builds UGrid.

        Args:
            results_group (h5py.Group): The 'Results' H5 group
        """
        # Get the name of the geometry the results apply to
        for geometry_results_group in results_group:
            self._logger.info('Reading solution geometry metadata from file...')
            self._geom_name = self._h5_str_to_python(geometry_results_group)
            self._geometry_results_group = self._h5_file[f'{self.RESULTS_GROUP_PATH}{self._geom_name}']
            self._read_solution_geometry_if_needed()
            self._read_solution_datasets()
            self._reset()

    def _read_solution_datasets(self):
        """Read the HEC-RAS solution datasets for a single geometry."""
        # Read the timestep info
        self._read_timesteps()

        # Check for WSE first so we can derive the activity array from it.
        if self.WSE_NAME in self._geometry_results_group:
            self._read_cell_data(self.WSE_NAME)

        # Read the cell-based and solution datasets
        for dset_group in self._geometry_results_group:
            dset_name = self._h5_str_to_python(dset_group)
            if dset_name in ['Depth', 'Shear Stress']:  # Legacy cell outputs
                self._read_cell_data(dset_name)
            elif dset_name.startswith('Cell '):  # Optional cell outputs
                self._read_cell_data(dset_name)

        # Check if there is an old node-based velocity solution dataset
        self._read_node_velocity()  # Only in v5

    def _read_solution_geometry_if_needed(self):
        """Reads the solution geometry if needed."""
        self._logger.info('Checking for existing compatible UGrid...')
        geometry_group = self._h5_file[f'Geometry/2D Flow Areas/{self._geom_name}']

        # Extract the coordinates of the vertices
        dset = geometry_group['FacePoints Coordinate']
        npdset = np.array(dset)
        points = [(row[0], row[1], 0.0) for row in npdset]

        # Extract the polygon definitions
        dset = geometry_group['Cells FacePoint Indexes']
        npdset = np.array(dset)
        cellstream = self._create_cellstream(npdset, points)
        num_cells = len(npdset) - len(self._degenerate_polys)
        if not self._find_compatible_ugrid(num_cells, len(points)):
            self._write_solution_ugrid(points, cellstream)

    def _create_cellstream(self, cell_dset, points):
        """Creates the XmUGrid cellstream definition.

        Args:
            cell_dset (np.ndarray): The cell definitions in HEC-Ras solution file format
            points (list): The node location coordinates

        Returns:
            cellstream: The UGrid cell definitions in XmUGrid format.
        """
        self._logger.info('Generating UGrid cell definitions...')
        cellstream = []
        degenerate_polys = set()
        order_checker = OrientationChecker()
        for idx, rawrow in enumerate(cell_dset):
            poly_pts = [vertex for vertex in rawrow if vertex != -1]
            # Meshes created in HEC-RAS will have two element polygons along the boundary. Skip them.
            num_poly_pts = len(poly_pts)
            if num_poly_pts < 3:
                # HEC-RAS uses degenerate polygons to define the boundary, ignore them.
                degenerate_polys.add(idx)
                continue
            # Ensure the points are in tri or quad order (can be clockwise or counter-clockwise).
            result = order_checker.is_clockwise([points[vertex] for vertex in poly_pts], 2.0)
            if result is None:  # None=8-shaped, False=counter-clockwise, True=clockwise
                poly_pts = [poly_pts[1], poly_pts[0]] + poly_pts[2:]
            if num_poly_pts == 3:
                cellstream.append(XmUGrid.cell_type_enum.TRIANGLE)
            elif num_poly_pts == 4:
                cellstream.append(XmUGrid.cell_type_enum.QUAD)
            else:
                cellstream.append(XmUGrid.cell_type_enum.POLYGON)
            cellstream.append(num_poly_pts)
            cellstream.extend(poly_pts)
        # Convert set to a numpy for convenience with later operations.
        self._degenerate_polys = np.array(list(degenerate_polys))
        return cellstream

    def _write_solution_ugrid(self, points, cellstream):
        """Builds a 2D unstructured Grid that XMS can load cell-centered datasets onto.

        Args:
            points (list): The node location coordinates
            cellstream (list): The XmUGrid cellstream definition
        """
        self._logger.info('Writing a UGrid geometry from data in the solution file...')
        # Create a temp file in a place that will get cleaned up.
        cogrid_temp_file = os.path.join(self._process_temp_dir, 'temp.cogrid')

        # Build the unconstrained UGrid
        xmugrid = XmUGrid(points, cellstream)
        terrain_file = self._find_terrain_file()
        cell_elevations = self._interpolate_raster_elevations(terrain_file, xmugrid)
        co_builder = UGridBuilder()
        co_builder.set_unconstrained()
        co_builder.set_ugrid(xmugrid)
        cogrid = co_builder.build_grid()
        if len(cell_elevations) > 0:
            cogrid.cell_elevations = cell_elevations
        else:  # Warn if we couldn't find cell elevations
            self._logger.warning(
                'Unable to extract cell elevations from raster. Solution UGrid Z will be uninitialized.'
            )
        # Write the constrained grid file
        os.makedirs(os.path.dirname(cogrid_temp_file), exist_ok=True)
        cogrid.write_to_file(cogrid_temp_file, True)

        # Build the api UGrid object - Guess the display projection better match.
        self.do_ugrids.append(DoUGrid(cogrid_temp_file, name=self._geom_name, uuid=self._geom_uuid,
                                      projection=self._projection))

    def _read_timesteps(self):
        """Reads timesteps and reference timestamp from file.

        Returns:
            tuple(datetime.datetime, np.ndarray): Reference time stamp, timestep times
        """
        self._logger.info('Reading solution time steps...')
        # Get the time steps group
        unsteady_group = self._h5_file['Results/Unsteady/Output/Output Blocks/Base Output/Unsteady Time Series']
        time_dset = unsteady_group['Time']
        # Extract the timestep time offset dataset
        self._timesteps = time_dset[:]
        # Get the reftime
        timestamp_dset = unsteady_group['Time Date Stamp']
        self._reftime = self._parse_reftime(timestamp_dset[0].decode('utf-8'))

    def _parse_reftime(self, reftime):
        """Parses a timestamp from a string in the HEC-RAS solution file format.

        Args:
            reftime (str): The timestamp string as read from the solution file
                Expected format: 'DDMMMYYYY HH:MM:SS'
                                 '12FEB2021 10:00:00'

        Returns:
            datetime.datetime: The solution dataset reference timestamp
        """
        pattern = re.compile(
            '([0-9]+)([a-zA-Z]+)([0-9]+) ([0-9]+):([0-9]+):([0-9]+)'
        )
        match = pattern.match(reftime)
        # Parse the calendar date portion
        day = int(match.group(1))
        month = self._month_str_to_int(match.group(2))
        year = int(match.group(3))
        date = datetime.datetime(year=year, month=month, day=day)
        # Parse the time of day portion
        hour = int(match.group(4))
        minute = int(match.group(5))
        sec = int(match.group(6))
        time_of_day = datetime.timedelta(hours=hour, minutes=minute,
                                         seconds=sec)
        return date + time_of_day

    def _month_str_to_int(self, month_str):
        """Converts a three letter month abbreviation to its corresponding integer.

        Args:
            month_str (str): Three letter abbreviation of the month

        Returns:
            int: The month integer [1-12]
        """
        month_str = month_str.upper()
        if month_str == 'JAN':
            return 1
        elif month_str == 'FEB':
            return 2
        elif month_str == 'MAR':
            return 3
        elif month_str == 'APR':
            return 4
        elif month_str == 'MAY':
            return 5
        elif month_str == 'JUN':
            return 6
        elif month_str == 'JUL':
            return 7
        elif month_str == 'AUG':
            return 8
        elif month_str == 'SEP':
            return 9
        elif month_str == 'OCT':
            return 10
        elif month_str == 'NOV':
            return 11
        else:
            return 12

    def _read_projection(self):
        """Parse the solution projection in case we need to build geometry."""
        if 'Projection' not in self._h5_file.attrs:
            return

        wkt = self._h5_str_to_python(self._h5_file.attrs['Projection'])
        self._projection = DoProjection(wkt=wkt)

    def _derive_activity_array_from_wse(self, wse_values, degenerate_polys_mask):
        """Create an activity area given a WSE solution dataset.

        Args:
            wse_values (np.ndarray): The WSE solution values (all timesteps)
            degenerate_polys_mask (np.ndarray): Mask to filter out the degenerate polygons
        """
        if wse_values.size <= 0:
            return
        geometry_group = self._h5_file[f'Geometry/2D Flow Areas/{self._geom_name}']
        if 'Cells Minimum Elevation' not in geometry_group:
            return

        # If the WSE is less than or equal to the minimum cell elevation, that cell is dry.
        minimum_cell_elevation = geometry_group['Cells Minimum Elevation'][:]
        minimum_cell_elevation = minimum_cell_elevation[degenerate_polys_mask]
        self._activity_array = np.zeros(wse_values.shape, dtype='u1')
        self._activity_array[wse_values > minimum_cell_elevation] = 1

    def _read_cell_data(self, dset_name):
        """Reads a cell solution dataset.

        Args:
            dset_name (str): H5 group name of the dataset to look for in file.
        """
        self._logger.info(f'Reading solution data set - {dset_name}...')
        # Read the values from the solution file.
        if dset_name.endswith(' X') or dset_name.endswith(' Y'):  # Cell vector dataset
            base_name = dset_name[:-2]
            if base_name in self._read_vector_dsets:
                return  # Already read this vector dataset
            self._read_vector_dsets.add(base_name)  # Don't read the other component when we get to it
            vx = self._geometry_results_group[f'{base_name} X'][:]
            vy = self._geometry_results_group[f'{base_name} Y'][:]
            values = np.stack([vx, vy], axis=2)
            pos = base_name.find(' - ')
            if pos > 0:  # Cleanup the name for the GUI
                dset_name = dset_name[:pos]
        else:
            values = self._geometry_results_group[dset_name][:]
        if values.size > 0:  # Have data to write
            self._write_xmdf_dataset(dset_name, values)

    def _write_xmdf_dataset(self, name, values, cell_based=True):
        """Creates an XMDF formatted file for a solution dataset.

        Args:
            name (str): Name of dataset.
            values (np.ndarray): The dataset values.
            cell_based (bool): True if a cell-based dataset, False otherwise
        """
        dset_uuid = self._next_dataset_uuid()
        filename = os.path.join(self._temp_dir, f'{dset_uuid}.h5')

        if cell_based:  # Filter out degenerate polygons if cell-based
            mask = np.ones(values.shape[1], dtype=bool)
            mask[self._degenerate_polys] = False
            values = values[:, mask]
            if name == self.WSE_NAME:  # Initialize the activity array from WSE (needs to be first dataset read)
                self._derive_activity_array_from_wse(values, mask)

        use_activity = cell_based and self._activity_array is not None
        dset_builder = DatasetWriter(h5_filename=filename, name=name, dset_uuid=dset_uuid, geom_uuid=self._geom_uuid,
                                     ref_time=self._reftime, time_units='Days', num_components=len(values.shape) - 1,
                                     location='cells' if cell_based else 'points', use_activity_as_null=use_activity)
        # Write an XMDF formatted file that SMS can read.
        dset_builder.write_xmdf_dataset(times=self._timesteps, data=values,
                                        activity=self._activity_array if use_activity else None)
        self.do_datasets.append(dset_builder)

    def _read_node_velocity(self):
        """Reads legacy node-based velocity solution dataset, if it exists.

        Notes:
            Starting with v6 HEC-RAS no longer writes the nodal velocity dataset. Velocity was also the only
            nodal dataset they ever outputted.
        """
        if 'Node X Vel' not in self._geometry_results_group or 'Node Y Vel' not in self._geometry_results_group:
            return  # Not a legacy solution
        self._logger.info('Reading solution data set - Velocity...')
        # Read the x and y values from the solution file.
        x_data = self._geometry_results_group['Node X Vel'][:]
        y_data = self._geometry_results_group['Node Y Vel'][:]
        if x_data.size > 0:  # Have data to write
            vector_data = np.stack([x_data, y_data], axis=2)  # Transform data into dimensions required by the XMDF
            self._write_xmdf_dataset('Velocity', vector_data, cell_based=False)

    def _find_terrain_file(self):
        """Search the solution file references for a terrain raster.

        Returns:
            str: Path to the terrain raster or empty string if not found.
        """
        terrain_group = self._h5_file[f'Geometry/2D Flow Areas/{self._geom_name}']
        if 'Terrain Filename' not in terrain_group.attrs:
            return ''
        filename = self._h5_str_to_python(terrain_group.attrs['Terrain Filename'])
        filename = os.path.join(os.path.dirname(self._filename), filename)
        if not os.path.exists(filename):
            return ''

        # Check for a .vrt file with a basename that matches the terrain .hdf basename. We couldn't find anywhere
        # the .vrt file is referenced in the HEC-RAS files. Hope it is always this pattern.
        vrt_filename = os.path.join(os.path.dirname(self._filename), f'{os.path.splitext(filename)[0]}.vrt')
        if os.path.isfile(vrt_filename):
            return vrt_filename

        # Open the other h5 file to get the tif name
        terrain_file = ''
        with h5py.File(filename, 'r') as terrain_h5_file:
            # Assume the .tif file is in the same directory as the .h5 file
            if 'Terrain' in terrain_h5_file:
                terrain_group = terrain_h5_file['Terrain']
                for sub_group in terrain_group:
                    if 'File' in terrain_group[sub_group].attrs:
                        terrain_file = self._h5_str_to_python(terrain_group[sub_group].attrs['File'])
                        break

        return os.path.join(os.path.dirname(filename), terrain_file)

    def _interpolate_raster_elevations(self, raster_filename, ugrid):
        """Extract cell elevations from a single raster or a VRT file.

        Args:
            raster_filename (str): Path to the raster file
            ugrid (XmUGrid): The target geometry

        Returns:
            list of float: The extracted cell elevations
        """
        if not os.path.isfile(raster_filename):
            return []

        self._logger.info('Extracting cell elevations from terrain raster...')
        cell_centroids = [ugrid.get_cell_centroid(i)[1] for i in range(ugrid.cell_count)]

        # Read raster metadata
        raster_file = RasterInput(raster_filename)
        xorigin = raster_file.xorigin
        yorigin = raster_file.yorigin
        pixel_width = raster_file.pixel_width
        pixel_height = raster_file.pixel_height
        xsize, ysize = raster_file.resolution
        no_data_value = raster_file.nodata_value

        # Interpolate the raster values to the grid cell centroids
        num_points = len(cell_centroids)
        dataset_vals = np.zeros(num_points)
        no_data_cells = []
        for i, point in enumerate(cell_centroids):
            xoff = int((point[0] - xorigin) / pixel_width)  # Find the pixel containing the point
            yoff = int((point[1] - yorigin) / pixel_height)
            if xoff < 0 or xoff > xsize or yoff < 0 or yoff > ysize:
                dataset_vals[i] = no_data_value
                no_data_cells.append(str(i + 1))
                continue  # Out of raster bounds
            # Read the raster value using the bilinear resampling option. Slower than reading the entire array into
            # memory all at once, but we want to do a little more than just the pixel value.
            elevation = raster_file.get_raster_values(xoff, yoff, 1, 1, resample_alg='bilinear')
            dataset_vals[i] = elevation[0][0]
            if elevation == no_data_value:  # In-bounds but no data, assign no data value and warn user
                no_data_cells.append(str(i + 1))

        # Report warnings for cells that we could not extract an elevation for
        if no_data_cells:
            self._logger.warning(f'Could not extract elevations for the following cells: {", ".join(no_data_cells)}')

        return dataset_vals

    def _add_built_data(self):
        """Send built data to XMS if not testing."""
        if self._query:  # pragma: no cover
            # Add the solution geometries.
            for do_ugrid in self.do_ugrids:
                self._query.add_ugrid(do_ugrid)
            # Add the solution datasets.
            for do_dset in self.do_datasets:
                self._query.add_dataset(do_dset)

    def read(self):
        """Reads the Solution from file and sends it to SMS."""
        with h5py.File(self._filename, 'r') as self._h5_file:
            self._read_projection()
            results_group = self._h5_file[self.RESULTS_GROUP_PATH]
            self._read_solutions(results_group)
        # Send built data to SMS
        self._add_built_data()

    def send(self):
        """Send data to SMS if everything went OK."""
        if self._query:
            self._query.send()
