"""This module writes out the mp file for CMS-Flow."""

# 1. Standard Python modules
import datetime
from functools import cached_property
import logging
import math
from typing import Any, Callable, cast, Iterable, Optional, Sequence, TypeAlias

# 2. Third party modules
import h5py
import numpy
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import ModelCheckError
from xms.constraint import UGrid2d
from xms.data_objects.parameters import Arc, Coverage
from xms.datasets.dataset_reader import DatasetReader
from xms.gdal.utilities.gdal_utils import is_geographic, is_local, transform_points_from_wkt
from xms.grid.geometry.geometry import angle_between_edges_2d
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.cmsflow.data.bc_data import BCData
from xms.cmsflow.dmi._cmsflow_model_check import check_forcing_errors
from xms.cmsflow.dmi.xms_data import XmsData
from xms.cmsflow.extraction.transient_dataset_extractor import TransientDatasetExtractor
from xms.cmsflow.feedback.xmlog import XmLog
from xms.cmsflow.mapping.coverage_mapper import CoverageMapper

DEFAULT_GEOGRAPHIC_WKT = (
    """
    GEOGCS["Geographic Coordinate System",DATUM["D_NORTH_AMERICAN_1983",SPHEROID["GRS_1980",6378137,298.257222101]],
    PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]
    """
)

Pt2d: TypeAlias = tuple[float, float, float]


class CMSFlowMpExporter:
    """Exporter for the CMS-Flow _mp.h5 file."""
    def __init__(self, data: XmsData, cov_mapper: CoverageMapper = None):
        """This constructs member variables for the export class."""
        self._params_group: Optional[h5py.Group] = None  # Essentially the root group in the file (as far as CMS cares).
        # I left cov_mapper defaultable for tests. When exporting, use the same mapper as the other files.
        self._cov_mapper = cov_mapper if cov_mapper else CoverageMapper(False)
        self._init_mapper = cov_mapper is None  # Init for tests
        self._arc_ex_snapper = None
        self._snap_id_to_original_id = None
        self._data = data
        self._groups_needing_extraction: list[h5py.Group] = []
        self._cells_on_arcs: dict[int, int] = {}
        # feature_id, cell_id, time_step
        self._elevation_nans: set[tuple[int, int, int]] = set()
        self._velocity_nans: set[tuple[int, int, int]] = set()
        # feature_id, cell_id, is_elevation
        self._found_nans: set[tuple[int, int, bool]] = set()

    def _export_meteorological(self):
        """Writes meteorological stations or constant wind data to file if applicable."""
        # See if Meteorological stations are turned on
        wind_types = {
            'None': 'None',
            'Spatially constant': 'Constant',
            'Meteorological stations': 'Stations',
            'Temporally and spatially varying from file': 'File'
        }
        wind_type = self._data.sim_data.wind.attrs['WIND_TYPE']
        wind_type = wind_types[wind_type]
        if wind_type == 'Stations':
            station_grp = self._params_group.create_group('Met Stations')
            curve_ids = self._data.sim_data.meteorological_stations_table.direction.data.tolist()
            for idx, curve_id in enumerate(curve_ids):
                curve_grp = station_grp.create_group(f'Sta{str(idx+1)}')
                curve_id = int(curve_id)
                curve = self._data.sim_data.direction_curve_from_meteorological_station(curve_id)
                np_times = curve['time'].data.tolist()
                np_dirs = curve['direction'].data.tolist()
                np_velocity = curve['velocity'].data.tolist()
                curve_grp.create_dataset('Times', data=np_times)
                curve_grp.create_dataset('Direction', data=np_dirs)
                curve_grp.create_dataset('Magnitude', data=np_velocity)
        if wind_type in ['Constant']:  # Removed 'Stations'; didn't need the info written twice
            # Write the wind curve table
            np_times = self._data.sim_data.wind_from_table.time.data
            np_velocity = self._data.sim_data.wind_from_table.velocity.data
            np_dirs = self._data.sim_data.wind_from_table.direction.data
            wind_grp = self._params_group.create_group('WindCurve')
            wind_grp.create_dataset('Direction', data=np_dirs)
            wind_grp.create_dataset('Magnitude', data=np_velocity)
            wind_grp.create_dataset('Times', data=np_times)

    def _export_temperature(self):
        """Writes temperature data to file if applicable."""
        temp = self._data.sim_data.salinity.attrs['CALCULATE_TEMPERATURE']
        if temp == 1:
            np_times = self._data.sim_data.atmospheric_table.time.data
            np_air_temps = self._data.sim_data.atmospheric_table.air_temp.data
            np_dew_points = self._data.sim_data.atmospheric_table.dewpoint.data
            np_cloud_covers = self._data.sim_data.atmospheric_table.cloud_cover.data
            np_solar_rads = self._data.sim_data.atmospheric_table.solar_radiation.data

            temp_grp = self._params_group.create_group('TemperatureParameters')
            temp_grp.create_dataset('Times', data=np_times)
            temp_grp.create_dataset('AirTemp', data=np_air_temps)
            temp_grp.create_dataset('DewPoint', data=np_dew_points)
            temp_grp.create_dataset('CloudCover', data=np_cloud_covers)
            temp_grp.create_dataset('SolarRadiation', data=np_solar_rads)

    def _add_curve(
        self,
        attributes: xr.Dataset,
        control_attribute_name: str,
        curve_attribute_name: str,
        arc_grp: h5py.Group,
        curve_function: Callable,
        x_title: str,
        y_title: str,
        force_control_is_flag: bool = False,
    ):
        """
        Adds curve to group.

        Args:
            attributes: Attributes for the arc. Should be filtered to just the row for this arc.
            control_attribute_name: Name of the attribute that controls whether the curve should be written.
            curve_attribute_name: Name of the attribute that contains the curve's ID.
            arc_grp: H5 group for the arc's data. Should be at a location like `PROPERTIES/Model Params//Boundary_#`.
            curve_function: The method to call to get the curve.
            x_title: The title for the "x" data.
            y_title: The title for the "y" data.
            force_control_is_flag: Treat the control attribute as containing 0/1 instead of 'Constant'/'Curve'.
        """
        bc_type = attributes['bc_type'].item()

        if (bc_type == 'Flow rate-forcing' or bc_type == 'WSE-forcing') and not force_control_is_flag:
            # make it False if we aren't using a curve
            write = attributes[control_attribute_name].item() == 'Curve'
        else:
            write = attributes[control_attribute_name].item() == 1

        if not write:
            return

        curve = curve_function(attributes[curve_attribute_name].item(), False)
        if curve:
            columns = list(curve.keys())
            xs = curve[columns[0]].values
            ys = curve[columns[1]].values
        else:
            # This can happen if the user enables the option that requires the curve, but doesn't actually define it.
            XmLog().instance.warning(
                f'Zero curve location values found for {x_title} vs. {y_title}. Creating an '
                'empty dataset'
            )
            xs = []
            ys = []

        arc_grp.create_dataset(x_title, data=xs)
        arc_grp.create_dataset(y_title, data=ys)

    def _export_boundaries(self, bc_cov: Coverage):
        """
        Writes boundary data to file.

        Args:
            bc_cov: The boundary conditions coverage geometry.
        """
        self._arc_ex_snapper = self._cov_mapper.get_bc_arc_snapper()
        self._snap_id_to_original_id = self._cov_mapper.get_snap_id_to_original_id()

        for arc in bc_cov.arcs:
            self._export_boundary(arc)

    def _error_checks_ok(self, wkt: str) -> bool:
        """
        Check if the arcs in the coverage look okay to export.

        Logs errors informing the user what's wrong if any problems are found.

        Returns:
            Whether the arcs look safe to export. If False, then exporting the arcs won't work.
        """
        _coverage, component = self._data.bc_coverage
        if not component or not component.comp_to_xms[component.cov_uuid][TargetType.arc]:
            XmLog().instance.error('No assigned boundary condition arcs. Aborting.')
            return False

        errors = check_forcing_errors(self._data)

        if self._data.sim_data.flow.attrs.get('LATITUDE_CORIOLIS', '') == 'From projection':
            if is_local(wkt):
                short = 'Invalid average latitude specified in model control.'
                long = (
                    'The Average Latitude for Coriolis is set to be computed from the projection, but the projection '
                    'is local and has no latitude information.'
                )
                fix = (
                    'Assign the linked UGrid a global or geographic projection, or change the Average Latitude for '
                    'Coriolis setting in the model control Flow tab to a constant.'
                )
                errors.append(ModelCheckError(problem=short, description=long, fix=fix))

        if not errors:
            return True

        for error in errors:
            XmLog().instance.error(error.problem_text)
        return False

    def _warn_if_necessary(self):
        """Output any warnings that can be warned about prior to exporting."""
        if not _uses_extracted_wse_forcing(self._data):
            return

        _coverage, component = self._data.bc_coverage
        data: BCData = component.data
        _start, end = self._get_simulation_start_and_end()

        elevation = self._data.get_dataset(data.wse_forcing_wse_source)
        last_elevation_index = elevation.num_times - 1
        last_elevation_offset = elevation.timestep_offset(last_elevation_index)
        last_elevation_time = elevation.ref_time + last_elevation_offset
        if last_elevation_time < end:
            path = self._data.tree_path(data.wse_forcing_wse_source)
            XmLog().instance.warning(f'The dataset "{path}" ends before the simulation.')

        if not data.wse_forcing_velocity_source:
            return

        velocity = self._data.get_dataset(data.wse_forcing_velocity_source)
        last_velocity_index = velocity.num_times - 1
        last_velocity_offset = velocity.timestep_offset(last_velocity_index)
        last_velocity_time = velocity.ref_time + last_velocity_offset
        if last_velocity_time < end:
            path = self._data.tree_path(data.wse_forcing_velocity_source)
            XmLog().instance.warning(f'The dataset "{path}" ends before the simulation.')

    def _export_boundary(self, arc: Arc):
        """
        Export all the information for an arc.

        If an arc is not assigned, nothing will be written for it.

        Args:
            arc: The arc to export information for.
        """
        if arc.id not in self._component_id_map:
            return

        _coverage, component = self._data.bc_coverage
        component_id = self._component_id_map[arc.id]
        attributes = component.data.arcs.loc[{'comp_id': component_id}]
        bc_type = attributes['bc_type'].item()

        arc_grp = self._params_group.create_group(f'Boundary_#{arc.id}')

        self._export_cells(arc, arc_grp)
        self._export_temp_and_salinity(attributes, arc_grp)

        if bc_type == 'Flow rate-forcing':
            self._add_curve(
                attributes, 'flow_source', 'flow_curve', arc_grp, component.data.flow_curve_from_id, 'Times', 'Flow'
            )
        elif bc_type == 'WSE-forcing':
            self._export_wse_curves(attributes, arc_grp)
            if attributes['wse_source'].item() == 'Extracted':
                self._groups_needing_extraction.append(arc_grp)

    def _export_temp_and_salinity(self, attributes: xr.Dataset, arc_grp: h5py.Group):
        """
        Export an arc's temperature and salinity curves.

        Args:
            attributes: The arc's attributes.
            arc_grp: The arc's group in the .h5 file.
        """
        _coverage, component = self._data.bc_coverage
        # salinity
        self._add_curve(
            attributes,
            'use_salinity_curve',
            'salinity_curve',
            arc_grp,
            component.data.salinity_curve_from_id,
            'Sal_Times',
            'Salinity',
            force_control_is_flag=True
        )

        # temperature
        self._add_curve(
            attributes,
            'use_temperature_curve',
            'temperature_curve',
            arc_grp,
            component.data.temperature_curve_from_id,
            'Temp_Times',
            'Temperature',
            force_control_is_flag=True
        )

    def _export_wse_curves(self, attributes: xr.Dataset, arc_grp: h5py.Group):
        """
        Export the WSE curves for an arc.

        Args:
            attributes: The arc's attributes.
            arc_grp: The arc's group in the .h5 file.
        """
        _coverage, component = self._data.bc_coverage
        self._add_curve(
            attributes, 'wse_source', 'wse_forcing_curve', arc_grp, component.data.wse_forcing_curve_from_id, 'Times',
            'WaterLevel'
        )
        self._add_curve(
            attributes, 'wse_offset_type', 'wse_offset_curve', arc_grp, component.data.wse_offset_curve_from_id,
            'Offset_Times', 'Offset'
        )

    def _export_cells(self, arc: Arc, arc_grp: h5py.Group):
        """
        Export the cells for an arc.

        Args:
            arc: The arc to export cells for.
            arc_grp: The arc's group in the .h5 file.
        """
        # get the cell ids of the snapped arc locations.
        snap_arc = self._arc_ex_snapper.get_snapped_points(arc)
        cell_ids = snap_arc['id']
        # convert first from small quadtree to original, then from 0 based to 1 based
        cell_ids = [self._snap_id_to_original_id[cell_id] + 1 for cell_id in cell_ids]

        if cell_ids[0] == cell_ids[-1]:
            # When users make arc loops, they expect the same BC to be applied to each cell on the arc. But the snapper
            # duplicates the start and end cell for loops, and CMS responds to duplicate cells by doubling the applied
            # values, which violates user expectations. Discarding the duplicate cell means this technically doesn't
            # match what the user created, but it *does* match what the user *wanted*, which is what matters.
            cell_ids.pop()

        if len(cell_ids) < 2:
            XmLog().instance.error(f'Arc {arc.id} was degenerate.')

        for cell_id in cell_ids:
            if cell_id in self._cells_on_arcs:
                message = f'Arcs {arc.id} and {self._cells_on_arcs[cell_id]} both snapped to cell {cell_id}.'
                XmLog().instance.error(message)
                return
            else:
                self._cells_on_arcs[cell_id] = arc.id

        arc_grp.create_dataset('Cells', data=numpy.array(cell_ids))

    @cached_property
    def _forced_velocity_dataset(self) -> Optional[DatasetReader]:
        """The simulation's velocity forcing dataset."""
        _coverage, component = self._data.bc_coverage
        dataset_uuid = component.data.wse_forcing_velocity_source
        if dataset_uuid:
            reader = self._data.get_dataset(dataset_uuid)
            return reader
        else:
            return None

    @cached_property
    def _forced_elevation_dataset(self) -> Optional[DatasetReader]:
        """The simulation's elevation forcing dataset."""
        _coverage, component = self._data.bc_coverage
        dataset_uuid = component.data.wse_forcing_wse_source
        reader = self._data.get_dataset(dataset_uuid)
        return reader

    @cached_property
    def _forced_elevation_geometry(self) -> Optional[UGrid2d]:
        """
        The simulation's elevation forcing dataset.

        Will be None if there is no dataset.
        """
        _coverage, component = self._data.bc_coverage
        ugrid_uuid = component.data.wse_forcing_geometry
        ugrid = self._data.get_ugrid(ugrid_uuid)
        return ugrid

    @cached_property
    def _component_id_map(self) -> dict[int, int]:
        """A mapping from feature_id -> component_id for all arcs in the BC coverage."""
        _coverage, component = self._data.bc_coverage

        if not component.cov_uuid:  # pragma: nocover
            raise AssertionError('Coverage component not initialized with coverage UUID.')

        if component.comp_to_xms and component.cov_uuid not in component.comp_to_xms:  # pragma: nocover
            raise AssertionError('Coverage component initialized with wrong coverage UUID.')

        if not component.comp_to_xms:
            logging.getLogger('xms.cmsflow').error('No assigned boundary condition arcs. Aborting.')
            return {}

        comp_to_xms = component.comp_to_xms[component.cov_uuid]
        arc_comp_to_xms = comp_to_xms.get(TargetType.arc, {})
        component_id_map = {}
        for component_id in arc_comp_to_xms:
            for feature_id in arc_comp_to_xms[component_id]:
                component_id_map[feature_id] = component_id

        return component_id_map

    def export(self):
        """Write the CMS-Flow _mp.h5 file."""
        if not self._data.ugrid:
            logging.getLogger('xms.cmsflow').error('Unable to find CMS-Flow domain grid. Aborting.')
            return

        coverage, component = self._data.bc_coverage
        if not coverage:
            logging.getLogger('xms.cmsflow').error('Unable to find CMS-Flow boundary conditions. Aborting.')
            return

        proj_name = self._data.project_path.stem

        co_grid = self._data.ugrid
        wkt = self._data.ugrid_projection.well_known_text
        if self._init_mapper:  # This should really just be in the tests.
            if self._data.ugrid and coverage:
                self._cov_mapper.set_boundary_conditions(coverage, component)
            self._cov_mapper.set_quadtree(co_grid, wkt)
            self._cov_mapper.set_activity(self._data.activity_coverage)

        self._write_mp_file(proj_name, coverage, co_grid, wkt)

    def _write_mp_file(self, proj_name, bc_cov, co_grid, wkt):
        """This opens the mp file for writing and writes the contents if the quadtree is present.

        Args:
            proj_name (str): The name of the CMS-Flow project.
            bc_cov (Coverage): The boundary conditions coverage geometry.
            co_grid (CoGrid): The quadtree geometry.
            wkt: Well-known text of the projection of `co_grid`.
        """
        with h5py.File(f'{proj_name}_mp.h5', 'w') as file:
            if self._error_checks_ok(wkt):
                self._warn_if_necessary()
                self._export_default_groups(file, co_grid.uuid)

                # Export meteorological data
                self._export_meteorological()
                self._export_temperature()
                self._export_boundaries(bc_cov)

                self._export_lats_and_lons(wkt, co_grid, file)
                self._export_extracted_data()

    def _export_default_groups(self, file, quad_uuid):
        """Export the default groups and properties.

        Args:
            file (h5py.File): The h5 file being written to.
            quad_uuid (str): The UUID of the quadtree geometry.
        """
        # Make this a pseudo XMDF file. Won't really use the XMDF library, just h5py.
        file_type = file.create_dataset('File Type', (1, ), dtype='S5')
        file_type[0] = numpy.array(bytes("Xmdf", 'utf-8'))
        file_version = file.create_dataset('File Version', (1, ), dtype='f')
        file_version[0] = numpy.array([99.99])
        # Create the root dataset group.
        props = file.create_group("PROPERTIES")
        # Write the geometry's UUID to the file.
        uuid = props.create_dataset("Guid", (1, ), 'S36')
        uuid[0] = bytes(quad_uuid, 'utf-8')
        # Write the Model property
        model = props.create_dataset("Model", (1, ), 'S9')
        model[0] = bytes('CMS-FLOW', 'utf-8')
        # Create Model Params group
        self._params_group = props.create_group("Model Params")

    def _export_lats_and_lons(self, projection: str, ugrid: UGrid, file: h5py.File):
        """
        Export the latitude and longitude datasets, if necessary.

        Args:
            projection: Well-known text of the projection that `ugrid` is in.
            ugrid: UGrid containing geometry to export latitudes and longitudes of.
            file: Where to export the datasets to.
        """
        if self._data.sim_data.flow.attrs.get('LATITUDE_CORIOLIS', '') != 'From projection':
            return  # Only need to export these if we're getting Coriolis from a projection.

        # This can raise an exception if the projection is local, but the pre-checks abort the export before we get
        # here if that's the case.
        projected_points = make_geographic(projection, ugrid)

        lons = [point[0] for point in projected_points]
        file.create_dataset('PROPERTIES/Model Params/Lons', data=lons)

        lats = [point[1] for point in projected_points]
        file.create_dataset('PROPERTIES/Model Params/Lats', data=lats)

    def _export_extracted_data(self):
        """Write all the WSE forcing data for extracted WSE forcing arcs."""
        if not self._groups_needing_extraction:
            return

        self._export_extracted_times()
        self._export_extracted_elevation()
        self._report_nans(True)
        self._export_extracted_velocity()
        self._report_nans(False)

    def _export_extracted_times(self):
        """Export extracted time steps."""
        seconds_per_hour = 60 * 60
        dataset = self._forced_elevation_dataset
        start_time, end_time = self._get_simulation_start_and_end()
        start_index, end_index = get_indexes_within_simulation(dataset, start_time, end_time)
        dataset_offsets = [dataset.timestep_offset(index) for index in range(start_index, end_index)]
        dataset_start_time = dataset.ref_time
        absolute_times = [dataset_start_time + offset for offset in dataset_offsets]
        simulation_times = [absolute_time - start_time for absolute_time in absolute_times]
        hour_offsets = [time.total_seconds() / seconds_per_hour for time in simulation_times]
        # get_indexes_within_simulation will give us the index just before the simulation starts if it doesn't match
        # up exactly. We shift the time forward to match the start of the simulation exactly, so there's always a time
        # right at the start.
        hour_offsets[0] = 0.0

        for group in self._groups_needing_extraction:
            group.create_dataset('WSE_Times', data=hour_offsets)

    def _export_extracted_elevation(self):
        """Export extracted elevation data."""
        extractor = self._make_elevation_extractor()

        for group in self._groups_needing_extraction:
            wse_group = group.create_group('WSE')
            for cell in group['Cells']:
                name = f'WaterLevel_{cell}'
                dataset = extractor.scalars_for(cell)
                self._check_extracted_data_for_nan(group.name, cell, dataset, True)
                wse_group.create_dataset(name, data=dataset)

    def _check_extracted_data_for_nan(self, group_name: str, cell: int, dataset: np.ndarray, is_elevation: bool):
        """
        Check if extracted data has NaN in it.

        This typically happens if the dataset had a NaN or null value in one of its time steps at the extraction site,
        or if there wasn't a place in the parent mesh to extract from.

        Args:
            group_name: Name of the group in the _mp.h5 file that contains the arc's data. Used to find the arc's
                feature ID for error reporting.
            cell: Cell in the domain grid the data was extracted from. Used for reporting errors to the user.
            dataset: The data for the current arc and cell (and, in the case of velocity, the current face). Checked to
                see if it contains NaN.
            is_elevation: Whether this is an elevation dataset, as opposed to a velocity one.
        """
        nan_times = np.where(np.isnan(dataset))[0]

        if len(nan_times) == 0:
            return

        _prefix, feature_id = group_name.split('#')
        feature_id = int(feature_id)

        if (feature_id, cell, is_elevation) in self._found_nans:
            return
        self._found_nans.add((feature_id, cell, is_elevation))

        first_nan_time = int(nan_times[0])

        if is_elevation:
            self._elevation_nans.add((feature_id, cell, first_nan_time))
        else:
            self._velocity_nans.add((feature_id, cell, first_nan_time))

    def _report_nans(self, is_elevation: bool):
        """
        Report any nan values, if necessary.

        Args:
            is_elevation: Whether to report for the elevation dataset, as opposed to the velocity one.
        """
        if is_elevation and not self._elevation_nans:
            return
        if not is_elevation and not self._velocity_nans:
            return

        grid_path = self._data.tree_path(self._data.ugrid.uuid)
        coverage_path = self._data.tree_path(self._data.bc_coverage[0].uuid)
        if is_elevation:
            dataset = self._forced_elevation_dataset
            dataset_path = self._data.tree_path(self._forced_elevation_dataset.uuid)
            nans = self._elevation_nans
            note = 'Note: Elevation values are extracted at centers of cells.'
        else:
            dataset = self._forced_velocity_dataset
            dataset_path = self._data.tree_path(self._forced_velocity_dataset.uuid)
            nans = self._velocity_nans
            note = 'Note: Velocity values are extracted at centers of all four cell faces.'

        XmLog().instance.error(
            'One or more boundary condition arcs used the Extracted WSE source and extracted null values.'
        )
        XmLog().instance.error(f'Arc IDs are in: {coverage_path}')
        XmLog().instance.error(f'Cell IDs are in: {grid_path}')
        XmLog().instance.error(f'Representative time step indexes are in: {dataset_path}')
        XmLog().instance.error(note)
        XmLog().instance.error(f'{"Arc ID":>9} {"Cell ID":>9} {"Time Step":>12}')

        for feature_id, cell_id, time_step in sorted(nans):
            offset: datetime.timedelta = dataset.timestep_offset(time_step)
            days = offset.days
            hours, remainder = divmod(offset.seconds, 3600)
            minutes, seconds = divmod(remainder, 60)
            time = f'{days} {hours:02}:{minutes:02}:{seconds:02}'
            XmLog().instance.error(f'{feature_id:9} {cell_id:9} {time:>12}')

    def _make_elevation_extractor(self) -> TransientDatasetExtractor:
        """Make an extractor for extracting elevation data for WSE forcing arcs."""
        ugrid = self._forced_elevation_geometry
        dataset = self._forced_elevation_dataset
        start_time, end_time = self._get_simulation_start_and_end()
        start_index, end_index = get_indexes_within_simulation(dataset, start_time, end_time)
        indexes = [i for i in range(start_index, end_index)]
        extractor = TransientDatasetExtractor(ugrid, dataset, indexes)
        quadtree_ug = self._cov_mapper.quadtree_ugrid
        for group in self._groups_needing_extraction:
            cell_ids = group['Cells']
            locations = [quadtree_ug.get_cell_centroid(cell_id - 1)[1] for cell_id in cell_ids]
            extractor.add_locations(locations, cell_ids)

        extractor.extract()
        return extractor

    def _export_extracted_velocity(self):
        """Export all the extracted velocity data for all directions."""
        _coverage, component = self._data.bc_coverage
        if not component.data.wse_forcing_velocity_source:
            return

        extractor = self._make_velocity_extractor()
        self._export_extracted_velocity_direction(extractor, 'Top')
        self._export_extracted_velocity_direction(extractor, 'Bottom')
        self._export_extracted_velocity_direction(extractor, 'Left')
        self._export_extracted_velocity_direction(extractor, 'Right')

    def _export_extracted_velocity_direction(self, extractor: TransientDatasetExtractor, direction: str):
        """
        Export all the extracted velocity data for a specific direction.

        Args:
            extractor: Source of extracted data.
            direction: One of 'Top', 'Bottom', 'Left', or 'Right'.
        """
        grid_rotation = self._data.ugrid.angle

        if direction in ['Left', 'Right']:
            grid_rotation_rad = math.radians(grid_rotation)
            i = math.cos(grid_rotation_rad)
            j = math.sin(grid_rotation_rad)
        else:
            grid_rotation_rad = math.radians(grid_rotation) + math.pi / 2
            i = math.cos(grid_rotation_rad)
            j = math.sin(grid_rotation_rad)

        for group in self._groups_needing_extraction:
            velocity_group = group.create_group(direction)
            for cell in group['Cells']:
                name = f'Velocity_{cell}'
                values = []
                us = extractor.scalars_for(f'{direction}{cell - 1}', 0)
                vs = extractor.scalars_for(f'{direction}{cell - 1}', 1)
                self._check_extracted_data_for_nan(group.name, cell, us, False)
                self._check_extracted_data_for_nan(group.name, cell, vs, False)

                for u, v in zip(us, vs):
                    dot = i * u + j * v
                    values.append(dot)

                velocity_group.create_dataset(name, dtype=float, data=values)

    def _make_velocity_extractor(self) -> TransientDatasetExtractor:
        """Make an extractor for extracting elevation data for WSE forcing arcs."""
        ugrid = self._forced_elevation_geometry
        dataset = self._forced_velocity_dataset
        start_time, end_time = self._get_simulation_start_and_end()
        start_index, end_index = get_indexes_within_simulation(dataset, start_time, end_time)
        indexes = [i for i in range(start_index, end_index)]
        extractor = TransientDatasetExtractor(ugrid, dataset, indexes)
        quadtree_ug = self._cov_mapper.quadtree_ugrid
        locs = quadtree_ug.locations
        angle = self._cov_mapper.get_quadtree().angle

        for group in self._groups_needing_extraction:
            cell_ids = group['Cells']
            for cell_id in cell_ids:
                cell_index = cell_id - 1
                left, right, top, bottom = face_midpoints(quadtree_ug, locs, angle, cell_id - 1)
                identifiers = [f'Left{cell_index}', f'Right{cell_index}', f'Top{cell_index}', f'Bottom{cell_index}']
                extractor.add_locations([left, right, top, bottom], identifiers)

        extractor.extract()
        return extractor

    def _get_simulation_start_and_end(self) -> tuple[datetime.datetime, datetime.datetime]:
        """Get the start and end time of the simulation."""
        start_time = datetime.datetime.fromisoformat(self._data.sim_data.general.attrs['DATE_START'])
        duration = self._data.sim_data.general.attrs['SIM_DURATION_VALUE']
        units = self._data.sim_data.general.attrs['SIM_DURATION_UNITS']
        kwargs = {units: duration}
        delta = datetime.timedelta(**kwargs)
        end_time = start_time + delta

        return start_time, end_time


def get_indexes_within_simulation(dataset: DatasetReader, start_time: datetime.datetime,
                                  end_time: datetime.datetime) -> tuple[int, int]:
    """
    Get the indexes of a dataset that are within the simulation.

    Args:
        dataset: Dataset to get indexes from.
        start_time: Time the simulation starts at.
        end_time: Time the simulation ends at.

    Returns:
        The first index that is within the simulation (or the one just before the start, if none align exactly with the
        start), and the index one past the end.
    """
    ref_time = dataset.ref_time
    offsets = [dataset.timestep_offset(i) for i in range(dataset.num_times)]
    time_steps = [ref_time + offset for offset in offsets]

    # It's common that the dataset doesn't have a time step that lines up exactly with the start of the simulation. When
    # that happens, we'll leave the time step just before the start and assume it lined up exactly.
    start_index = 0
    for index, time in enumerate(time_steps):
        if time <= start_time:
            start_index = index

    end_index = -1
    for index, time in enumerate(reversed(time_steps)):
        if time <= end_time:
            end_index = len(time_steps) - index
            break

    return start_index, end_index


def dataset_ends_before(dataset: DatasetReader, end_time: datetime.datetime) -> bool:
    """
    Check if a dataset ends before the given time.

    Args:
        dataset: Dataset to get indexes from.
        end_time: Time the simulation ends at.

    Returns:
        Whether the dataset's last time step is before the provided end time.
    """
    ref_time = dataset.ref_time
    last_offset = dataset.timestep_offset(dataset.num_times - 1)
    last_time = ref_time + last_offset
    is_short = last_time < end_time

    return is_short


def make_geographic(projection: str, ugrid: UGrid | Any) -> Sequence[tuple[float, float, float]]:
    """
    Make a UGrid geographic.

    If the UGrid is already geographic, it is left unchanged. If reprojection is necessary, the UGrid's points will be
    reprojected into an arbitrary geographic projection (currently WGS84).

    Reprojection will fail if the projection is local or has no well_known_text.

    Args:
        projection: The WKT of the current projection of the grid.
        ugrid: The ugrid to project.

    Returns:
        Whether the UGrid is now in geographic coordinates.
    """
    if is_local(projection):
        raise ValueError('Unable to reproject local coordinates to geographic.')

    if not isinstance(ugrid, UGrid):
        ugrid = ugrid.ugrid

    centers = [ugrid.get_cell_centroid(i)[1] for i in range(ugrid.cell_count)]

    if is_geographic(projection):
        return centers

    projected_centers = transform_points_from_wkt(centers, projection, DEFAULT_GEOGRAPHIC_WKT)

    return cast(list[tuple[float, float, float]], projected_centers)


def possible_face_midpoints(ugrid: UGrid, cell_index: int, locations) -> Iterable[Pt2d]:
    """
    Get the possible midpoints for a grid cell.

    For a rectilinear grid, these are just the midpoints of each edge. In a quadtree, the cell on the other side of a
    face may have a higher refinement level, which results in the face between them being split in two edges, and means
    we have to deal with the possibility that the face's midpoint is actually one of the vertices of the cell.

    The below drawing illustrates: The Xs are face midpoints of the large cell on the left, but the one toward the right
    is also a vertex because the face is adjacent to refined cells.

    +--X--+----+
    |     |    |
    X     X----+
    |     |    |
    +--X--+----+

    Args:
        ugrid: The UGrid to get midpoints in.
        cell_index: Index of the cell to get possible midpoints for.
        locations: List-like of locations defining the cells in the UGrid.

    Returns:
        Generator of potential face midpoints.
    """
    for first, second in ugrid.get_cell_edges(cell_index):
        first_x, first_y, first_z = locations[first]
        second_x, second_y, second_z = locations[second]
        side_center: Pt2d = ((first_x + second_x) / 2, (first_y + second_y) / 2, (first_z + second_z) / 2)
        yield side_center

    yield from ugrid.get_cell_locations(cell_index)


def face_midpoints(ugrid: UGrid, locations: list, grid_angle: float, cell_index: int) -> tuple[Pt2d, Pt2d, Pt2d, Pt2d]:
    """
    Get the midpoints of each of a quad cell's faces, in order of left, right, top, bottom.

    Args:

        ugrid: The UGrid
        locations: The UGrid locations
        grid_angle: The Grid angle (from the CoGrid)
        cell_index: Index of the cell to find the midpoint of.

    Returns:
        The midpoints of each of the cell's edges, in order of left, right, top, bottom.
    """
    left = right = top = bottom = None

    for possible_midpoint in possible_face_midpoints(ugrid, cell_index, locations):
        angle = side_angle(ugrid, grid_angle, cell_index, possible_midpoint)

        if math.isclose(angle, 0.0, abs_tol=1) or math.isclose(angle, 360.0, abs_tol=1):
            right = possible_midpoint
        elif math.isclose(angle, 90.0, abs_tol=1):
            top = possible_midpoint
        elif math.isclose(angle, 180.0, abs_tol=1):
            left = possible_midpoint
        elif math.isclose(angle, 270.0, abs_tol=1):
            bottom = possible_midpoint
        else:
            # We're iterating over *possible* face midpoints, which means some won't be *actual* face midpoints, so
            # failure to identify one is normal.
            pass

    assert left is not None and right is not None and top is not None and bottom is not None
    return left, right, top, bottom


def side_angle(ugrid: UGrid, angle: float, cell_index: int, side_center: Pt2d) -> float:
    """
    Find the CCW angle between the grid's I-axis and the segment formed between a cells centroid and the given point.

    Args:
        ugrid: The UGrid geometry
        angle: The grid angle (from the CoGrid)
        cell_index: Index of the cell containing the centroid to use.
        side_center: Center of one of the cell's sides.

    Returns:
        The computed angle.
    """
    _, (centroid_x, centroid_y, _) = ugrid.get_cell_centroid(cell_index)
    side_x, side_y, side_z = side_center

    side_direction = (side_x - centroid_x, side_y - centroid_y)
    origin = (0.0, 0.0)
    x_axis = (1.0, 0.0)
    world_angle = math.degrees(angle_between_edges_2d(x_axis, origin, side_direction))
    grid_angle = world_angle - angle

    if grid_angle > 360.0:
        grid_angle -= 360.0
    if grid_angle < 0.0:
        grid_angle += 360.0

    return grid_angle


def _log_forcing_error(message: str, arcs: list[int]):
    joined = ', '.join(str(arc) for arc in arcs)
    formatted = (
        f'Cannot export boundary conditions: One or more arcs use extracted WSE forcing, but {message}. '
        f'Arcs with the following IDs cannot be exported: {joined}'
    )
    XmLog().instance.error(formatted)


def _uses_extracted_wse_forcing(xms_data: XmsData) -> bool:
    """Get whether the simulation uses Extracted WSE forcing."""
    _coverage, bc_component = xms_data.bc_coverage
    if not bc_component:
        return False

    bc_data: BCData = bc_component.data

    arc_component_id_map = bc_component.comp_to_xms[bc_component.cov_uuid][TargetType.arc]

    for component_id in arc_component_id_map:
        if bc_data.arc_uses_extracted_wse_forcing(component_id):
            return True

    return False
