"""AttTableCoverageDump code."""

__copyright__ = '(C) Copyright Aquaveo 2020'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from enum import IntEnum
import logging
from pathlib import Path
import uuid

# 2. Third party modules
import h5py
import pandas as pd

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import DataDumpIOBase
from xms.data_objects.parameters import Coverage, FilterLocation

# 4. Local modules
from xms.coverage import coverage_util
from xms.coverage.xy.xy_io import XyReader
from xms.coverage.xy.xy_series import XySeries


class ColumnTypeEnum(IntEnum):
    """Column data type (colatttype in XMS)."""
    COL_TYPE_INT = 0
    COL_TYPE_DOUBLE = 1
    COL_TYPE_STR = 2
    COL_TYPE_BOOL = 3
    COL_TYPE_FLOAT = 4


def get_att(df, feature_id, column, node=-1):
    """Returns the attribute or None if the column or id_ doesn't exist.

    Args:
        df (DataFrame): The pandas dataframe representing the attribute table.
        feature_id (int): Feature ID, which is the key.
        column (str): Column name.
        node (int): If 0, gets the attribute at node[0] on arc (if 1, node[1]).

    Returns:
        The value.
    """
    try:
        if node == 0:
            column = f'Node1 {column}'
        elif node == 1:
            column = f'Node2 {column}'
        return df.loc[feature_id, column]
    except KeyError:
        return None


class AttTableCoverageDump(DataDumpIOBase):
    """Class to interface with the GMS coverages that use attribute tables."""

    # Constants
    ARC_ATTS = 'Arc Atts'
    ARC_GROUP_ATTS = 'Arc Group Atts'
    COLUMN = 'Column'
    COVERAGE1 = 'Coverage1'
    DATA_TYPE = 'Data Type'
    GMS_ATTS = 'GMS Atts'
    Id = 'Id'
    ID = 'ID'
    MAP_DATA = 'Map Data'
    NAME = 'Name'
    NUMBER = 'Number'
    POINT_ATTS = 'Point Atts'
    POLYGON_ATTS = 'Polygon Atts'
    XY_SERIES_IDS_FLAG = 'XYSeries Ids Flag'

    def __init__(self, file_name=''):
        """Initializes the class.

        Args:
            file_name(str): Path to the H5 dump file
        """
        super().__init__()
        super().SetSelf(self)
        self._do_coverage = None  # data_objects Coverage
        self._has_read = False
        self._file_name = file_name
        self._table_group_names = {
            'points': self.POINT_ATTS,
            'arcs': self.ARC_ATTS,
            'arc_groups': self.ARC_GROUP_ATTS,
            'polys': self.POLYGON_ATTS
        }
        self._column_values_h5_datasets = {
            0: 'Integer Column Data',
            1: 'Double Column Data',
            2: 'String Column Data',
            3: 'Bool Column Data',
            4: 'Float Column Data'
        }

    @property
    def do_coverage(self):
        """Get the coverage geometry.

        Returns:
            xms.data_objects.parameters.Coverage: The activity coverage geometry
        """
        if self._has_read is False and self._file_name:
            self.ReadDump(self._file_name)
        return self._do_coverage

    @do_coverage.setter
    def do_coverage(self, val):
        """Set the coverage geometry.

        Args:
            val (xms.data_objects.parameters.Coverage): The activity coverage geometry
        """
        self._do_coverage = val

    def get_xy_series(self) -> list[XySeries]:
        """Returns the xy series as a dict[int, XySeries].

        Returns:
            (list[XySeries]):
        """
        if self._has_read is False and self._file_name:
            self.ReadDump(self._file_name)

        reader = XyReader()
        xy_series_list = reader.read_from_h5(Path(self._file_name))
        return xy_series_list

    def get_table(self, feature_type):
        """Returns the att table as a Pandas DataFrame.

        Args:
            feature_type (str): 'points', 'arcs', 'arc_groups', 'polys'

        Returns:
            (DataFrame): table as a DataFrame.
        """
        return self._read_table(feature_type)[0]

    def get_xy_series_columns(self, feature_type):
        """Returns a dict[str, bool] of which columns are xy series.

        Args:
            feature_type (str): 'points', 'arcs', 'arc_groups', 'polys'

        Returns:
            (dict[str, bool]): dict of columns and xy series flags.
        """
        return self._read_table(feature_type)[1]

    def _read_table(self, feature_type):
        """Returns the att table as tuple: a Pandas DataFrame, and a dict of which columns are xy series flags.

        Args:
            feature_type (str): 'points', 'arcs', 'arc_groups', 'polys'

        Returns:
            (tuple[DataFrame, dict[str, bool]]): Tuple: table as a DataFrame, dict of columns and xy series flags.
        """
        assert feature_type in {'points', 'arcs', 'arc_groups', 'polys'}
        h5_file = None
        df = None
        xy_columns = {}

        try:  # Use try/except block to make sure h5 file gets closed even if there's an error
            if self._has_read is False and self._file_name:
                self.ReadDump(self._file_name)

            # Open file and make sure the h5 group exists
            h5_file = h5py.File(self._file_name, 'r')
            if not (att_table_group := h5_file.get(self._get_att_table_group_name(feature_type))):
                raise RuntimeError('Table not found.')

            columns_dict, xy_columns = self._read_columns(att_table_group)
            df = self._make_dataframe(columns_dict)
            return df, xy_columns

        except Exception as e:
            logger = logging.getLogger('xms.coverage')
            logger.error(str(e))

        if h5_file:
            h5_file.close()
        return df, xy_columns

    def _make_dataframe(self, columns_dict):
        """Creates and returns a Pandas DataFrame from the columns dict.

        Args:
            columns_dict (dict[str, any]): Dict of columns of values to be turned into a DataFrame

        Returns:
            (DataFrame): The DataFrame.
        """
        df = pd.DataFrame(columns_dict)
        if self.ID in df.columns:
            df.set_index(self.ID, inplace=True)  # Set 'ID' column as index for fast lookup
        return df

    def _get_att_table_group_name(self, feature_type):
        """Returns the name of the att table group.

        Args:
            feature_type (str): 'points', 'arcs', 'arc_groups', 'polys'

        Returns:
            (str): The name of the att table group.
        """
        table_group_name = self._table_group_names[feature_type]
        att_table_group_name = f'{self.MAP_DATA}/{self.COVERAGE1}/{self.GMS_ATTS}/{table_group_name}'
        return att_table_group_name

    def _read_columns(self, att_table_group):
        """Reads the columns and returns them as a dict.

        Args:
            att_table_group (H5 group): Group for the attribute table.

        Returns:
            (tuple[dict[str, any], dict[str, bool]]): Dict of columns of values to be turned into a DataFrame,
            and dict of columns and whether they are xy series IDs or not.
        """
        number = att_table_group[self.NUMBER][0]
        columns_dict = {}
        xy_columns = {}
        for i in range(number):
            column_group_name = f'{self.COLUMN}{i + 1}'
            if not (column_group := att_table_group.get(column_group_name)):
                raise RuntimeError(f'{column_group_name} not found.')
            self._read_column(column_group, columns_dict, xy_columns)
        return columns_dict, xy_columns

    def _read_column(self, column_group, columns_dict, xy_columns):
        """Reads and stores the column.

        Args:
            column_group (H5 group): The H5 group for the column.
            columns_dict (dict[str, any]): Dict of columns of values to be turned into a DataFrame.
            xy_columns (dict[str, bool]): Dict containing which columns are used as an xy series flag.
        """
        name = column_group[self.NAME][0].astype(str)
        xy_series_flag = column_group[self.XY_SERIES_IDS_FLAG][0]
        xy_columns[name] = bool(xy_series_flag)
        data_type = column_group[self.DATA_TYPE][0]
        values_dataset_name = self._column_values_h5_datasets[data_type]
        values_dataset = column_group.get(values_dataset_name)
        if values_dataset:
            if data_type == ColumnTypeEnum.COL_TYPE_STR:
                columns_dict[name] = values_dataset[:].astype(str)
            else:
                columns_dict[name] = values_dataset[:]

    # @overrides
    def WriteDump(self, file_name):  # noqa: N802
        """Write the coverage geometry and attributes to an H5 file XMS can read.

        Args:
            file_name (str): Path to the output H5 file
        """
        # Write the geometry
        cov_uuid = self.do_coverage.uuid
        if not cov_uuid or cov_uuid.lower() == 'cccccccc-cccc-cccc-cccc-cccccccccccc':
            self.do_coverage.uuid = str(uuid.uuid4())
        self.do_coverage.write_h5(file_name)

        self._write_attributes(file_name)

    # @overrides
    def ReadDump(self, file_name):  # noqa: N802
        """Read coverage geometry and attributes from an H5 file written by XMS.

        Args:
            file_name (str): Path to the file to read
        """
        # Read the coverage name
        self._has_read = True
        h5_file = h5py.File(file_name, 'r')
        if (coverage_name := coverage_util.read_coverage_name(h5_file)) == "":
            h5_file.close()
            return

        self._read_attributes(coverage_name, h5_file)

        # read geometry at the end
        self.do_coverage = Coverage(file_name, f'/{self.MAP_DATA}/{coverage_name}')
        self.do_coverage.get_points(FilterLocation.LOC_NONE)  # force geometry to load from H5

    def _read_attributes(self, coverage_name, h5_file):
        """Reads the attributes.

        Args:
            coverage_name (str): Name of the coverage.
            h5_file: H5 file.
        """
        attribute_group = h5_file[f'{self.MAP_DATA}/{coverage_name}/{self.GMS_ATTS}']
        del attribute_group  # temporary

    def _write_attributes(self, file_name):
        """Writes the attributes.

        Args:
            file_name (str): Filepath of coverage dump file.
        """
        f = h5py.File(file_name, 'a')
        cov_group = f[f'{self.MAP_DATA}/{self.COVERAGE1}']
        atts = cov_group.create_group(self.GMS_ATTS)
        del atts
        # act_dataset = atts.create_dataset("Activity", (len(self.m_activity),), 'i')
        # id_dataset = atts.create_dataset("Id", (len(self.m_activity),), 'i')

        f.close()

    # @overrides
    def Copy(self):  # noqa: N802
        """Return a reference to this object."""
        return self

    # @overrides
    def GetDumpType(self):  # noqa: N802
        """Get the XMS coverage dump type."""
        return "xms.coverage.att_table_coverage"


def ReadDumpWithObject(file_name):  # noqa: N802
    """Read a coverage dump file.

    Args:
        file_name (str): Filepath to the dumped coverage to read

    Returns:
        AttTableCoverage: The loaded coverage
    """
    coverage = AttTableCoverageDump(file_name)
    return coverage
