"""Class to combine the attributes of multiple input coverages."""
# 1. Standard python modules
from collections import OrderedDict
import logging

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.tuflowfv.components.tuflowfv_component import UNINITIALIZED_COMP_ID
from xms.tuflowfv.data.material_data import MaterialData
from xms.tuflowfv.data.bc_data import get_default_bc_data


# Different GIS types
GIS_POINT = 0
GIS_NODESTRING = 1
GIS_POLYGON = 2


class CoverageCollector:
    """Class to combine the attributes of multiple input coverages."""

    def __init__(self, xms_data):
        """Constructor.

        Args:
            xms_data (XmsData): Simulation data retrieved from SMS
        """
        self._logger = logging.getLogger('xms.tuflowfv')
        self._xms_data = xms_data
        self._next_mat_id = 1
        self._mismatched_material_names = {}  # {mat_name: {mismatched_mat_names}}
        self.next_bc_id = 1  # Start numbering point BCs after all the arc BCs
        # Polygons were added later. They are actually stored in the 'nodestring' variables
        self.bcs = OrderedDict()  # {bc_id: (xr.Dataset, xr.Dataset)}  - bc_id to (atts, curve)
        self.bc_lookup = {}  # {cov_idx: {TargetType: {feature_id: bc_id}}}
        self.bc_points = []  # [(do_point, xr.Dataset, xr.Dataset)]  - (geometry, atts, curve)
        self.material_list = OrderedDict()  # {mat_id: xr.Dataset}
        self.material_lookup = {}  # {cov_idx: {comp_id: mat_id}}
        self.bc_types = {}  # {cov_idx: GIS_*}
        self.default_mat_id = UNINITIALIZED_COMP_ID
        self._collect_coverages()

    def _collect_coverages(self):
        """Visit all the linked coverages of a simulation."""
        self._visit_all_bc_coverages()
        self._visit_all_material_coverages()

    def _visit_all_bc_coverages(self):
        """Visit all linked BC coverages of a simulation."""
        self._logger.info('Merging boundary conditions...')
        for cov_idx, (bc_cov, bc_comp) in enumerate(zip(self._xms_data.bc_covs, self._xms_data.bc_comps)):
            # Loop through all the arcs in this coverage.
            self.bc_lookup[cov_idx] = {TargetType.point: {}, TargetType.arc: {}, TargetType.polygon: {}}
            arc_data = bc_comp.data.bcs
            polygon_data = bc_comp.data.polygons
            point_data = bc_comp.data.points
            export_format = bc_comp.data.globals.attrs['export_format']
            for do_point in bc_cov.get_points(FilterLocation.PT_LOC_DISJOINT):
                self._visit_all_bc_features(bc_comp, point_data, cov_idx, do_point, TargetType.point, bc_cov.name,
                                            export_format)
            for do_arc in bc_cov.arcs:
                self._visit_all_bc_features(bc_comp, arc_data, cov_idx, do_arc, TargetType.arc, bc_cov.name)
            for do_poly in bc_cov.polygons:
                self._visit_all_bc_features(bc_comp, polygon_data, cov_idx, do_poly, TargetType.polygon, bc_cov.name)

    def _visit_all_bc_features(self, bc_comp, feature_data, cov_idx, do_feature, target_type, cov_name,
                               export_format=''):
        """Visit all the feature in a BC coverage.

        Args:
            bc_comp (BcComponent): The coverage's component
            feature_data (Dataset): The data for a feature type
            cov_idx (int): Index of the coverage in the lockup dict
            do_feature (data_objects feature): The feature's data_objects class
            target_type (TargetType): The type of feature to visit
            cov_name (str): Name of the coverage
            export_format (str): Whether we are writing to the 2dm/fvc or a GIS shapefile
        """
        comp_id = bc_comp.get_comp_id(target_type, do_feature.id)
        atts = None
        if (comp_id is not None and comp_id == UNINITIALIZED_COMP_ID) or feature_data is None:
            if target_type != target_type.arc:  # If a monitor line, no atts
                feature_type = 'point' if target_type == TargetType.point else 'polygon'
                self._logger.warning(
                    f'No attributes found for Boundary Condition {feature_type} {do_feature.id} in the {cov_name} '
                    f'coverage. It will not be exported.'
                )
                return
            else:  # Get some default arc attributes
                default_data = get_default_bc_data(fill=False, gridded=False)
                coords = {'comp_id': np.array([], np.int64)}
                feature_data = xr.Dataset(data_vars=default_data, coords=coords)
        bc_atts = feature_data.where(feature_data.comp_id == comp_id, drop=True)
        if bc_atts['type'] != 'Monitor':
            bc_curve = bc_comp.data.get_bc_curve(comp_id=comp_id, bc_type=bc_atts.type.item(), default=False)
            atts = (bc_atts, bc_curve)
        if bc_atts['type'] == 'QC_POLY':
            self.bc_types[cov_idx] = GIS_POLYGON
        elif bc_atts['type'] == 'QC':
            self.bc_types[cov_idx] = GIS_POINT
            if export_format == '2dm':  # If this is an old style QC point, handle it separately.
                self.bc_points.append([do_feature, atts[0], atts[1]])
                return
        self.bcs[self.next_bc_id] = atts
        self.bc_lookup[cov_idx][target_type][do_feature.id] = self.next_bc_id
        self.next_bc_id += 1

    def _visit_all_material_coverages(self):
        """Visit all linked Material coverages of a simulation and build a union of their material lists."""
        self._logger.info('Merging material lists...')
        for cov_idx, (_, mat_comp) in enumerate(zip(self._xms_data.mat_covs, self._xms_data.mat_comps)):
            self.material_lookup[cov_idx] = {}
            for original_mat_id in mat_comp.data.materials.id.data.tolist():
                if original_mat_id == MaterialData.UNASSIGNED_MAT:
                    continue
                mat_atts = mat_comp.data.materials.where(mat_comp.data.materials.id == original_mat_id, drop=True)
                new_mat_id = self._compare_materials(mat_atts, False)
                self.material_lookup[cov_idx][original_mat_id] = new_mat_id
        attrs = self._xms_data.sim_data.global_set_mat.attrs
        if attrs['define_set_mat']:  # Add global default material to the list
            # If there is a global material defined, it will always be the second row in the simulation data's
            # material Dataset.
            set_mat = self._xms_data.sim_data.global_set_mat
            mat_atts = set_mat.where(set_mat.index == 1, drop=True)
            self.default_mat_id = self._compare_materials(mat_atts, True)

    def _compare_material_visosity(self, existing_mat, new_mat):
        """Check if a material shares viscosity attributes with an existing known material.

        Args:
            existing_mat (xr.Dataset): Existing material attributes to compare against
            new_mat (xr.Dataset): New material attributes to compare against

        Returns:
            bool: False if the material has different viscosity attributes from the existing materials
        """
        # flake8: noqa W504
        if not bool(existing_mat.override_horizontal_eddy_viscosity == new_mat.override_horizontal_eddy_viscosity):
            return False
        if existing_mat.override_horizontal_eddy_viscosity:
            if not bool(existing_mat.horizontal_eddy_viscosity == new_mat.horizontal_eddy_viscosity):
                return False
        if not bool(existing_mat.override_horizontal_eddy_viscosity_limits ==
                    new_mat.override_horizontal_eddy_viscosity_limits):
            return False
        if existing_mat.override_horizontal_eddy_viscosity_limits:
            if not bool(existing_mat.horizontal_eddy_viscosity_minimum == new_mat.horizontal_eddy_viscosity_minimum):
                return False
            if not bool(existing_mat.horizontal_eddy_viscosity_maximum == new_mat.horizontal_eddy_viscosity_maximum):
                return False
        if not bool(existing_mat.override_vertical_eddy_viscosity_limits ==
                    new_mat.override_vertical_eddy_viscosity_limits):
            return False
        if existing_mat.override_vertical_eddy_viscosity_limits:
            if not bool(existing_mat.vertical_eddy_viscosity_minimum == new_mat.vertical_eddy_viscosity_minimum):
                return False
            if not bool(existing_mat.vertical_eddy_viscosity_maximum == new_mat.vertical_eddy_viscosity_maximum):
                return False
        return True

    def _compare_material_diffusivity(self, existing_mat, new_mat):
        """Check if a material shares diffusivity attributes with an existing known material.

        Args:
            existing_mat (xr.Dataset): Existing material attributes to compare against
            new_mat (xr.Dataset): New material attributes to compare against

        Returns:
            bool: False if the material has different diffusivity attributes from the existing materials
        """
        if not bool(existing_mat.override_horizontal_scalar_diffusivity ==
                    new_mat.override_horizontal_scalar_diffusivity):
            return False
        if existing_mat.override_horizontal_scalar_diffusivity:
            if not bool(existing_mat.horizontal_scalar_diffusivity == new_mat.horizontal_scalar_diffusivity):
                return False
        if not bool(existing_mat.override_horizontal_scalar_diffusivity_limits ==
                    new_mat.override_horizontal_scalar_diffusivity_limits):
            return False
        if existing_mat.override_horizontal_scalar_diffusivity_limits:
            if not bool(existing_mat.horizontal_scalar_diffusivity_minimum ==
                        new_mat.horizontal_scalar_diffusivity_minimum):
                return False
            if not bool(existing_mat.horizontal_scalar_diffusivity_maximum ==
                        new_mat.horizontal_scalar_diffusivity_maximum):
                return False
        if not bool(existing_mat.override_vertical_scalar_diffusivity_limits ==
                    new_mat.override_vertical_scalar_diffusivity_limits):
            return False
        if existing_mat.override_vertical_scalar_diffusivity_limits:
            if not bool(existing_mat.vertical_scalar_diffusivity_minimum ==
                        new_mat.vertical_scalar_diffusivity_minimum):
                return False
            if not bool(existing_mat.vertical_scalar_diffusivity_maximum ==
                        new_mat.vertical_scalar_diffusivity_maximum):
                return False
        return True

    def _compare_materials(self, mat_atts, combine_mismatched_names):
        """Check if a new material we are adding matches a material we have already added to the list.

        Args:
            mat_atts (xr.Dataset): material attributes to compare against

        Returns:
            mat_id (int): material id of matching material or new id if it does not match
        """
        for mat_id, mat_dset in self.material_list.items():
            if not bool(mat_dset.inactive == mat_atts.inactive):
                continue
            if not bool(mat_dset.override_bottom_roughness == mat_atts.override_bottom_roughness):
                continue
            if mat_dset.override_bottom_roughness:
                if not bool(mat_dset.bottom_roughness == mat_atts.bottom_roughness):
                    continue
            if not self._compare_material_visosity(mat_dset, mat_atts):
                continue
            if not self._compare_material_diffusivity(mat_dset, mat_atts):
                continue
            if not bool(mat_dset.override_bed_elevation_limits == mat_atts.override_bed_elevation_limits):
                continue
            if mat_dset.override_bed_elevation_limits:
                if not bool(mat_dset.bed_elevation_minimum == mat_atts.bed_elevation_minimum):
                    continue
                if not bool(mat_dset.bed_elevation_maximum == mat_atts.bed_elevation_maximum):
                    continue
            if not bool(mat_dset.spatial_reconstruction == mat_atts.spatial_reconstruction):
                continue
            if not bool(mat_dset.name == mat_atts.name):
                if combine_mismatched_names:
                    return mat_id
                # Warn if we find two materials with the same attributes but different names because we are going to
                # export them as separate materials. Keep track of the mismatches so we only warn once.
                # mismatches
                existing_mat_name = mat_dset.name.item()
                new_mat_name = mat_atts.name.item()
                if new_mat_name not in self._mismatched_material_names.setdefault(existing_mat_name, set()):
                    self._logger.warning(
                        f'The two materials: {existing_mat_name}, {new_mat_name} have the same attributes but '
                        'different names and will be exported as two separate materials.'
                    )
                    self._mismatched_material_names[existing_mat_name].add(new_mat_name)
                continue
            return mat_id
        mat_id = self._next_mat_id
        self.material_list[mat_id] = mat_atts
        self._next_mat_id += 1
        return mat_id

    def twodm_material_coverages(self):
        """Returns a list of the linked Materials coverages that have a .2dm export format.

        Returns:
            list[int]: See description
        """
        return [
            idx for idx, mat_comp in enumerate(self._xms_data.mat_comps)
            if mat_comp.data.info.attrs['export_format'] == '2dm'
        ]

    def twodm_bc_coverages(self):
        """Returns a list of the linked BC coverages that have a .2dm export format.

        Returns:
            list[int]: See description
        """
        return [
            idx for idx, bc_comp in enumerate(self._xms_data.bc_comps)
            if bc_comp.data.globals.attrs['export_format'] == '2dm'
        ]

    def shapefile_material_coverages(self):
        """Returns a list of the linked Materials coverages that have a shapefile export format.

        Returns:
            list[int]: See description
        """
        return [
            idx for idx, mat_comp in enumerate(self._xms_data.mat_comps)
            if mat_comp.data.info.attrs['export_format'] == 'Shapefile'
        ]

    def shapefile_bc_coverages(self):
        """Returns a list of the linked BC coverages that have a shapefile export format.

        Returns:
            list[int]: See description
        """
        return [
            idx for idx, bc_comp in enumerate(self._xms_data.bc_comps)
            if bc_comp.data.globals.attrs['export_format'] == 'Shapefile'
        ]
