"""EfdcGridReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules
from xms.tool.algorithms.ugrids import curvilinear_grid_ij as cgij
from xms.tool.file_io.curvilinear.curvilinear_reader_base import CurvilinearReaderBase as ReaderBase


class EfdcGridReader(ReaderBase):
    """Reader for EFDC curvilinear grid files."""

    class CornerPoint:
        """Struct to store corner point data from cells that we will eventually stitch together."""

        def __init__(self, index):
            """Constructor.

            Args:
                 index (int): 0-based index in the XmUGrid point list
            """
            self.index = index  # 0-based index in the XmUGrid point list
            # Candidate x,y locations for this corner point as computed from each adjacent cell
            self._candidate_coords = [[], []]  # [[x], [y]]

        def add_candidate_location(self, x, y):
            """Add the x-y coordinates of the corner point computed from a cell to the list of candidate locations.

            Args:
                x (float): x-coordinate location with respect to a specific quad cell.
                y (float): y-coordinate location with respect to a specific quad cell.
            """
            self._candidate_coords[0].append(x)
            self._candidate_coords[1].append(y)

        def average_location(self):
            """Average all the candidate x-y locations for this corner point.

            Returns:
                tuple(float, float, float): The x,y,z coordinate tuple where x and y are averages of all connected
                cells, and z is always 0.0. We rely on cell elevations for this format.
            """
            av_xy = np.mean(self._candidate_coords, axis=1)
            return av_xy[0], av_xy[1], 0.0

    def __init__(self, dxdy_filename, lxly_filename, grid_name, logger):
        """Initializes the class.

        Args:
            dxdy_filename (str): Path to the EFDC dxdy.inp file
            lxly_filename (str): Path to the EFDC lxly.inp file
            grid_name (str): Optional user input for output grid name. If not
                specified, will try to read from file.
            logger (logging.Logger): The tool logger
        """
        super().__init__(grid_name, logger)
        self._dxdy_filename = dxdy_filename
        self._lxly_filename = lxly_filename
        self._corner_points = {}  # {(pt_i, pt_j): CornerPoint}
        # Output datasets
        self._depth_values = []
        self._zrough_values = []
        self._veg_type_values = []
        self._wind_shelter_values = []

    def _build_efdc_grid(self, dxdy_df, lxly_df):
        """Build a curvilinear UGrid from data read from EFDC formatted grid files.

        Args:
            dxdy_df (pd.DataFrame): The cell sizes DataFrame read from dxdy.inp
            lxly_df (pd.DataFrame): The centroid DataFrame read from lxly.inp
        """
        self.logger.info('Extracting cell elevations...')
        self._cell_elevations = dxdy_df.z.values  # We'll rely on cell elevations for this format instead of point Z
        self._build_quads(dxdy_df, lxly_df)
        self._stitch_points()

    def _build_quads(self, dxdy_df, lxly_df):
        """Build the cellstream while also storing candidate corner point locations for stitching later.

        Args:
            dxdy_df (pd.DataFrame): The cell sizes DataFrame read from dxdy.inp
            lxly_df (pd.DataFrame): The centroid DataFrame read from lxly.inp
        """
        self.logger.info('Building grid cells from imported data...')
        for dxdy_row, lxly_row in zip(dxdy_df.itertuples(), lxly_df.itertuples()):
            cell_i = dxdy_row.Index[0]
            cell_j = dxdy_row.Index[1]
            cell_pts = self._add_cell(
                cell_i=cell_i, cell_j=cell_j, dx=dxdy_row.dx, dy=dxdy_row.dy, center_x=lxly_row.x, center_y=lxly_row.y,
                cue=lxly_row.cue, cve=lxly_row.cve, cun=lxly_row.cun, cvn=lxly_row.cvn
            )
            self._cellstream.extend([XmUGrid.cell_type_enum.QUAD, 4, *cell_pts])
            self._i_values.append(cell_i - 2)  # Cell i-coordinate for output dataset (without border buffer)
            self._j_values.append(cell_j - 2)  # Cell j-coordinate for output dataset (without border buffer)
            self._depth_values.append(dxdy_row.depth)
            self._zrough_values.append(dxdy_row.zrough)
            self._veg_type_values.append(dxdy_row.veg_type)
            self._wind_shelter_values.append(lxly_row.wind_shelter)

    def _add_cell(self, cell_i, cell_j, dx, dy, center_x, center_y, cue, cve, cun, cvn):
        """Add a cell and its points to the XmUgrid cellstream and point list.

        Args:
            cell_i (int): i-coordinate of the cell
            cell_j (int): j-coordinate of the cell
            dx (float): Width of the cell
            dy (float): Height of the cell
            center_x (float): x-coordinate of the cell centroid
            center_y (float): y-coordinate of the cell centroid
            cue (float): Curvilinear east u-vector component we need to rotate the cell
            cve (float): Curvilinear east v-vector component we need to rotate the cell
            cun (float): Curvilinear north u-vector component we need to rotate the cell
            cvn (float): Curvilinear north v-vector component we need to rotate the cell

        Returns:
            tuple(int, int, int, int): The XmUGrid point indices of the quad cell in CCW order:
            [LOC_BOTTOM_LEFT, LOC_BOTTOM_RIGHT, LOC_TOP_RIGHT, LOC_TOP_LEFT]
        """
        half_dx = dx * 0.5
        half_dy = dy * 0.5
        # cu=(cue,cun), cv=(cve,cvn)
        half_dy_cve = half_dy * cve
        half_dy_cvn = half_dy * cvn
        half_dx_cue = half_dx * cue
        half_dx_cun = half_dx * cun
        # Find the coordinates of the corner points of the cell (CCW order starting at bottom left corner).
        # BL = c - cvDY/2 - cuDX/2
        bl_x = center_x - half_dy_cve - half_dx_cue
        bl_y = center_y - half_dy_cvn - half_dx_cun
        # BR = c - cvDY/2 + cuDX/2
        br_x = center_x - half_dy_cve + half_dx_cue
        br_y = center_y - half_dy_cvn + half_dx_cun
        # TR = c + cvDY/2 + cuDX/2
        tr_x = center_x + half_dy_cve + half_dx_cue
        tr_y = center_y + half_dy_cvn + half_dx_cun
        # TL = c + cvDY/2-cuDX/2
        tl_x = center_x + half_dy_cve - half_dx_cue
        tl_y = center_y + half_dy_cvn - half_dx_cun
        return (
            self._add_candidate_point(cell_i, cell_j, bl_x, bl_y, cgij.LOC_BOTTOM_LEFT),
            self._add_candidate_point(cell_i, cell_j, br_x, br_y, cgij.LOC_BOTTOM_RIGHT),
            self._add_candidate_point(cell_i, cell_j, tr_x, tr_y, cgij.LOC_TOP_RIGHT),
            self._add_candidate_point(cell_i, cell_j, tl_x, tl_y, cgij.LOC_TOP_LEFT),
        )

    def _add_candidate_point(self, cell_i, cell_j, x_coord, y_coord, location):
        """Add a corner point, with its x-y coordinates computed from a specific attached cell.

        Args:
            cell_i (int): i-coordinate of the cell
            cell_j (int): j-coordinate of the cell
            x_coord (float): The corner point's x-coordinate with respect to the cell's centroid
            y_coord (float): The corner point's y-coordinate with respect to the cell's centroid
            location (int): The corner point's location with respect to the cell centroid. One of the LOC_* constants
                defined in the base class.

        Returns:
            int: The point's index in the XmUGrid point list.
        """
        # Find the point i-j coordinates given the cell i-j and the corner point's location.
        pt_i = cell_i  # if location == self.LOC_BOTTOM_LEFT, (pt_i, pt_j) = (cell_i, cell_j)
        pt_j = cell_j
        if location == cgij.LOC_BOTTOM_RIGHT:
            pt_i += 1  # (pt_i, pt_j) = (cell_i + i, cell_j)
        elif location == cgij.LOC_TOP_RIGHT:
            pt_i += 1  # (pt_i, pt_j) = (cell_i + i, cell_j + 1)
            pt_j += 1
        elif location == cgij.LOC_TOP_LEFT:
            pt_j += 1  # (pt_i, pt_j) = (cell_i, cell_j + 1)
        pt_ij = (pt_i, pt_j)

        # Find or create the CornerPoint struct for this point's i-j
        corner_point = self._corner_points.get(pt_ij)
        if corner_point is None:  # First cell this point was attached to.
            corner_point = self.CornerPoint(index=len(self._corner_points))
            self._corner_points[pt_ij] = corner_point

        # Add the point's x-y location computed for this cell as a candidate for the final stitched location.
        corner_point.add_candidate_location(x_coord, y_coord)
        return corner_point.index

        # DEBUG - Comment out everything above, uncomment the lines below, and comment out the call to _stitch_points()
        #         in _build_efdc_grid() to return the unstitched quads.
        # self._points.append((x_coord, y_coord, 0.0))
        # return len(self._points) - 1

    def _stitch_points(self):
        """Build the XmUGrid point list by averaging the locations of the points computed from each adjacent cell."""
        self.logger.info('Stitching corner points...')
        # Iterate in insertion order, should match point list order/indices used to build the cellstream.
        for corner_point in self._corner_points.values():
            self._points.append(corner_point.average_location())

    def _read_dxdy_file(self):
        """Read the dxdy.inp file.

        Returns:
            pd.DataFrame: The Dx/Dy DataFrame with i-j coordinates as the MultiIndex
        """
        self.logger.info('Reading cell size data from the dxdy.inp file...')
        try:
            columns = ['i', 'j', 'dx', 'dy', 'depth', 'z', 'zrough', 'veg_type']
            df = pd.read_csv(self._dxdy_filename, sep='\\s+', header=None, names=columns, index_col=['i', 'j'],
                             comment='C')
            # zrough and veg type are optional and may not exist
            df['zrough'] = df['zrough'].fillna(0.0)
            df['veg_type'] = df['veg_type'].fillna(0)
            return df
        except Exception as e:
            raise RuntimeError('Error reading dxdy.inp file. Aborting import:') from e

    def _read_lxly_file(self):
        """Read the lxly.inp file.

        Returns:
            pd.DataFrame: The Lx/Ly DataFrame with i-j coordinates as the MultiIndex
        """
        self.logger.info('Reading cell centroid data from the lxly.inp file...')
        try:
            columns = ['i', 'j', 'x', 'y', 'cue', 'cve', 'cun', 'cvn', 'wind_shelter']
            df = pd.read_csv(self._lxly_filename, sep='\\s+', header=None, names=columns, index_col=['i', 'j'],
                             comment='C')
            df['wind_shelter'] = df['wind_shelter'].fillna(0.0)  # wind shelter is optional and may not exist
            return df
        except Exception as e:
            raise RuntimeError('Error reading lxly.inp file. Aborting import:') from e

    def _create_dataset(self, dset_name, values, dsets):
        """Create an output dataset.

        Args:
            dset_name (str): Name of the dataset
            values (list): The dataset values
            dsets (list): The output dataset list, will append to
        """
        dset_uuid = str(uuid.uuid4())
        dset_writer = DatasetWriter(name=dset_name, dset_uuid=dset_uuid, geom_uuid=self._cogrid_uuid,
                                    location='cells')
        dset_writer.write_xmdf_dataset([0.0], [values])
        dsets.append(dset_writer)

    def create_datasets(self):
        """Overload to add EFDC-specific datasets.

        Returns:
            list[DatasetWriter]: The output datasets
        """
        dsets = super().create_datasets()  # Creates the cell i-j coordinate datasets
        self.logger.info('Writing output cell depth dataset...')
        self._create_dataset('Depth', self._depth_values, dsets)
        self.logger.info('Writing output cell Z roughness dataset...')
        self._create_dataset('Z Roughness', self._zrough_values, dsets)
        self.logger.info('Writing output cell vegetation type dataset...')
        self._create_dataset('Vegetation Type', self._veg_type_values, dsets)
        self.logger.info('Writing output cell wind shelter dataset...')
        self._create_dataset('Wind Shelter', self._wind_shelter_values, dsets)
        return dsets

    def read(self):
        """Import a curvilinear grid from an EFDC formatted file."""
        dxdy_df = self._read_dxdy_file()
        if dxdy_df is not None:
            lxly_df = self._read_lxly_file()
            if lxly_df is not None:
                if not dxdy_df.index.equals(lxly_df.index):
                    raise RuntimeError(
                        'Cell size definitions in the dxdy.inp file do not match the cell centroid definitions in the '
                        'lxly.inp file. Aborting import.'
                    )
                self._build_efdc_grid(dxdy_df, lxly_df)
