"""LanduseMapper class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import values_with_nans
from xms.gdal.utilities import gdal_utils as gu
from xms.mesher.meshing import mesh_utils

# 4. Local modules

log_frequency = 10000


def set_log_frequency(frequency):
    """Sets the log frequency to the given frequency.

    Args:
        frequency (int):  The new frequency.
    """
    global log_frequency
    log_frequency = frequency


class LanduseMapper():
    """Mannings N Creator class for landuse calculations involving a grid and landuse raster."""

    def __init__(self, logger, raster, grid, default_value, landuse_codes, landuse_vals, default_dset=None,
                 lnd=None, grid_wkt=None):
        """Constructor.

        Args:
            logger (logging.Logger):  The logger.
            raster (RasterInput):  The landuse raster to work on.
            grid (cogrid):  The grid containing the points to process.
            default_value (float):  The default value to assign in the calculation.
            landuse_codes (list of int):  List of landuse code values.
            landuse_vals (list):  List of landuse values.
            default_dset (dataset):  The default dataset.
            lnd (dataset):  The locked nodes dataset.
            grid_wkt (str): The grid's projection.
        """
        self._logger = logger
        self._grid = grid
        self._grid_wkt = grid_wkt
        self._default_dset = default_dset
        self._dataset_vals = None
        self._x_origin = None
        self._y_origin = None
        self._pixel_width = 0.0
        self._pixel_height = 0.0
        self._xsize = 0.0
        self._ysize = 0.0
        self._raster_values = None
        self._default_value = default_value
        self._lnd = lnd
        self._landuse_codes = np.array(landuse_codes)
        self._landuse_vals = landuse_vals
        self._mapping_table = dict(zip(landuse_codes, landuse_vals))
        self._bad_codes_found = {}

        self._initialize_raster(raster)

    def _initialize_raster(self, raster):
        """Initializes some raster data from the raster passed in.

        Args:
            raster (RasterInput):  The raster to process.
        """
        self._x_origin = raster.xorigin
        self._y_origin = raster.yorigin
        self._pixel_width = raster.pixel_width
        self._pixel_height = raster.pixel_height
        self._xsize, self._ysize = raster.resolution
        self._raster_values = raster.get_raster_values()
        self._raster_wkt = raster.wkt

    def process_points(self):
        """Process the points to fill in the dataset values.

        Returns:
            dataset:  The dataset calculated.
        """
        # Calculate the size function and set the dataset values default
        self._logger.info('Calculating size function...')
        locations = self._grid.locations
        if self._grid_wkt is not None:
            # if the grid had a projection then transform the points to the raster projection
            new_locations = gu.transform_points_from_wkt(locations, self._grid_wkt, self._raster_wkt)
            self._grid.locations = new_locations
        size_func = mesh_utils.size_function_from_edge_lengths(self._grid)

        self._fill_array(size_func)
        if len(self._bad_codes_found):
            warning = "No Manning's values found for the following codes in the landuse raster; make sure you are " \
                      "using the correct landuse lookup table:"
            for key in self._bad_codes_found.keys():
                warning += f' {key}'
            self._logger.warning(warning)

        return self._dataset_vals

    def _fill_array(self, size_func):
        """Extract Manning's N dataset values from a land use raster.

        Args:
            size_func (Sequence): The size function for the grid points, parallel with point_locations

        Returns:
            np.array: The dataset values for this chunk
        """
        point_locations = self._grid.locations.tolist()  # Iterate on a pure Python list, numpy array iteration is slow
        num_points = len(point_locations)

        # Set up the default dataset values
        if self._default_dset:
            self._dataset_vals = self._default_dset.values[0]
        else:
            self._dataset_vals = np.full(num_points, self._default_value)

        # Set up an activity array based on the optional locked nodes dataset
        if self._lnd:
            values = self._lnd.values[0]
            activity = None if self._lnd.activity is None else self._lnd.activity[0]
            lnd = values_with_nans(self._grid, values, activity, self._lnd.null_value)
        else:
            lnd = np.zeros(num_points)
        self._pixel_area = (self._pixel_width * self._pixel_height)
        # Loop on the grid points and size function
        self._logger.info(f'Processing point 1 of {num_points}...')
        for i, (point, size_val) in enumerate(zip(point_locations, size_func)):
            if (i + 1) % log_frequency == 0:
                self._logger.info(f'Processing point {i + 1} of {num_points}...')
            if math.isnan(lnd[i]) or lnd[i] == 1.0 or math.isnan(size_val):
                continue
            self._process_point(i, point, size_val)

    def _process_point(self, idx, point, size_val):
        """Extract a raster value for a location in the target grid.

        Args:
            idx (int): Index of the grid location
            point (tuple): The grid location
            size_val (float): Size function value for the location
        """
        # Based on the size function, make a simple box around the point
        # (SMS currently does a polygon, so results will differ slightly)
        x = point[0]
        y = point[1]
        ul = (x - size_val / 2.0, y + size_val / 2.0)
        lr = (x + size_val / 2.0, y - size_val / 2.0)
        xmin = ul[0]
        xmax = lr[0]
        ymin = lr[1]
        ymax = ul[1]

        # Calculate offset and rows and columns to read
        xoff = int((xmin - self._x_origin) / self._pixel_width)
        yoff = int((ymax - self._y_origin) / self._pixel_height)
        xcount = int((xmax - xmin) / self._pixel_width) + 1
        ycount = int((ymax - ymin) / abs(self._pixel_height)) + 1

        # Ensure that we are reading only from the raster area... if we try to go past the bounds we get None
        # Make sure starting row/col is on the raster, and the box we read is all on the raster
        if xoff < 0:
            # Too far to the left of the raster... shift the start onto the raster
            xcount = xcount + xoff
            xoff = 0
        if yoff < 0:
            # Too far to the top of the raster... shift the start onto the raster
            ycount = ycount + yoff
            yoff = 0
        if xoff + xcount > self._xsize:
            # We're trying to read off of the right of the raster... clip the amount to read
            xcount = self._xsize - xoff
        if yoff + ycount > self._ysize:
            # We're trying to read off of the bottom of the raster... clip the amount to read
            ycount = self._ysize - yoff
        if xcount < 1:
            return  # Entire bounding box of point is outside raster extents, trivial reject
        if ycount < 1:
            return  # Entire bounding box of point is outside raster extents, trivial reject

        # Read the small box of pixel values that surround the point, based on the size function
        vals = self._raster_values[yoff: yoff + ycount, xoff: xoff + xcount].flatten().tolist()
        scalars = []
        cumulative_area = 0.0
        for val in vals:
            cur_val = self._get_mapping_value(val)
            if cur_val[0]:
                scalars.append(cur_val[1])
                cumulative_area += self._pixel_area

        if len(scalars):
            # Calculate a weighted area Manning's N value for the grid location.
            weighted_average_n = 0
            for scalar in scalars:
                weighted_average_n += scalar * self._pixel_area / cumulative_area
            self._dataset_vals[idx] = weighted_average_n

    def _get_mapping_value(self, code):
        """Gets the landuse mapping value (Mannings, canopy coefficient, etc) from the landuse code.

        Args:
            code (int):  The landuse code.

        Returns:
            tuple(bool, float):  Whether we were able to get the mapping value and The landuse mapping value
            corresponding to the code value passed in.
        """
        value = self._mapping_table.get(code)
        if value is None:
            self._bad_codes_found[code] = None
            return False, 0.0
        return True, value
