"""IdomainMapper class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv

# 2. Third party modules
from shapely.geometry import LineString

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.coverage.polygons import polygon_orienteer
from xms.data_objects.parameters import FilterLocation

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.geom import geom
from xms.mf6.mapping import grid_intersector, map_util
from xms.mf6.misc import log_util


def _layer_range_exists(table_def):
    """Returns True if the layer range attributes ('From layer', 'To layer') are part of the table.

    Args:
        table_def (dict): The table definition.

    Returns:
        (bool): See description
    """
    if 'columns' in table_def:
        for column in table_def['columns']:
            if 'name' in column and column['name'] == 'From layer':
                return True
    return False


def _point_in_polygon(polygon, point):
    """Returns True if the point is inside the polygon and not inside any holes.

    Args:
        polygon (list of lists of tuple): Polygon with possible inner polygons.
        point (tuple of float): xyz point.

    Returns:
        (bool): See description.
    """
    inside = False
    rv = geom.point_in_polygon_2d(polygon[0][:-1], point)
    if rv == 0 or rv == 1:  # In or on
        inside = True
        for inner_polygon in polygon[1:]:
            rv2 = geom.point_in_polygon_2d(inner_polygon[:-1], point)
            if rv2 == 1:
                inside = False
                break
    return inside


class IdomainMapper:
    """Handles creating IDOMAIN from coverages.

    Does point in polygon of cell centers, and adds cells under arcs if the arcs have a type that isn't NONE. Looks
    at the 'From layer' and 'To layer' attributes in the table, and not in the coverage setup, as we currently
    don't have that info come across to Python from GMS.
    """
    def __init__(self, dis_package, cogrid: Grid, coverages, coverage_att_files):
        """Initializes the class.

        If testing, don't pass coverage_uuids and query, and do pass coverages and coverage_att_files.

        Args:
            dis_package (DIS, DISV, or DISU): the dis package from mf6
            cogrid: The constrained grid.
            coverages (list): list of the coverages used to generate the IDOMAIN array
            coverage_att_files (dict): dict of the coverage attribute files
        """
        self._dis_package = dis_package
        self._cogrid = cogrid
        self._ugrid = cogrid.ugrid if cogrid else None  # UGrid of the cogrid, so we only get it once
        self._coverages = coverages
        self._coverage_att_files = coverage_att_files
        self._log = log_util.get_logger()
        # Layer stuff
        self._do_layers = False
        self._layers = []  # List of bool flags indicating which layers have been covered by the layer range
        self._intersector = None  # GridIntersect class from flopy
        self.idomain = None
        self.tops = None
        self.bottoms = None

    def do_work(self):
        """Creates the IDOMAIN array in the dis package."""
        self._init_layer_info()
        self._find_cells_in_polygons()
        self._update_dis_package()

    def _init_layer_info(self):
        """Initialize the variables needed to deal with layers."""
        grid_info = self._dis_package.grid_info()
        self._do_layers = self._dis_package.ftype in ['DIS6', 'DISV6']
        if self._do_layers:
            self._layers = [False] * grid_info.nlay  # The layers we've encountered via 'From layer', 'To layer'

    def _find_cells_in_polygons(self):
        """Gets the array of active cells (IDOMAIN)."""
        cell_centers = dis_builder.get_cell_centers2d(self._cogrid, self._ugrid)
        self.idomain = [0] * len(cell_centers)

        # Loop through coverages
        for coverage in self._coverages:
            att_files = self._coverage_att_files[coverage.uuid]
            if 'polygons' not in att_files:
                continue
            self._polygons_exist = True

            # Read the att tables
            polygon_atts, do_polygons = self._read_att_table(att_files['polygons'])
            if not do_polygons:
                continue
            arc_atts, do_arcs = self._read_att_table(att_files['arcs'])

            # Loop through polygons
            grid_info = self._dis_package.grid_info()
            for polygon in coverage.polygons:
                from_layer, to_layer = self._layer_info_from_att_table(polygon_atts, polygon.id)
                self._update_layers_array(from_layer, to_layer)
                poly_points = polygon_orienteer.get_polygon_point_lists(polygon)

                # Loop through cells doing point in poly
                for cell_idx, point in enumerate(cell_centers):
                    if self._in_layer_range(grid_info, cell_idx, from_layer, to_layer):
                        if _point_in_polygon(poly_points, point):
                            self.idomain[cell_idx] = 1
                if do_arcs:
                    self._add_cells_under_typed_arcs(polygon, arc_atts, self.idomain)

        if self._do_layers:
            self._check_for_layers_not_covered()

    def _add_cells_under_typed_arcs(self, polygon, arc_atts, idomain):
        """Adds cells intersected by arcs if the arc type is not NONE.

        Uses flopy GridIntersect class.

        Args:
            polygon (xms.data_objects.parameters.Polygon): A polygon.
            arc_atts (dict): Dict of the arc attribute table.
            idomain (list of int): The IDOMAIN.
        """
        arcs = polygon.arcs
        grid_info = self._dis_package.grid_info()
        for arc in arcs:
            arc_id = arc.id
            if arc_id in arc_atts and arc_atts[arc_id].get('Type', 'NONE') != 'NONE':
                # Get from and to layer
                from_layer = arc_atts[arc_id].get('From layer', 0)
                to_layer = arc_atts[arc_id].get('To layer', 0)

                # Intersect the arc with the grid
                if not self._intersector:
                    _, self._intersector = grid_intersector.create_flopy_grid_and_intersector(
                        self._dis_package, self._cogrid, self._ugrid
                    )
                arc_points = arc.get_points(FilterLocation.PT_LOC_ALL)
                point_list = [(p.x, p.y, p.z) for p in arc_points]
                linestring = LineString(point_list)
                rec = self._intersector.intersect(linestring)

                # Add intersected cells to idomain
                cellids = rec.cellids.tolist()
                for cellid in cellids:
                    cellid = grid_info.fix_cellid(cellid)
                    if not self._do_layers or from_layer <= cellid[0] <= to_layer:
                        cell_idx = grid_info.cell_index_from_modflow_cellid(cellid)
                        idomain[cell_idx] = 1

    def _in_layer_range(self, grid_info, cell_idx, from_layer, to_layer):
        """Returns True if the cell is in the layer range.

        Args:
            grid_info (GridInfo): Number of rows, cols etc.
            cell_idx (int): 0-based index of the cell.
            from_layer (int): The 'From layer'.
            to_layer (int): The 'To layer'.

        Returns:
            (bool): See description.
        """
        if not self._do_layers:
            return True
        cellid = grid_info.modflow_cellid_from_cell_index(cell_idx)
        return from_layer <= cellid[0] <= to_layer

    def _update_layers_array(self, from_layer, to_layer):
        """Updates the array that keeps track of what layers are covered.

        Args:
            from_layer (int): The 'From layer' for the polygon.
            to_layer (int): The 'To layer' for the polygon.
        """
        if not self._do_layers:
            return
        for layer in range(from_layer, to_layer + 1):
            self._layers[layer - 1] = True

    def _layer_info_from_att_table(self, att_table, feature_id):
        """Returns 'from layer' and 'to layer'.

        Args:
            att_table (dict): Dict of the att table.
            feature_id (int): ID of the current feature.

        Returns:
            (tuple): tuple containing:
                - (int): The 'From layer'.
                - (int): The 'To layer'.
        """
        if not self._do_layers:
            return 0, 0

        from_layer = int(att_table[feature_id].get('From layer', 0))
        to_layer = int(att_table[feature_id].get('To layer', 0))
        return from_layer, to_layer

    def _read_att_table(self, att_file):
        """Returns a dict of the atts we need by reading from the att table file.

        Args:
            att_file (str): Filepath of att table file.

        Returns:
            (tuple): tuple containing:
                - (dict): The atts.
                - (bool): True if we can use this given the layer requirements.
        """
        atts = {}
        table_def = map_util.read_table_definition_file(att_file)
        if self._do_layers and not _layer_range_exists(table_def):
            return atts, False
        column_names = [column['name'] for column in table_def['columns']]
        with open(att_file, 'r') as att_csv_file:
            reader = csv.DictReader(att_csv_file, fieldnames=column_names)
            # skip the header row
            header = next(reader)  # noqa F841 local variable 'header' is assigned to but never used
            for row in reader:
                feature_id = int(row['ID'])
                from_layer, to_layer = map_util.layer_info_from_row(self._do_layers, row)
                feature_dict = {}
                if 'Type' in column_names:
                    feature_dict['Type'] = row['Type']
                feature_dict['From layer'] = from_layer
                feature_dict['To layer'] = to_layer
                atts[feature_id] = feature_dict
        return atts, True

    def _check_for_layers_not_covered(self):
        """Gives warning message for layers not covered by polygons."""
        layers_not_covered = []
        for k, layer in enumerate(self._layers):
            if not layer:
                layers_not_covered.append(k + 1)
        if layers_not_covered:
            not_covered = ', '.join(map(str, layers_not_covered))
            msg = f'The following layers are not assigned to any polygons and will be inactivated: {not_covered}'
            self._log.warning(msg)

    def _update_dis_package(self):
        """Updates the IDOMAIN for the DIS package."""
        griddata = self._dis_package.block('GRIDDATA')
        if not griddata.has('IDOMAIN'):
            griddata.add_array(self._dis_package.new_array('IDOMAIN', layered=False))
        idomain_array = griddata.array('IDOMAIN')
        _, _, shape = self._dis_package.array_size_and_layers('IDOMAIN', layered=idomain_array.layered)
        idomain_array.set_values(values=self.idomain, shape=shape, combine=False)
        idomain_array.dump_to_temp_files()
