"""Perform tidal constituent extraction using tidal constituent component data."""

__copyright__ = '(C) Copyright Aquaveo 2025'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import datetime

# 2. Third party modules
from harmonica.tidal_constituents import Constituents
import xarray as xr

# 3. Aquaveo modules
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules
from xms.tides.data import tidal_data as td


USER_DEFINED_INDEX = 15  # Index of the user defined constituent in the user defined table


class TidalExtractor:
    """Uses harmonica to extract tidal constituent properties and datasets."""

    def __init__(self, tidal_data):
        """Initialize the extractor.

        Args:
            tidal_data (TidalData): The tidal constituent component data file.
        """
        self.data = tidal_data
        self.model = None
        self._init_model()

    def _init_model(self):
        """Initialize the harmonica tidal constituent extraction model."""
        # Get the correct model text for harmonica from source index.
        source = self.data.info.attrs['source'].item()
        # Use the ADCIRC database by default. Will use to get constituent properties unless the constituents are user
        # defined and not one of the standard constituents.
        harmonica_source = 'adcirc2015'
        if source == td.LEPROVOST_INDEX:
            harmonica_source = 'leprovost'
        elif source == td.FES2014_INDEX:
            harmonica_source = 'fes2014'
        elif source == td.TPX08_INDEX:
            harmonica_source = 'tpxo8'
        elif source == td.TPX09_INDEX:
            harmonica_source = 'tpxo9'
        self.model = Constituents(harmonica_source)

    def get_constituent_properties(self):
        """Get the frquency, nodal factor, and equilibrium argument for a component's enabled constituents.

        Returns:
            xarray.Dataset: Dataset containing frequency, nodal factor, and equilibrium argument for the enabled
            constituents.
        """
        source = self.data.info.attrs['source'].item()
        start_dt = datetime.datetime.strptime(self.data.info.attrs['reftime'], ISO_DATETIME_FORMAT)
        if source == td.USER_DEFINED_INDEX:
            # Check for any rows in the table that are not standard constituents
            con_props_combined = None
            user_cons = self.data.user.where(self.data.user.Constituent == USER_DEFINED_INDEX, drop=True)
            if 'index' in user_cons.sizes and user_cons.sizes['index'] > 0:
                # Extract the user-defined constituent properties from the source tidal component.
                con_props_combined = self._extract_user_con_properties(user_cons)
            standard_cons = self.data.user.where(self.data.user.Constituent != USER_DEFINED_INDEX, drop=True)
            if 'index' in standard_cons.sizes and standard_cons.sizes['index'] > 0:
                cons = [
                    td.STANDARD_CONSTITUENTS[con_idx] for con_idx in
                    standard_cons['Constituent'].data.astype(int).tolist()
                ]
                standard_cons = self._extract_standard_con_properties(cons, start_dt)
                if con_props_combined is not None:
                    con_props_combined = xr.concat([con_props_combined, standard_cons], 'con')
                else:
                    con_props_combined = standard_cons
            return con_props_combined
        else:  # Get the selected constituents.
            # Ask harmonica for the enabled constituents' properties.
            cons = self.data.cons.where(self.data.cons.enabled == 1, drop=True)
            return self._extract_standard_con_properties(cons['name'].data.tolist(), start_dt)

    def get_amplitude_and_phase(self, locs, node_ids):
        """Get the amplitude and phase for the enabled constituents at the given locations.

        Args:
            locs (list): List of ocean boundary node locations. [(x1, y1), ..., (xN, yN)]
            node_ids (list): List of integer ocean boundary node ids parallel with locs. Used to build coords of
                the returned xarray.Dataset.

        Returns:
            xarray.Dataset: Dataset containing frequency, nodal factor, and equilibrium argument for the enabled
            constituents at the specified locations.
        """
        cons = self.data.cons.where(self.data.cons.enabled == 1, drop=True)  # Get the enabled constituents
        # Query harmonica for amplitude and phase at nodal values
        dfs = self.model.get_components(locs, cons=cons['name'].data.tolist(), positive_ph=True)
        # Convert pandas.DataFrames to xarray Datasets and rename default pandas dimension name to 'con'
        dsets = [df.to_xarray().rename({'index': 'con'}) for df in dfs.data]
        if dsets:
            # Build a data cube of the location DataFrames.
            all_dsets = xr.concat(dsets, 'node_id')
            # Assign passed in ocean boundary node ids as coords of the node_id dimension. Transpose the data so the
            # first dimension is the constituent and the second is node id.
            return all_dsets.assign_coords({'node_id': node_ids}).transpose().sortby('con')
        return None

    def _extract_standard_con_properties(self, standard_cons, start_dt):
        """Get the frquency, nodal factor, and equilibrium argument for standard user defined constituents.

        Args:
            standard_cons (list): List of the constituent names to query harmonica for properties
            start_dt (datetime.datetime): Reference time from extraction. Should be a datetime object.

        Returns:
            xarray.Dataset: Dataset containing frequency, nodal factor, and equilibrium argument for the standard
            constituents.
        """
        # Ask harmonica for the enabled constituents' properties.
        middle_dt = start_dt + datetime.timedelta(hours=self.data.info.attrs['run_duration'] / 2)
        df = self.model.get_nodal_factor(standard_cons, start_dt, timestamp_middle=middle_dt)
        # Drop properties we don't care about and rename the coords
        dset = df.drop(['speed'], axis=1).to_xarray()
        dset = dset.rename({'index': 'con'}).sortby('con')
        return dset

    def _extract_user_con_properties(self, user_cons):
        """Get the frquency, nodal factor, and equilibrium argument for non-standard user defined constituents.

        Args:
            user_cons (xarray.Dataset): The user defined table Dataset

        Returns:
            xarray.Dataset: Dataset containing tidal potential amplitude, frequency, nodal factor, equilibrium argument,
            and Earth tide reduction factor, for the non-standard constituents.
        """
        user_cons = user_cons.set_coords(['Name'])  # Set the constituent name as the index coordinate
        user_cons = user_cons.swap_dims({'index': 'Name'})
        user_cons = user_cons.reset_coords(['index'], drop=True)  # Drop the default index
        user_cons = user_cons.drop_vars(['Constituent', 'Amplitude', 'Phase'])
        user_cons = user_cons.rename({  # Rename the coords and data variables
            'Name': 'con',
            'Tidal __new_line__ Potential __new_line__ Amplitude': 'amplitude',
            'Frequency': 'frequency',
            'Nodal __new_line__ Factor': 'nodal_factor',
            'Equilibrium __new_line__ Argument': 'equilibrium_argument',
            'Earth Tide __new_line__ Reduction __new_line__ Factor': 'earth_tide_reduction_factor',
        }).sortby('con')
        return user_cons
