"""Utility functions for units."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.base_file_data import BaseFileData

# Some common units
UNITS_LENGTH = '[L]'
UNITS_TIME = '[T]'
UNITS_AREA = '[L^2]'
UNITS_SS = '[L^-1]'  # Specific storage
UNITS_K = '[L/T]'  # Hydraulic conductivity
UNITS_INFILTRATION = '[L/T]'  # Infiltration rate (ET, recharge)
UNITS_COND = '[L^2/T]'  # Conductance
UNITS_DIFFC = '[L^2/T]'  # Diffusion coefficient
UNITS_CONCENTRATION = '[concentration]'  # Concentration
UNITS_BULK_D = '[M/L^3]'  # Bulk density
UNITS_DIST_C = '[L^3/M]'  # Distribution coefficient
UNITS_DEGREES = '[DEGREES]'
UNITS_UNITLESS = '[unitless]'
UNITS_TEMPERATURE = '[temperature]'

MF6_TIME_UNITS = ['UNKNOWN', 'SECONDS', 'MINUTES', 'HOURS', 'DAYS', 'YEARS']


def _substitute_unit(units_spec: str, substitutions: dict[str, str], unit: str, unknown: str) -> str:
    """Returns the units_spec with the unit replaced by what is in substitutions.

    Args:
        units_spec (str): A unit specification string (e.g. '[L^3/T]').
        substitutions (dict[str, str]): Dict defining what should be substituted.
        unit (str): The single unit to replace (e.g. 'L', or 'T').
        null_case (str): The string that indicates the unit is undefined or unknown.

    Returns:
        (str)): See description.
    """
    units = units_spec
    u = substitutions.get(unit, '')
    if unit in units_spec and u and u.upper() != unknown:
        units = units_spec.replace(unit, u)
    return units


def _substitute_units(substitutions: dict[str, str], units_spec) -> str:
    """Returns a units string given a unit specification string (e.g. '[L^3/T]') and substitutions.

    Args:
        substitutions (dict[str, str]): Dict defining what should be substituted.
        units_spec (str): A unit specification string (e.g. '[L^3/T]').

    Returns:
        (str): See description.
    """
    # Use '$' to mark beginning/end so that we don't overwrite something we've already substituted ('FEET'->'FEEYEAR')
    presubstitutions = {'L': '$L$', 'T': '$T$'}
    if 'L' in substitutions and substitutions['L'].upper() != 'UNKNOWN':
        units_spec = _substitute_unit(units_spec, presubstitutions, 'L', '')
        substitutions['$L$'] = substitutions.pop('L')
    if 'T' in substitutions and substitutions['T'].upper() != 'UNKNOWN':
        units_spec = _substitute_unit(units_spec, presubstitutions, 'T', '')
        substitutions['$T$'] = substitutions.pop('T')

    units_spec = _substitute_unit(units_spec, substitutions, '$L$', 'UNKNOWN')
    units_spec = _substitute_unit(units_spec, substitutions, '$T$', 'UNKNOWN')
    return units_spec


def _get_unit_substitutions(data: BaseFileData, units_spec: str) -> dict[str, str]:
    """Returns a dict with the unit substitution info.

    Args:
        data: Data class.
        units_spec: A unit specification string (e.g. '[L^3/T]').

    Returns:
        (str): See description.
    """
    substitutions = {}
    if 'L' in units_spec and data and data.model:
        dis = data.model.get_dis()
        if dis and dis.options_block.defined('LENGTH_UNITS'):
            length = data.model.get_dis().options_block.get('LENGTH_UNITS')
            if length and length.upper() != 'UNKNOWN':
                substitutions['L'] = length

    if 'T' in units_spec and data and data.mfsim:
        tdis = data.mfsim.tdis
        if tdis and tdis.options_block.defined('TIME_UNITS'):
            time = tdis.options_block.get('TIME_UNITS')
            if time and time.upper() != 'UNKNOWN':
                substitutions['T'] = time
    return substitutions


def string_from_units(data: BaseFileData, units_spec: str) -> str:
    """Returns a units string given a unit specification string (e.g. '[L^3/T]').

    Length = L
    Time = T
    Angle = A

    Args:
        data: Data class.
        units_spec: A unit specification string (e.g. '[L^3/T]').

    Returns:
        (str): See description.
    """
    substitutions = _get_unit_substitutions(data, units_spec)
    return _substitute_units(substitutions, units_spec)


def dataset_time_units_from_tdis_time_units(tdis_time_units: str) -> str | None:
    """Returns a dataset time units string given a TDIS time units string, or '' if unable.

    Args:
        tdis_time_units: The TIME_UNITS option in the TDIS file.

    Returns:
        See description.
    """
    if not tdis_time_units:
        return 'None'  # Default to 'None' because empty string will crash when creating the dataset writer

    if tdis_time_units.upper() in MF6_TIME_UNITS and tdis_time_units.upper() != 'UNKNOWN':
        return tdis_time_units.title()
    return 'None'  # Default to 'None' because empty string will crash when creating the dataset writer
