"""AdvectiveCourantNumber Algorithm."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.ugrid import UGrid
from xms.mesher.meshing import mesh_utils

# 4. Local modules


class CoastDatasetCalc:
    """A class for calculating the advective Courant number."""
    def __init__(self, logger=None, dataset_is_depth=False, gravity_const=9.80665, timestep_seconds=1.0,
                 min_depth=0.1, courant_number=1.0, period=20.0, dataset=None, co_grid=None, wkt='',
                 dataset_builder=None) -> None:
        """Initializes the class.

        Args:
            logger: The logger to use.
            dataset_is_depth: True if the dataset is depth.
            gravity_const: The gravity constant.
            timestep_seconds: The time step in seconds.
            min_depth: The minimum depth.
            courant_number: The Courant number.
            period: The period.
            dataset: The dataset.
            co_grid: The constrained UGrid.
            wkt: The wkt of projection.
            dataset_builder: The dataset builder.
        """
        self.logger = logger
        self.dataset_is_depth = dataset_is_depth
        self.gravity_const = gravity_const
        self.timestep_seconds = timestep_seconds
        self.min_depth = min_depth
        self.courant_number = courant_number
        self.period = period
        self.dataset = dataset
        self.co_grid = co_grid
        self.default_wkt = wkt
        self.dataset_builder = dataset_builder
        self._calc_size_func = True
        self._ugrid = co_grid.ugrid if co_grid else None
        self._size_func = None
        self._depth_vals = None
        self._depth_activity = None
        self._out_vals = None
        self._calc = None
        self._w = 0.0
        self._gx = 0.0
        self._vx = None
        self._idx = -1

    def _transform_grid_if_geographic(self):
        """Sets the logger."""
        if not self._calc_size_func:
            return
        sr = None if not self.default_wkt else gu.wkt_to_sr(self.default_wkt)
        if sr and sr.IsGeographic():
            self.logger.info('Converting geographic coordinates to UTM...')
            locs, _ = gu.convert_lat_lon_pts_to_utm(self._ugrid.locations)
            self._ugrid = UGrid(locs, self._ugrid.cellstream)

    def _calc_size_function(self):
        """Calculates the size function."""
        if not self._calc_size_func:
            self._size_func = np.asarray([1.0] * len(self._ugrid.locations), dtype=np.float64)
            return
        self.logger.info('Calculating size function...')
        self._size_func = np.array(mesh_utils.size_function_from_edge_lengths(self._ugrid))
        indexes = np.where(~(self._size_func > 0.0))[0]
        if indexes.size > 0:
            indexes += 1
            self.logger.warning(f'{len(indexes)} disjoint points found and assigned NULL value. Point ids: {indexes}')

    def _get_depth(self):
        """Gets the depth."""
        if self.dataset.activity:
            self.dataset_builder.use_activity_as_null = True
            if self.dataset.values.shape != self.dataset.activity.shape:  # Nodal dataset values with cell activity
                self.dataset.activity_calculator = CellToPointActivityCalculator(self._ugrid)
                self.dataset_builder.activity_calculator = CellToPointActivityCalculator(self._ugrid)

        self._depth_vals, self.depth_activity = self.dataset.timestep_with_activity(0)
        if not self.dataset_is_depth:
            self._depth_vals *= -1.0
        indexes = np.where(self._depth_vals < self.min_depth)[0]
        if indexes.size > 0:
            indexes += 1
            self.logger.warning(f'{len(indexes)} points found with depth less than minimum depth. Point ids: {indexes}')
        # Make sure depth is at least self.min_depth
        self._depth_vals = np.array(self._depth_vals, dtype=np.float64)
        self._depth_vals[self._depth_vals < self.min_depth] = self.min_depth

    def _do_calc(self):
        """Does the appropriate calculation."""
        self._transform_grid_if_geographic()
        self._calc_size_function()
        self._get_depth()
        vals = [self.dataset.null_value] * len(self._depth_vals)
        self.logger.info('Calculating values at points...')
        for i, depth in enumerate(self._depth_vals):
            if self._size_func[i] > 0.0:
                self._idx = i
                self._pt_size = self._size_func[i]
                self._pt_depth = depth
                vals[i] = self._calc()
        vals = np.array(vals, dtype=float)
        self.dataset_builder.append_timestep(self.dataset.times[0], vals, self.depth_activity)

    def _calc_gravity_wave_courant(self):
        """Calculates the gravity waves Courant number."""
        return math.sqrt(self.gravity_const * float(self._pt_depth)) * self.timestep_seconds / self._pt_size

    def gravity_waves_courant(self):
        """Calculates the gravity waves Courant number."""
        self._calc = self._calc_gravity_wave_courant
        self._do_calc()

    def _calc_gravity_waves_timestep(self):
        """Calculates the gravity waves Courant number for each time step."""
        return self.courant_number * self._pt_size / math.sqrt(self.gravity_const * self._pt_depth)

    def gravity_waves_timestep(self):
        """Calculates the gravity waves Courant number for each time step."""
        self._calc = self._calc_gravity_waves_timestep
        self._do_calc()

    def _calc_wavelength_celerity_vx(self):
        """Calculates the wave length and celerity vx term."""
        # self._w and self._gx already calculated
        # w = 2.0 * math.pi / self.period
        # gx = w * w * self._pt_depth / self.gravity_const
        gx = self._gx * self._pt_depth
        vx = 1.0

        while True:
            wx = math.tanh(vx)
            fx = (vx * wx) - gx
            fpr = vx * (1.0 - wx * wx) + wx
            vx -= fx / fpr

            if abs(fx) <= 0.1e-3:
                break
        self._vx[self._idx] = vx
        return vx

    def _calc_wavelength(self):
        """Calculates the wave length and celerity."""
        vx = self._calc_wavelength_celerity_vx()
        return 2.0 * math.pi * self._pt_depth / vx

    def wavelength(self):
        """Calculates the wavelength."""
        self._calc_size_func = False
        self._w = 2.0 * math.pi / self.period
        self._gx = self._w * self._w / self.gravity_const
        self._calc = self._calc_wavelength
        self._vx = np.zeros(self.dataset.num_values, dtype=np.float64)
        self._do_calc()

    def _calc_celerity(self):
        """Calculates the celerity."""
        vx = self._vx[self._idx]
        k = vx / self._pt_depth
        return math.sqrt((self.gravity_const / k) * math.tanh(vx))

    def celerity(self):
        """Calculates the celerity."""
        self._calc = self._calc_celerity
        self._do_calc()
