"""RenumberUGridTool class."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.tree import tree_util, TreeNode
from xms.constraint import UGrid3d
from xms.grid.ugrid import UGrid
from xms.tool_core import IoDirection, Tool
from xms.tool_core.grid_argument import GridArgument

# 4. Local modules
from xms.hgs.components import dmi_util
from xms.hgs.tools.ugrid_renumberer import UGridRenumberer

ARG_INPUT_GRID = 0
ARG_OUTPUT_GRID = 1


class RenumberUGridTool(Tool):
    r"""Tool to renumber of UGrid the way HydroGeoSphere requires.

    ::

           8----------7    4---------6
          /|         /|    |\       /|
         / |        / |    | \     / |
        5----------6  |    |  \   /  |
        |  4-------|--3    1---\ /---3
        | /        | /      \   5   /
        |/         |/        \  |  /
        1----------2          \ | /
                               \|/
                                2

    Cell nodes must be numbered as shown, bottom to top, counterclockwise locking down, and
    cells and nodes must be numbered in layers, bottom to top. See HydroGeoSphere reference manual.
    """
    def __init__(self) -> None:
        """Initializes the class."""
        super().__init__(name='Renumber UGrid for HydroGeoSphere')
        self._point_locations: list[tuple[float]] | None = None
        self._co_grid = None
        self._ugrid = None  # Do 'self._co_grid.ugrid' only once as it's costly
        self._query = None

        # For testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self) -> list[GridArgument]:
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        in_grid = dmi_util.get_default_grid(self._query)
        in_path = _argument_path_from_node(in_grid)
        out_name = in_grid.name + '-hgs' if in_grid else ''
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid', value=in_path),
            self.grid_argument(
                name='output_grid', description='Output grid', io_direction=IoDirection.OUTPUT, value=out_name
            )
        ]
        return arguments

    def set_data_handler(self, data_handler) -> None:
        """Set up query attribute if we have a XMSDataHandler."""
        super().set_data_handler(data_handler)
        if hasattr(self._data_handler, "_query"):
            self._query = self._data_handler._query

    def validate_arguments(self, arguments: list[GridArgument]) -> dict[str, str]:
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])
        return errors

    def _validate_input_grid(self, errors: dict[str, str], argument: GridArgument) -> None:
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        if argument.value is None:
            return

        key = argument.name
        self._co_grid = self.get_input_grid(argument.text_value)
        if not self._co_grid:
            errors[key] = 'Could not read grid.'
            return

        self._ugrid = self._co_grid.ugrid
        if self._ugrid.cell_count <= 0:
            errors[key] = 'Grid has no cells.'

        if not self._co_grid.check_all_cells_3d():
            errors[key] = 'Must have all 3D cells.'

        # We don't check for stacked anymore because stacked also means the grid must be numbered in a stacked
        # way.
        # rv = self._co_grid.check_is_stacked_grid()
        # stacked = rv is not None and rv[0] != 0
        # if not stacked:
        #     errors[key] = (
        #         'Grid must be "stacked". A "stacked grid" is one in which there is no vertical'
        #         ' sub-discretization of layers and the horizontal discretization of all layers is'
        #         ' the same.'
        #     )

        all_hex = self._co_grid.check_all_cells_are_of_type(UGrid.cell_type_enum.HEXAHEDRON)
        all_prism = self._co_grid.check_all_cells_are_of_type(UGrid.cell_type_enum.WEDGE)
        if all_hex or all_prism:
            return

        # Could be polyhedron. Assuming a stacked grid, all cells must be hexes (8 nodes) or prisms (6 nodes)
        poly_hexes, poly_prisms, others = False, False, False
        for cellidx in range(self._ugrid.cell_count):
            if self._ugrid.get_cell_type(cellidx) == UGrid.cell_type_enum.POLYHEDRON:
                if self._ugrid.get_cell_point_count(cellidx) == 8:
                    poly_hexes = True
                elif self._ugrid.get_cell_point_count(cellidx) == 6:
                    poly_prisms = True
                else:
                    others = True
            else:
                others = True
        if others or (not poly_hexes and not poly_prisms):
            errors[key] = 'Grid cells must be either all hexahedron or all triangular prisms.'

    def _build_co_grid(self, arguments: list[GridArgument]) -> UGrid3d:
        """Builds and returns the new, renumbered, UGrid.

        Args:
            arguments (list[GridArgument]): The tool arguments.

        Returns:
              (UGrid3d): The new UGrid.
        """
        if self._co_grid is None:  # This can be None only in testing
            self._co_grid = self.get_input_grid(arguments[ARG_INPUT_GRID].text_value)
            self._ugrid = self._co_grid.ugrid

        renumberer = UGridRenumberer(self._co_grid, self._ugrid)
        return renumberer.build_cogrid()

    def run(self, arguments: list[GridArgument]) -> None:
        """Override to run the tool.

        Args:
            arguments (list[GridArgument]): The tool arguments.
        """
        # import time
        # time.sleep(10)
        new_co_grid_3d = self._build_co_grid(arguments)
        self.set_output_grid(new_co_grid_3d, arguments[ARG_OUTPUT_GRID])


def _argument_path_from_node(tree_node: TreeNode) -> str:
    """Returns tool argument compatible tree path to the node ('Project/' is stripped)."""
    if not tree_node:
        return ''
    path = tree_util.tree_path(tree_node)
    if path.startswith('Project/'):
        path = path[len('Project/'):]
    return path
