"""Calculator for comparing simulations."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import sys

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.srh.floodway.xms_getter import XmsGetter


def _compute_average(values, null_value):
    avg = 0.0
    if not values:
        return avg
    count = 0
    for value in values:
        if value != null_value:
            avg += value
            count += 1
    if count:
        return avg / count
    return avg


def _get_min_diff(values_1, values_2, null_value_1, null_value_2):
    min_diff = sys.float_info.max
    if not values_1 or not values_2:
        return 0.0
    values = zip(values_1, values_2)
    found = False
    for value_1, value_2 in values:
        if value_1 != null_value_1 and value_2 != null_value_2:
            found = True
            diff = abs(value_2 - value_1)
            if diff < min_diff:
                min_diff = diff
    if found:
        return min_diff
    return 0.0


def _get_max_diff(values_1, values_2, null_value_1, null_value_2):
    max_diff = 0.0
    if not values_1 or not values_2:
        return max_diff
    values = zip(values_1, values_2)
    for value_1, value_2 in values:
        if value_1 != null_value_1 and value_2 != null_value_2:
            diff = abs(value_2 - value_1)
            if diff > max_diff:
                max_diff = diff
    return max_diff


def _get_avg_diff(values_1, values_2, null_value_1, null_value_2):
    avg_diff = 0.0
    if not values_1 or not values_2:
        return avg_diff
    count = 0
    values = zip(values_1, values_2)
    for value_1, value_2 in values:
        if value_1 != null_value_1 and value_2 != null_value_2:
            avg_diff += abs(value_2 - value_1)
            count += 1
    if count:
        return avg_diff / count
    return avg_diff


class SimulationCompareCalc:
    """A class that compares the datasets from a floodplain and floodway."""
    def __init__(
        self,
        query,
        fp_geom_uuid,
        fp_geom_name,
        fp_wse_uuid,
        fp_vmag_uuid,
        fw_geom_uuid,
        fw_geom_name,
        fw_wse_uuid,
        fw_vmag_uuid,
        cross_section_uuid,
        simulation_compare_data=None
    ):
        """Initializes the Floodway Comparison calculator.

        Args:
            query (:obj:`Query`): XMS interprocess communication object
            fp_geom_uuid (:obj:`UUID`): The floodplain geometry UUID
            fp_geom_name (:obj:`str`): The floodplain geometry name
            fp_wse_uuid (:obj:`UUID`): The floodplain WSE UUID
            fp_vmag_uuid (:obj:`UUID`): The floodplain vmag UUID
            fw_geom_uuid (:obj:`UUID`): The floodway geometry UUID
            fw_geom_name (:obj:`name`): The floodway geometry name
            fw_wse_uuid (:obj:`UUID`): The floodway WSE UUID
            fw_vmag_uuid (:obj:`UUID`): The floodway vmag UUID
            cross_section_uuid (:obj:`UUID`): The cross section UUID
            simulation_compare_data (:obj:`dict`): Dict of data to be used in this class for testing
        """
        # Input
        self.query = query
        self.fp_geom_uuid = fp_geom_uuid
        self.fp_geom_name = fp_geom_name
        self.fp_wse_uuid = fp_wse_uuid
        self.fp_vmag_uuid = fp_vmag_uuid
        self.fw_geom_uuid = fw_geom_uuid
        self.fw_geom_name = fw_geom_name
        self.fw_wse_uuid = fw_wse_uuid
        self.fw_vmag_uuid = fw_vmag_uuid
        self.cross_section_uuid = cross_section_uuid
        self.simulation_compare_data = simulation_compare_data

    def compute_data(self):
        """Computes the data possible; stores results in self.

        Returns:
            Dictionary containing the following data for each arc:
            Average Floodplain WSE
            Average Floodway WSE
            Minimum WSE Difference
            Maximum WSE Difference
            Average WSE Difference
            Average Floodplain VMag
            Average Floodway VMag
            Minimum VMag Difference
            Maximum VMag Difference
            Average VMag Difference
        """
        floodplain_average_wse = f'Average {self.fp_geom_name} base WSE'
        floodway_average_wse = f'Average {self.fw_geom_name} revised WSE'
        floodplain_average_vmag = f'Average {self.fp_geom_name} base VMag'
        floodway_average_vmag = f'Average {self.fw_geom_name} revised VMag'
        # Setup the default return dictionary
        all_data = {
            'Arc ID': [],
            floodplain_average_wse: [],
            floodway_average_wse: [],
            'Minimum WSE Difference': [],
            'Maximum WSE Difference': [],
            'Average WSE Difference': [],
            floodplain_average_vmag: [],
            floodway_average_vmag: [],
            'Minimum VMag Difference': [],
            'Maximum VMag Difference': [],
            'Average VMag Difference': []
        }

        # Get the WSE information
        if self.fp_wse_uuid and self.fp_geom_uuid:
            fp_wse_data, fp_null = self._get_xms_data(self.fp_wse_uuid, self.fp_geom_uuid)
        else:
            fp_wse_data = self.simulation_compare_data['fp_wse_data']
            fp_null = self.simulation_compare_data['fp_null']
        if self.fw_wse_uuid and self.fw_geom_uuid:
            fw_wse_data, fw_null = self._get_xms_data(self.fw_wse_uuid, self.fw_geom_uuid)
        else:
            fw_wse_data = self.simulation_compare_data['fw_wse_data']
            fw_null = self.simulation_compare_data['fw_null']
        for fp_arc_id, fw_arc_id in zip(fp_wse_data, fw_wse_data):
            all_data['Arc ID'].append(fp_arc_id)
            fp_avg = _compute_average(fp_wse_data[fp_arc_id][0], fp_null)
            all_data[floodplain_average_wse].append(fp_avg)
            fw_avg = _compute_average(fw_wse_data[fw_arc_id][0], fw_null)
            min_diff = _get_min_diff(fp_wse_data[fp_arc_id][0], fw_wse_data[fw_arc_id][0], fp_null, fw_null)
            max_diff = _get_max_diff(fp_wse_data[fp_arc_id][0], fw_wse_data[fw_arc_id][0], fp_null, fw_null)
            avg_diff = _get_avg_diff(fp_wse_data[fp_arc_id][0], fw_wse_data[fw_arc_id][0], fp_null, fw_null)
            all_data[floodway_average_wse].append(fw_avg)
            all_data['Minimum WSE Difference'].append(min_diff)
            all_data['Maximum WSE Difference'].append(max_diff)
            all_data['Average WSE Difference'].append(avg_diff)

        # Get the VMag information
        if self.fp_vmag_uuid and self.fp_geom_uuid:
            fp_vmag_data, fp_null = self._get_xms_data(self.fp_vmag_uuid, self.fp_geom_uuid)
        else:
            fp_vmag_data = self.simulation_compare_data['fp_vmag_data']
        if self.fw_vmag_uuid and self.fp_geom_uuid:
            fw_vmag_data, fw_null = self._get_xms_data(self.fw_vmag_uuid, self.fw_geom_uuid)
        else:
            fw_vmag_data = self.simulation_compare_data['fw_vmag_data']
        for fp_arc_id, fw_arc_id in zip(fp_vmag_data, fw_vmag_data):
            fp_avg = _compute_average(fp_vmag_data[fp_arc_id][0], fp_null)
            all_data[floodplain_average_vmag].append(fp_avg)
            fw_avg = _compute_average(fw_vmag_data[fw_arc_id][0], fw_null)
            min_diff = _get_min_diff(fp_vmag_data[fp_arc_id][0], fw_vmag_data[fw_arc_id][0], fp_null, fw_null)
            max_diff = _get_max_diff(fp_vmag_data[fp_arc_id][0], fw_vmag_data[fw_arc_id][0], fp_null, fw_null)
            avg_diff = _get_avg_diff(fp_vmag_data[fp_arc_id][0], fw_vmag_data[fw_arc_id][0], fp_null, fw_null)
            all_data[floodway_average_vmag].append(fw_avg)
            all_data['Minimum VMag Difference'].append(min_diff)
            all_data['Maximum VMag Difference'].append(max_diff)
            all_data['Average VMag Difference'].append(avg_diff)

        return all_data

    def _get_xms_data(self, data_uuid, geom_uuid):
        xms_getter = XmsGetter(self.query, self.cross_section_uuid, [], dataset_uuid=data_uuid)
        data = xms_getter.retrieve_xms_data(geom_uuid, False, False)
        dataset = self.query.item_with_uuid(data_uuid)
        null_value = -999.0
        if dataset.null_value is not None:
            null_value = dataset.null_value
        return data, null_value
