"""UGrid3dFromRastersTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
try:
    from xms.guipy.settings import get_file_browser_directory
except ImportError:  # pragma no cover - optional import
    get_file_browser_directory = None
from xms.tool_core import IoDirection, Tool
from xms.tool_core.table_definition import IntColumnType, StringColumnType, TableDefinition

# 4. Local modules
from xms.tool.algorithms.ugrids import ugrid_3d_from_rasters_creator
from xms.tool.algorithms.ugrids.ugrid_3d_from_rasters_creator import Columns, UGrid3dFromRastersCreator, \
    write_sublayers_file

ARG_INPUT_2D_UGRID = 0
ARG_INPUT_RASTERS_TABLE = 1
ARG_INPUT_TARGET_LOCATION = 2
ARG_INPUT_MIN_LAYER_THICKNESS = 3
ARG_INPUT_3D_UGRID_NAME = 4
ARG_OUTPUT_3D_UGRID = 5
ARG_OUTPUT_SUBLAYERS_CSV = 6
ARG_OUTPUT_SUBLAYERS_CSV_FILENAME = 7


class UGrid3dFromRastersTool(Tool):
    """Tool to create a 3D UGrid from a 2D UGrid and rasters."""
    def __init__(self):
        """Initializes the class."""
        super().__init__(name='3D UGrid from Rasters')
        # TODO: Get rid of this after testing
        # import os
        # os.environ['XMSTOOL_GUI_TESTING'] = 'YES'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        rasters = self._data_handler.get_available_rasters()
        table_def, df = _create_table(rasters)
        default_ugrid_2d = self._find_default_ugrid_2d()
        path = None
        if get_file_browser_directory is not None:
            path = os.path.join(get_file_browser_directory(), 'sublayers.csv')
        arguments = [
            self.grid_argument(name='ugrid_2d', description='2D UGrid', value=default_ugrid_2d, optional=False),
            self.table_argument(
                name='rasters_table', description='Rasters', value=df, optional=False, table_definition=table_def
            ),
            self.string_argument(
                name='target_location',
                description='Target location',
                value='Cell tops and bottoms',
                choices=['Cell tops and bottoms', 'Points'],
                optional=False
            ),
            self.float_argument(
                name='min_layer_thickness',
                description='Minimum layer thickness',
                value=0.0,
                min_value=0.0,
                optional=False
            ),
            self.string_argument(name='3d_ugrid_name', description='3D UGrid name', value='', optional=True),
            self.grid_argument(
                name='ugrid_3d', description='The 3D UGrid', hide=True, optional=True, io_direction=IoDirection.OUTPUT
            ),
            self.bool_argument('output_sublayers', description='Write sublayers CSV', value=False),
            self.file_argument(
                name='sublayers_csv',
                description='CSV file',
                file_filter='CSV file (*.csv)',
                default_suffix='csv',
                io_direction=IoDirection.OUTPUT,
                value=path
            ),
        ]
        return arguments

    def _find_default_ugrid_2d(self) -> str:
        """Returns the path to the first 2D UGrid, if there is one, else ''."""
        ugrid_paths = self._data_handler.get_available_grids()
        for ugrid_path in ugrid_paths:
            co_grid = self.get_input_grid(ugrid_path)
            if co_grid.check_all_cells_2d():
                return ugrid_path
        return ''

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors: dict[str, str] = {}
        self._validate_input_ugrid(errors, arguments[ARG_INPUT_2D_UGRID])
        self._validate_rasters_table(
            errors, arguments[ARG_INPUT_RASTERS_TABLE].name, arguments[ARG_INPUT_RASTERS_TABLE].value
        )
        return errors

    def _validate_input_ugrid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        if argument.value == '':
            errors[key] = 'No 2D UGrid selected.'
            return

        co_grid = self.get_input_grid(argument.text_value)
        if not co_grid:
            errors[key] = 'Could not read grid.'
            return
        if co_grid.ugrid.cell_count <= 0:
            errors[key] = 'Grid has no cells.'
            return
        if not co_grid.check_all_cells_2d():
            errors[key] = 'Grid cells must all be 2D.'
            return

    def _validate_rasters_table(self, errors, argument_name: str, df: pd.DataFrame):
        """Validate the rasters table.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument_name (str): Argument name.
            df (pd.DataFrame): DataFrame.
        """
        if df is None:  # pragma no cover - should never happen, but just in case
            return

        if len(df) == 0:
            errors[argument_name] = 'Rasters table is empty.'
            return
        if not self._validate_rasters_column(errors, argument_name, df):
            return
        if not self._validate_proportions_column(errors, argument_name, df):
            return
        if not self._validate_non_zero_layers_created(errors, argument_name, df):
            return

    def _validate_rasters_column(self, errors: dict[str, str], key: str, df: pd.DataFrame) -> bool:
        """Validate the proportions column in the rasters table.

        Args:
            errors (dict[str, str]): Dictionary of errors keyed by argument name.
            key (str): Key for errors dict.
            df (pd.DataFrame): Pandas data frame.

        Returns:
            (bool): True if OK, False if invalid.
        """
        rasters_dict: dict[str, int] = {}
        for row in range(len(df)):
            raster = df[Columns.RASTER].iloc[row]
            if raster in rasters_dict:
                errors[key] = f'"{raster}" used in both row {str(rasters_dict[raster])} and row {str(row + 1)}.'
                return False
            else:
                rasters_dict[raster] = row + 1
        return True

    def _validate_proportions_column(self, errors: dict[str, str], key: str, df: pd.DataFrame) -> bool:
        """Validate the proportions column in the rasters table.

        Args:
            errors (dict[str, str]): Dictionary of errors keyed by argument name.
            key (str): Key for errors dict.
            df (pd.DataFrame): Pandas data frame.

        Returns:
            (bool): True if OK, False if invalid.
        """
        for row in range(len(df)):
            if row < len(df) - 1 and df[Columns.PROPORTIONS].iloc[row] != '':
                ugrid_3d_from_rasters_creator.get_proportions_list(
                    df[Columns.PROPORTIONS].iloc[row], row, df[Columns.SUBLAYERS].iloc[row], errors, key
                )
                if key in errors:
                    return False
        return True

    def _validate_non_zero_layers_created(self, errors: dict[str, str], key: str, df: pd.DataFrame) -> bool:
        """Validate that a non-zero number of layers would be created.

        Args:
            errors (dict[str, str]): Dictionary of errors keyed by argument name.
            key (str): Key for errors dict.
            df (pd.DataFrame): Pandas data frame.

        Returns:
            (bool): True if OK, False if invalid.
        """
        total_layer_count, _ = ugrid_3d_from_rasters_creator.compute_total_num_layers(df)
        if total_layer_count < 1:
            errors[key] = 'No layers would be created.'
            return False
        return True

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        # Automatically number the horizons and fix the last row of Sublayers and Proportions
        df = arguments[ARG_INPUT_RASTERS_TABLE].value
        if len(df) > 0:
            self._fix_horizons(df)
            self._fix_sublayers(df)
            self._fix_proportions(df)
        # Hide the output CSV filename picker if we don't need it.
        arguments[ARG_OUTPUT_SUBLAYERS_CSV_FILENAME].show = arguments[ARG_OUTPUT_SUBLAYERS_CSV].value

    def _fix_horizons(self, df) -> bool:
        """Fix the horizons if necessary.

        Args:
            df(pd.DataFrame): The data frame.

        Returns:
            (bool): True if we made any changes
        """
        correct_horizons = [i for i in range(len(df), 0, -1)]
        if df[Columns.HORIZON].to_list() != correct_horizons:
            df[Columns.HORIZON] = correct_horizons
            return True
        return False

    def _fix_sublayers(self, df) -> bool:
        """Fix the sublayers if necessary.

        Args:
            df(pd.DataFrame): The data frame.

        Returns:
            (bool): True if we made any changes
        """
        # The minimum sublayers should be 1 except for the last row, which should be 0
        changes = False
        for row in range(len(df)):
            if row == len(df) - 1 and df[Columns.SUBLAYERS].iloc[row] != 0:  # Last row
                df.at[row + 1, Columns.SUBLAYERS] = 0
                changes = True
            elif row < len(df) - 1 and df[Columns.SUBLAYERS].iloc[row] < 1:
                df.at[row + 1, Columns.SUBLAYERS] = 1
                changes = True
        return changes

    def _fix_proportions(self, df) -> bool:
        """Fix the proportions if necessary.

        Args:
            df(pd.DataFrame): The data frame.

        Returns:
            (bool): True if we made any changes
        """
        if df[Columns.PROPORTIONS].iloc[-1] != '':
            df.at[len(df), Columns.PROPORTIONS] = ''
            return True
        return False

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Get arguments
        name_2d = arguments[ARG_INPUT_2D_UGRID].text_value
        co_grid_2d = self.get_input_grid(name_2d)
        raster_table = arguments[ARG_INPUT_RASTERS_TABLE].value
        target_location = arguments[ARG_INPUT_TARGET_LOCATION].value
        min_layer_thickness = arguments[ARG_INPUT_MIN_LAYER_THICKNESS].value
        name_3d = arguments[ARG_INPUT_3D_UGRID_NAME].value

        # Replace raster tree paths with file names. Copy raster_table so history isn't messed up
        df = raster_table.copy()
        raster_files = [self.get_input_raster_file(raster) for raster in df[Columns.RASTER].to_list()]
        df[Columns.RASTER] = raster_files

        # Build the 3d grid
        creator = UGrid3dFromRastersCreator(self.default_wkt, self.logger)
        co_grid_3d = creator.create_3d_co_grid(co_grid_2d, df, target_location, min_layer_thickness)
        # co_grid_3d = UnconstrainedGrid(ugrid=co_grid_3d.ugrid) why?
        co_grid_3d.uuid = creator.get_uuid()

        # Set the output grid with the right name
        arguments[ARG_OUTPUT_3D_UGRID].value = _get_ugrid_3d_name(name_2d, name_3d)
        self.set_output_grid(co_grid_3d, arguments[ARG_OUTPUT_3D_UGRID])

        # Optionally, write the horizons to a csv file
        if arguments[ARG_OUTPUT_SUBLAYERS_CSV].value:
            write_sublayers_file(arguments[ARG_OUTPUT_SUBLAYERS_CSV_FILENAME].value, df)


def _create_table(rasters: list[str]) -> tuple[TableDefinition, pd.DataFrame]:
    """Creates and returns a TableDefinition and an empty pd.DataFrame."""
    tool_tips = [
        'Horizon number, in order of deposition (higher numbers on top)', 'The raster',
        'Create cell layers below this raster (if not the bottom raster).',
        'Prevent all rasters below from going above this raster.',
        'Number of cell layers between this raster and the one below.',
        'Space delimited list of integers describing sublayer proportions. E.g. "1 1 2" means the top'
        ' layer and the one below it are the same thickness, and the bottom layer is twice as thick.'
    ]
    default_raster = rasters[0] if rasters and len(rasters) > 0 else ''
    columns = [
        IntColumnType(header=Columns.HORIZON, tool_tip=tool_tips[0], default=1, low=1, spinbox=False, enabled=False),
        StringColumnType(header=Columns.RASTER, tool_tip=tool_tips[1], default=default_raster, choices=rasters),
        IntColumnType(header=Columns.FILL, tool_tip=tool_tips[2], default=1, checkbox=True, spinbox=False),
        IntColumnType(header=Columns.CLIP, tool_tip=tool_tips[3], default=0, checkbox=True, spinbox=False),
        IntColumnType(header=Columns.SUBLAYERS, tool_tip=tool_tips[4], default=1, low=0, high=100, spinbox=True),
        StringColumnType(header=Columns.PROPORTIONS, tool_tip=tool_tips[5])
    ]

    table_def = TableDefinition(columns)
    return table_def, table_def.to_pandas()


def _get_ugrid_3d_name(name_2d, name_3d):
    """Returns a name for the 3D ugrid.

    If they provided a name (name_3d) it is just returned. Otherwise, we append "3d" to 2D ugrid name.
    E.g.: "blah/blah/my 2D ugrid" -> "my 2D ugrid 3d"

    Args:
        name_2d (str): Name of the 2D UGrid.
        name_3d (str): Name of the 3D UGrid. If provided, it will just be returned.

    Returns:
        (str): Name of the 3D ugrid.
    """
    if name_3d:
        return name_3d
    else:  # Use 2D grid name + ' 3d'
        words = name_2d.split('/')
        name_3d = f'{words[-1]} 3d'
        return name_3d
