"""Data2dFromData3d class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from logging import Logger
import uuid

# 2. Third party modules
from rtree import index
from shapely.geometry import LineString, Point, Polygon

# 3. Aquaveo modules
from xms.constraint import Grid, UGrid2dFromUGrid3dCreator, UnconstrainedGrid
from xms.constraint.ugrid_2d_from_ugrid_3d_creator import cogrid_2d_from_locations_and_cells
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules


HIGHEST_ACTIVE_KEY = 2
AVERAGE_KEY = 3
MAX_KEY = 4
MIN_KEY = 5
LAYERS_KEY = 6


class Data2dFromData3d():
    """Algorithm that converts a 3D UGrid dataset to a 2D UGrid and datasets."""
    def __init__(self, in_ds: DatasetReader, co_grid: Grid, logger: Logger, compute_highest_value=True,
                 compute_average_value=False, compute_max_value=False, compute_min_value=False,
                 compute_each_layer=False, output_ugrid_name=''):
        """
        Initializes the class.

        Args:
            in_ds: The input dataset
            co_grid: The input constrained grid
            logger: The logger that outputs to the user
            compute_highest_value: Whether to compute the highest active value in the column
            compute_average_value: Whether to compute the average value in the column
            compute_max_value: Whether to compute the maximum value in the column
            compute_min_value: Whether to compute the minimum value in the column
            compute_each_layer: Whether to compute the value for each layer in the column
            output_ugrid_name: The name for the output UGrid
        """
        self._in_ds = in_ds
        self._in_ug_num_layers = max(co_grid.cell_layers)
        self._in_co_grid = co_grid
        self._in_ug_is_stacked = co_grid.check_is_stacked_grid()
        self._in_ug = co_grid.ugrid
        self._in_ug_vertical_connections = None
        self._in_ug_vert_refinement = False
        self._out_grid = None
        self._orig_cell_idx = []
        self._cell_columns = None
        self._cell_col_dict = None
        self._data_sets = []
        self._ds_writers = None
        self._ds_layers = None
        # variables for processing the input dataset
        self._ds_time = None
        self._new_ds_cell_idx = None
        self._ds_cell_col = None
        self._ds_cell_scalars = None
        self._ds_cell_active = None

        self.logger = logger

        self._output_ugrid_name = output_ugrid_name

        self.compute_highest_value = compute_highest_value
        self.compute_average_value = compute_average_value
        self.compute_max_value = compute_max_value
        self.compute_min_value = compute_min_value
        self.compute_each_layer = compute_each_layer

    def data_2d_from_data_3d(self):
        """

        Converts a 3D UGrid dataset to a 2D UGrid and datasets.

        Returns:
            (Tuple<String, UnconstrainedGrid, List<DatasetWriter>>) A tuple that holds the resulting UGrid name,
            UGrid, and DatasetWriters from the operation.
        """
        self._create_out_ugrid()
        self._create_data_sets()
        return self._output_ugrid_name, self._out_grid, self._data_sets

    def _create_out_ugrid(self):
        """Creates a ugrid from the x, y coordinates in a data frame."""
        self.logger.info('Creating 2D UGrid from 3D UGrid.')
        if self._in_ug_is_stacked:
            self._out_grid_from_stacked_grid()
        else:
            self._build_vertical_connectivity()
            self._create_cell_columns()
            ug_pts_dict = dict()
            ug_pts = []
            cells = []
            cell_polys = []
            self._cell_rtree = index.Index()
            for lay, col in self._cell_columns:
                for c_idx in col[lay]:
                    _, poly = self._in_ug.get_cell_plan_view_polygon(c_idx)
                    sh_poly = Polygon([(p[0], p[1]) for p in poly])
                    self._cell_rtree.insert(len(cell_polys), sh_poly.bounds)
                    cell_polys.append(sh_poly)
                    pt_idxs = []
                    for p in poly:
                        pt = (p[0], p[1], 0.0)
                        if pt in ug_pts_dict:
                            pt_idxs.append(ug_pts_dict[pt])
                        else:
                            pt_idx = len(ug_pts)
                            ug_pts_dict[pt] = pt_idx
                            pt_idxs.append(pt_idx)
                            ug_pts.append(pt)
                    cells.append(pt_idxs)
                    self._orig_cell_idx.append(c_idx)
            # find any cell edges that must be split
            for idx, cell_poly in enumerate(cell_polys):
                adj_cells = list(self._cell_rtree.intersection(cell_poly.bounds))
                adj_cells.remove(idx)
                pts = set(cell_polys[idx].boundary.coords)
                for adj in adj_cells:
                    if adj < 0 or adj >= len(cell_polys):
                        pass  # pragma no cover - Don't know how to hit, lost coverage after GDAL 3.4.1 update.
                    adj_poly = cell_polys[adj]
                    check_pts = pts - set(cell_polys[adj].boundary.coords)
                    check_pts = [Point(p) for p in check_pts]
                    # common points removed, see if any remaining points touch the polygon
                    for p in check_pts:
                        if p.touches(adj_poly):  # insert the point into the adj cell
                            p_pts = list(cell_polys[adj].boundary.coords)
                            for i in range(1, len(p_pts)):  # make polygon into lines
                                line = LineString((p_pts[i - 1], p_pts[i]))
                                if p.distance(line) == 0.0:  # insert point on line that it touches
                                    # rebuild the cell and the polygon
                                    cell_pts = cells[adj]
                                    pt_idx = ug_pts_dict[(p.x, p.y, 0.0)]
                                    cells[adj] = cell_pts[:i] + [pt_idx] + cell_pts[i:]
                                    new_pts = [(ug_pts[j][0], ug_pts[j][1]) for j in cells[adj]]
                                    cell_polys[adj] = Polygon(new_pts)
                                    break
            self._out_grid = cogrid_2d_from_locations_and_cells(ug_pts, cells, self._get_uuid)

        grid_uuid = self._out_grid.uuid
        self._out_grid = UnconstrainedGrid(self._out_grid.ugrid)
        self._out_grid.uuid = grid_uuid

        if self._output_ugrid_name == '':
            self._output_ugrid_name = self._in_ds.name

    def _out_grid_from_stacked_grid(self):
        """Creates the output grid from a stacked grid."""
        self.logger.info('Processing stacked grid.')
        cell_lay = self._in_co_grid.cell_layers
        ncell_per_lay = int(len(cell_lay) / self._in_ug_num_layers)
        self._in_ug_vertical_connections = [None] * self._in_ug.cell_count
        self._cell_columns = []
        for i in range(ncell_per_lay):
            layers = [[]] + [[i + (j * ncell_per_lay)] for j in range(self._in_ug_num_layers)] + [[]]
            self._cell_columns.append((1, layers))
            for j in range(self._in_ug_num_layers):
                cell_idx = i + (j * ncell_per_lay)
                self._in_ug_vertical_connections[cell_idx] = (layers[j], layers[j + 2])
        self._orig_cell_idx = [i for i in range(ncell_per_lay)]
        ug_create = UGrid2dFromUGrid3dCreator()
        self._out_grid = ug_create.create_2d_cogrid(self._in_co_grid, 'Top')

    def _build_vertical_connectivity(self):
        """Build the vertical connectivity for every cell."""
        bot = self._in_ug.face_orientation_enum.ORIENTATION_BOTTOM
        top = self._in_ug.face_orientation_enum.ORIENTATION_TOP
        self._in_ug_vertical_connections = []
        for c_idx in range(self._in_ug.cell_count):
            tops = []
            bots = []
            fc = self._in_ug.get_cell_3d_face_count(c_idx)
            for fc_idx in range(fc):
                face_orient = self._in_ug.get_cell_3d_face_orientation(c_idx, fc_idx)
                adj_idx = self._in_ug.get_cell_3d_face_adjacent_cell(c_idx, fc_idx)
                if face_orient == top and adj_idx > -1:
                    tops.append(adj_idx)
                elif face_orient == bot and adj_idx > -1:
                    bots.append(adj_idx)
            if len(tops) > 1 or len(bots) > 1:
                self._in_ug_vert_refinement = True
            self._in_ug_vertical_connections.append((tops, bots))
        # create a spatial index of cells without cells below that are not in the last or second
        # to last layer
        self._cell_rtree = index.Index()
        for c_idx in range(self._in_ug.cell_count):
            if self._in_co_grid.cell_layers[c_idx] > self._in_ug_num_layers - 2:
                continue
            if len(self._in_ug_vertical_connections[c_idx][1]) > 0:
                continue
            _, poly = self._in_ug.get_cell_plan_view_polygon(c_idx)
            sh_poly = Polygon([(p[0], p[1]) for p in poly])
            self._cell_rtree.insert(c_idx, sh_poly.bounds)

        # check cells that are not in layer 1 or 2 that do not have a cell above them
        # to see if they should be connected to a cell above
        for c_idx in range(self._in_ug.cell_count):
            if self._in_co_grid.cell_layers[c_idx] < 3:
                continue
            if len(self._in_ug_vertical_connections[c_idx][0]) > 0:
                continue
            _, poly = self._in_ug.get_cell_plan_view_polygon(c_idx)
            sh_poly = Polygon([(p[0], p[1]) for p in poly])
            # any cell centers in my polygon will be counted as above and if there are multiple layers
            # then you only get the layer closest above
            col_cells = [[] for i in range(self._in_ug_num_layers + 1)]
            cells = list(self._cell_rtree.intersection(sh_poly.bounds))
            for c in cells:
                _, pt = self._in_ug.get_cell_centroid(c)
                sh_pt = Point((pt[0], pt[1]))
                if sh_poly.contains(sh_pt):
                    lay = self._in_co_grid.cell_layers[c]
                    col_cells[lay].append(c)

            lay = self._in_co_grid.cell_layers[c_idx]
            while lay > 0:
                lay -= 1
                if len(col_cells[lay]) > 0:
                    col_cells = col_cells[lay]
                    break
            if lay > 0 and len(col_cells) > 0:
                self._in_ug_vertical_connections[c_idx][0].extend(col_cells)
                for c in col_cells:
                    self._in_ug_vertical_connections[c][1].append(c_idx)

    def _create_cell_columns(self):
        """Create columns of cells for further processing."""
        self._cell_columns = []
        self._cell_cols_dict = dict()
        for c_idx in range(self._in_ug.cell_count):
            if len(self._in_ug_vertical_connections[c_idx][0]) < 1:
                set_cells = set()
                cell_col = [[] for i in range(self._in_ug_num_layers + 1)]
                lay = self._in_co_grid.cell_layers[c_idx]
                cell_col[lay].append(c_idx)
                set_cells.add(c_idx)

                below = self._in_ug_vertical_connections[c_idx][1]
                while len(below) > 0:
                    lay = self._in_co_grid.cell_layers[below[0]]
                    cell_col[lay] = below
                    set_below = set()
                    for b in below:
                        set_below.update(self._in_ug_vertical_connections[b][1])
                    below = list(set_below)
                    set_cells.update(below)

                num_per_lay = [len(c) for c in cell_col]
                lay = num_per_lay.index(max(num_per_lay))
                for c in set_cells:
                    self._cell_cols_dict[c] = (lay, cell_col)
                self._cell_columns.append((lay, cell_col))

    def _create_data_sets(self):
        """Creates datasets from columns in a data frame."""
        self.logger.info('Creating datasets for 2D UGrid from 3D dataset.')
        self._set_up_data_sets()

        ds = self._in_ds
        times = list(ds.times)
        null = ds.null_value
        for i, time in enumerate(times):
            ds_vals = ds.values[i]
            ds_act = None if not ds.activity else ds.activity[i]
            self._ds_time = time
            for j, cell_idx in enumerate(self._orig_cell_idx):
                self._new_ds_cell_idx = j
                self._ds_cell_col = self._calc_ds_cell_col(cell_idx)
                # self._ds_cell_col = self._cell_columns[cell_idx]
                self._ds_cell_scalars = [ds_vals[s_idx] for s_idx in self._ds_cell_col]
                self._ds_cell_active = None
                if ds_act is not None or null is not None:
                    if null is not None:
                        self._ds_cell_active = [0 if s == null else 1 for s in self._ds_cell_scalars]
                    else:
                        self._ds_cell_active = [ds_act[s_idx] for s_idx in self._ds_cell_col]
                self._add_ds_value()
            self._end_time_step()

        self._ds_appending_finished()

    def _set_up_data_sets(self):
        """Creates DatasetWriter for each of the output datasets."""
        ds_name = self._in_ds.name
        names = [f'{ds_name}_highest_active', f'{ds_name}_average', f'{ds_name}_max', f'{ds_name}_min']
        arg_ids = [HIGHEST_ACTIVE_KEY, AVERAGE_KEY, MAX_KEY, MIN_KEY]
        arg_values = [self.compute_highest_value, self.compute_average_value, self.compute_max_value,
                      self.compute_min_value]
        ds_writers = self._list_of_data_sets_from_args(arg_values, names)
        vals = [0.0] * len(self._orig_cell_idx)
        act = None
        if self._in_ds.activity is not None or self._in_ds.null_value is not None:
            act = [1] * len(self._orig_cell_idx)
        self._ds_writers = {k: [v, vals.copy(), None if not act else act.copy()] for k, v in zip(arg_ids, ds_writers)}

        # lay_arg = [LAYERS_KEY] * self._in_ug_num_layers
        lay_values = [self.compute_each_layer] * self._in_ug_num_layers
        lay_names = [f'{ds_name}_layer_{lay + 1}' for lay in range(self._in_ug_num_layers)]
        layers = self._list_of_data_sets_from_args(lay_values, lay_names)
        self._ds_layers = [[lay, vals.copy(), None if not act else act.copy()] for lay in layers]

    @staticmethod
    def _get_uuid():
        """Returns a random uuid string (or maybe not so random if testing)."""
        return str(uuid.uuid4())

    def _calc_ds_cell_col(self, cell_idx):
        """Creates a column of cells given cell_idx.

        Args:
            cell_idx (int): index of the cell

        Returns:
            (list): list of ints of cells in this column
        """
        vc = self._in_ug_vertical_connections[cell_idx]
        _, pt = self._in_ug.get_cell_centroid(cell_idx)
        sh_pt = Point((pt[0], pt[1]))
        above = []
        while len(vc[0]) > 0:
            above.append(self._find_cell_containing_pt(sh_pt, vc[0]))
            vc = self._in_ug_vertical_connections[above[-1]]
        above.reverse()

        vc = self._in_ug_vertical_connections[cell_idx]
        below = []
        while len(vc[1]) > 0:
            below.append(self._find_cell_containing_pt(sh_pt, vc[1]))
            vc = self._in_ug_vertical_connections[below[-1]]
        return above + [cell_idx] + below

    def _find_cell_containing_pt(self, sh_pt, cell_list):
        """Finds the first cell that contains the point.

        Args:
            sh_pt (Point): shapely point class
            cell_list (list(int)): list of cell indexes

        Returns:
            (int): index of cell containing point
        """
        ret_cell = cell_list[0]
        if len(cell_list) > 1:  # find the first above cell that contains cell_idx
            for c_idx in cell_list:
                _, poly = self._in_ug.get_cell_plan_view_polygon(c_idx)
                sh_poly = Polygon([(p[0], p[1]) for p in poly])
                if sh_poly.contains(sh_pt):
                    ret_cell = c_idx
                    break
        return ret_cell

    def _add_ds_value(self):
        """Add a value to each dataset being computed."""
        column_not_active = False
        s_list = self._ds_cell_scalars
        if self._ds_cell_active:
            if self._ds_cell_active.count(0) == len(self._ds_cell_active):  # all cells in column NOT active
                column_not_active = True
            else:
                s_list = []
                for s, a in zip(self._ds_cell_scalars, self._ds_cell_active):
                    if a != 0:
                        s_list.append(s)
        values = [
            (HIGHEST_ACTIVE_KEY, s_list[0]),
            (AVERAGE_KEY, sum(s_list) / len(s_list)),
            (MAX_KEY, max(s_list)),
            (MIN_KEY, min(s_list))
        ]
        for val in values:
            dsw = self._ds_writers[val[0]]
            if dsw[0] is not None:
                dsw[1][self._new_ds_cell_idx] = val[1]
                if column_not_active:
                    dsw[2][self._new_ds_cell_idx] = 0

        if self.compute_each_layer:
            set_lay = set()
            for i, cell in enumerate(self._ds_cell_col):
                lay = self._in_co_grid.cell_layers[cell]
                set_lay.add(lay)
                dsw = self._ds_layers[lay - 1]
                dsw[1][self._new_ds_cell_idx] = self._ds_cell_scalars[i]
                if self._ds_cell_active and self._ds_cell_active[i] == 0:
                    dsw[2][self._new_ds_cell_idx] = 0

            all_lay = set(range(1, len(self._ds_layers) + 1))
            missing = all_lay - set_lay
            # set inactive any missing layers
            for lay in missing:
                dsw = self._ds_layers[lay - 1]
                if dsw[2] is None:
                    dsw[2] = [1] * len(dsw[1])
                dsw[2][self._new_ds_cell_idx] = 0

    def _end_time_step(self):
        """Appends a time step to the dataset writer."""
        dsw_list = []
        for _, dsw in self._ds_writers.items():
            if dsw[0] is not None:
                dsw_list.append(dsw)
        for dsw in self._ds_layers:
            if dsw[0] is not None:
                dsw_list.append(dsw)

        for dsw in dsw_list:
            dsw[0].append_timestep(time=self._ds_time, data=dsw[1], activity=dsw[2])
            dsw[1] = [0.0] * len(dsw[1])
            dsw[2] = [1] * len(dsw[2]) if dsw[2] is not None else None

    def _ds_appending_finished(self):
        """Finished appending to the dataset writers."""
        for _, dsw in self._ds_writers.items():
            if dsw[0] is not None:
                dsw[0].appending_finished()
                self._data_sets.append(dsw[0])

        for dsw in self._ds_layers:
            if dsw[0] is not None:
                dsw[0].appending_finished()
                self._data_sets.append(dsw[0])

    def _list_of_data_sets_from_args(self, arg_list, name_list):
        """Creates a list of datasets given the arg_list.

        Args:
            arg_list (list): list of tool argument values
            name_list (list): list of strings for dataset names

        Returns:
            (list(DatasetWriter)): a list of the datasets

        """
        ret_val = []
        for arg, name in zip(arg_list, name_list):
            if arg:
                _writer = DatasetWriter(
                    name=name,
                    geom_uuid=self._out_grid.uuid,
                    num_components=1,
                    ref_time=self._in_ds.ref_time,
                    time_units=self._in_ds.time_units,
                    null_value=self._in_ds.null_value,
                    location='cells',
                )

                ret_val.append(_writer)
            else:
                ret_val.append(None)
        return ret_val
