"""FootprintCalculator class."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import json
import os

# 2. Third party modules
import pandas as pd
from shapely.geometry import LineString, Point

# 3. Aquaveo modules
from xms.constraint.grid_reader import read_grid_from_file
from xms.constraint.ugrid_builder import UGridBuilder
from xms.core.filesystem import filesystem as xmf
from xms.data_objects.parameters import Coverage
from xms.gdal.rasters import raster_utils as ru
from xms.grid.ugrid.ugrid import UGrid
from xms.guipy.dialogs.process_feedback_dlg import LogEchoQSignalStream
from xms.tool.algorithms.mesh_2d.bridge_footprint import ArcType, PierEndType, PierType
from xms.tool.mesh_2d import BridgeFootprintTool
from xms.tool_gui.xms_data_handler import convert_to_geodataframe

# 4. Local modules
from xms.bridge import structure_util as su
from xms.bridge.calc.culvert_calculator import CulvertCalculator
import xms.bridge.gui.struct_dlg.arc_props as ap


class FootprintCalculator:
    """Class for running the bridge footprint tool."""
    def __init__(self, struct_comp, coverage, data_folder, cl_arc_pts, wkt, vertical_units=''):
        """Constructor.

        Args:
            struct_comp (:obj:`StructureComponent`): the structure component
            coverage (:obj:`data_objects - Coverage`): input coverage to the footprint tool
            data_folder (:obj:`str`): path to the data folder for the tool
            cl_arc_pts (:obj:`list`): list of the coordinates of the centerline arc
            wkt (:obj:`str`): projection string
            vertical_units (:obj:`str`): vertical units
        """
        self._struct_comp = struct_comp
        self.cov = coverage
        cov_name = 'coverage' if not coverage else coverage.name
        self.data_folder = data_folder
        self.cl_arc_pts = cl_arc_pts
        self.wkt = wkt
        self.gdf = convert_to_geodataframe(self.cov, self.wkt)
        self.vertical_units = vertical_units
        self.cov_name = cov_name
        self.input_cov_file = os.path.join(data_folder, 'coverages', f'{cov_name}.h5')
        self.culvert_calc = CulvertCalculator(self._struct_comp, coverage)
        self.tool = BridgeFootprintTool()
        self.tool.echo_output = False
        self.tool.set_gui_data_folder(data_folder)
        self.args = self.tool.initial_arguments()
        self.args[self.tool.ARG_INPUT_COVERAGE].value = self.cov_name
        self.struct_data = None
        self.bridge_mesh_file = ''
        self.cov_file = ''
        self.bridge_mesh = None
        self.bridge_mesh_coverage = None
        self.culvert_out_data = None

    def setup_tool(self, struct_data):
        """Set up the footprint tool with input data.

        Args:
            struct_data (:obj:`StructureData`): data that defines the 3d structure
        """
        self.culvert_calc = CulvertCalculator(self._struct_comp, self.cov)
        # self.culvert_calc.err_msg = ''
        self.struct_data = struct_data
        if struct_data.data_dict['bridge_width'] <= 0.0:
            self.culvert_calc.err_msg = 'Structure width must be greater than 0.0.'
            return
        tool = self.tool
        args = self.args
        tool.arc_data = None
        tool.arc_id_to_type = None
        tool.bridge_num_side_elem = None

        struct_data.data_dict['srh_mapping_info'] = ''
        tool.center_line_num_segments = None
        if struct_data.data_dict['structure_type'] != 'Culvert':
            if struct_data.data_dict['specify_bridge_cl_num_segments']:
                tool.center_line_num_segments = struct_data.data_dict['number_bridge_cl_segments']
        args[tool.ARG_BRIDGE_WIDTH].value = struct_data.data_dict['bridge_width']
        args[tool.ARG_BRIDGE_WRAPPING_WIDTH].value = struct_data.data_dict['bridge_wrapping_width']
        tool.wrap_upstream = struct_data.data_dict['bridge_wrap_upstream'] != 0
        tool.wrap_downstream = struct_data.data_dict['bridge_wrap_downstream'] != 0
        val = True if struct_data.data_dict['specify_number_segments'] else False
        args[tool.ARG_SPECIFY_SEGMENT_COUNT].value = val
        args[tool.ARG_SEGMENT_COUNT].value = struct_data.data_dict['number_segments']
        val = True if struct_data.data_dict['has_abutments'] else False
        args[tool.ARG_HAS_ABUTMENTS].value = val
        args[tool.ARG_PIER_TYPE].value = None
        if struct_data.data_dict['pier_type'] != '-- None Selected --':
            args[tool.ARG_PIER_TYPE].value = struct_data.data_dict['pier_type']
        args[tool.ARG_PIER_DIAMETER].value = struct_data.data_dict['pier_group_diameter']
        args[tool.ARG_GROUP_WRAPPING_WIDTH].value = struct_data.data_dict['pier_wrapping_width']
        args[tool.ARG_PIER_GROUP_COUNT].value = struct_data.data_dict['pier_group_num_in_group']
        args[tool.ARG_PIER_GROUP_SPACING].value = struct_data.data_dict['pier_group_spacing']
        args[tool.ARG_WALL_WIDTH].value = struct_data.data_dict['pier_wall_width']
        args[tool.ARG_WALL_WRAPPING_WIDTH].value = struct_data.data_dict['pier_wrapping_width']
        args[tool.ARG_WALL_PIER_LENGTH].value = struct_data.data_dict['pier_wall_length']
        args[tool.ARG_SIDE_ELEMENT_COUNT].value = struct_data.data_dict['pier_wall_num_side_elem']
        args[tool.ARG_PIER_END_TYPE].value = struct_data.data_dict['pier_wall_end_type']

        args[tool.ARG_OUTPUT_GRID].value = 'bridge_mesh'
        cov_out_name = f'{self.cov_name}_struct_cov'
        args[tool.ARG_OUTPUT_COVERAGE].value = cov_out_name
        self.bridge_mesh_file = os.path.join(self.data_folder, 'grids', 'bridge_mesh.xmc')
        xmf.removefile(self.bridge_mesh_file)
        self.cov_file = os.path.join(self.data_folder, 'coverages', f'{cov_out_name}.h5')
        xmf.removefile(self.cov_file)
        self.bridge_mesh = None
        self.bridge_mesh_coverage = None
        self.culvert_out_data = None

        # reset these so the dialog doesn't show warning or error from a previous run
        LogEchoQSignalStream.logged_error = False
        LogEchoQSignalStream.logged_warning = False

        is_culvert = struct_data.data_dict['structure_type'] == 'Culvert'
        if is_culvert:
            self._prep_tool_for_culvert()
        self.tool.coverage_elevation = self.base_elevation()
        if not is_culvert and struct_data.data_dict['specify_arc_properties'] == 1:
            self._set_arc_data_from_arc_properties(struct_data.data_dict['arc_properties'])
            tool.bridge_num_side_elem = struct_data.data_dict['bridge_num_side_elem']

    def _set_arc_data_from_arc_properties(self, csv):
        """Set the arc data from the arc properties.

        Args:
            csv (:obj:`str`): csv string with arc properties
        """
        df = pd.read_csv(StringIO(csv))
        df.columns = ap.arc_property_column_names()
        if 'Bridge' not in df['Type'].values:
            self.culvert_calc.err_msg = (
                'No "Bridge" arc found in arc properties. Set arc properties before '
                'creating bridge mesh.'
            )
            return
        self.tool.arc_id_to_type = {row[1]['Arc ID']: row[1]['Type'] for row in df.iterrows()}
        self.args[self.tool.ARG_PIER_TYPE].value = 'Wall'
        arc_data = self.tool.get_arc_data(self.gdf, self.args)
        self.args[self.tool.ARG_PIER_TYPE].value = '-- None Selected --'
        self.tool.arc_id_to_type = None
        arc_id_to_idx = {arc_data[i]['id']: i for i in range(len(arc_data))}
        remove_idx = set()
        wall_end_type = {'Sharp': PierEndType.SHARP, 'Round': PierEndType.ROUND, 'Square': PierEndType.SQUARE}
        for row in df.iterrows():
            arc_id = row[1]['Arc ID']
            arc_idx = arc_id_to_idx[arc_id]
            arc_item = arc_data[arc_idx]
            arc_type = row[1]['Type']
            if arc_type == 'Bridge':
                arc_item['arc_type'] = ArcType.BRIDGE
                arc_item['bridge_width'] = self.struct_data.data_dict['bridge_width']
                arc_item['bridge_wrapping_width'] = self.struct_data.data_dict['bridge_wrapping_width']
                arc_item['bridge_num_segments'] = self.struct_data.data_dict['number_segments']
                arc_item['bridge_specify_num_segments'] = self.struct_data.data_dict['specify_number_segments']
            elif arc_type == 'Abutment':
                arc_item['arc_type'] = ArcType.ABUTMENT
            elif arc_type == 'Wall pier (WP)':
                arc_item['arc_type'] = ArcType.PIER
                arc_item['pier_type'] = PierType.WALL
                arc_item['pier_size'] = row[1][ap.WP_WIDTH]
                arc_item['pier_element_wrap_width'] = row[1][ap.ELEM_WRAP_WIDTH]
                arc_item['pier_length'] = row[1][ap.WP_LENGTH]
                arc_item['pier_num_side_elements'] = self.struct_data.data_dict['bridge_num_side_elem']
                arc_item['pier_end_type'] = wall_end_type[row[1][ap.WP_END_TYPE]]
            elif arc_type == 'Pier group (PG)':
                arc_item['arc_type'] = ArcType.PIER
                arc_item['pier_type'] = PierType.GROUP
                arc_item['pier_size'] = row[1][ap.PG_DIAMETER]
                arc_item['pier_element_wrap_width'] = row[1][ap.ELEM_WRAP_WIDTH]
                arc_item['number_piers'] = row[1][ap.PG_NUM_PIERS]
                arc_item['pier_spacing'] = row[1][ap.PG_SPACING]
                arc_item['pier_num_side_elements'] = self.struct_data.data_dict['bridge_num_side_elem']
            else:
                remove_idx.add(arc_idx)
            self.tool.arc_data = [arc_data[i] for i in range(len(arc_data)) if i not in remove_idx]

    def _prep_tool_for_culvert(self):
        """Additional work to be done if this is a culvert structure."""
        self.culvert_calc.err_msg = ''
        struct_data = self.struct_data
        culvert_rise = culvert_span = None
        if struct_data.data_dict['culvert_type'] == 'Box':
            culvert_rise = struct_data.data_dict['culvert_rise']
            culvert_span = struct_data.data_dict['culvert_span']
        ave_top = su.compute_ave_y_from_xy_data(struct_data.curves['top_profile'])
        ave_bot = 0.5 * (struct_data.data_dict['culvert_up_invert'] + struct_data.data_dict['culvert_dn_invert'])
        culvert_dict = {
            'culvert_type': struct_data.data_dict['culvert_type'],
            'bridge_width': struct_data.data_dict['bridge_width'],
            'culvert_diameter': struct_data.data_dict['culvert_diameter'],
            'culvert_num_seg_barrel': struct_data.data_dict['culvert_num_seg_barrel'],
            'coverage_file': self.input_cov_file,
            'has_abutments': struct_data.data_dict['has_abutments'],
            'culvert_embed_depth': struct_data.data_dict['culvert_embed_depth'],
            'culvert_rise': culvert_rise,
            'culvert_span': culvert_span,
            'ave_top': ave_top,
            'ave_bot': ave_bot,
            'num_barrels': struct_data.data_dict['culvert_num_barrels'],
            'culvert_wall_width': struct_data.data_dict['culvert_wall_width'],
            'culvert_arc_properties': self._culvert_arc_properties(),
        }
        self.culvert_calc.new_coverage_for_culvert_mesh(culvert_dict)
        if self.culvert_calc.has_culvert_arc():
            tool = self.tool
            if self.culvert_calc.has_abutments:
                self.args[tool.ARG_HAS_ABUTMENTS].value = True
            self.args[tool.ARG_BRIDGE_WIDTH].value = self.culvert_calc.bridge_width
            self.args[tool.ARG_PIER_TYPE].value = 'Wall'
            self.args[tool.ARG_WALL_WIDTH].value = self.culvert_calc.wall_width
            # make wrap same as wall width
            self.args[tool.ARG_WALL_WRAPPING_WIDTH].value = self.culvert_calc.wall_width
            self.args[tool.ARG_WALL_PIER_LENGTH].value = 10 * struct_data.data_dict['bridge_width']

    def _culvert_arc_properties(self):
        """Return a dict of arc id, arc type if culverts have properties specified on arcs."""
        rval = {}
        if self.struct_data.data_dict['specify_culvert_arc_properties'] == 0:
            return rval
        csv = self.struct_data.data_dict['culvert_arc_properties']
        df = pd.read_csv(StringIO(csv))
        arc_ids = df['col_0'].to_list()
        arc_types = df['col_1'].to_list()
        if self.culvert_calc.err_msg == '':
            rval = {arc_ids[i]: arc_types[i] for i in range(len(arc_ids))}
        return rval

    def load_tool_results(self, raster_file, arc_data_dict):
        """Load the results of the tool.

        Args:
            raster_file (:obj:`str`): path to a raster for elevations
            arc_data_dict (:obj:`dict`): dict with arc data
        """
        if os.path.isfile(self.bridge_mesh_file):
            self.bridge_mesh = read_grid_from_file(self.bridge_mesh_file)
            locs = self.bridge_mesh.ugrid.locations
            elev = self.calc_bridge_mesh_elevations(raster_file, arc_data_dict)
            locs = [(p[0], p[1], elev[i]) for i, p in enumerate(locs)]
            cs = self.bridge_mesh.ugrid.cellstream
            self.bridge_mesh.ugrid.locations = locs
            ug = UGrid(locs, cs)
            builder = UGridBuilder()
            builder.set_ugrid(ug)
            builder.set_is_2d()
            self.bridge_mesh = builder.build_grid()
            os.remove(self.bridge_mesh_file)
        if os.path.isfile(self.cov_file):
            tmp_cov = xmf.temp_filename()
            xmf.copyfile(self.cov_file, tmp_cov)
            self.bridge_mesh_coverage = Coverage(tmp_cov)

    def calc_bridge_mesh_elevations(self, raster_file, arc_data_dict):
        """Update the bridge mesh elevations.

        Args:
            raster_file (:obj:`str`): path to a raster for elevations
            arc_data_dict (:obj:`dict`): dict with arc data

        Returns:
            (:obj:`list[float]`): list of elevations
        """
        struct_data = self.struct_data
        base_elev = self.base_elevation()
        locs = self.bridge_mesh.ugrid.locations if self.bridge_mesh else []
        elev = []
        if len(locs) > 0:
            if raster_file:
                interp = ru.interpolate_raster_to_points
                raster_vals, no_data_val = interp(raster_file, locs, self.wkt, self.vertical_units)
                elev = [v if v != no_data_val else base_elev for v in raster_vals]
            else:
                elev = [base_elev] * len(locs)

        culvert_poly = None
        up_arc, down_arc, _ = self.get_up_down_arcs()
        if not self.culvert_calc.has_culvert_arc():
            pass  # do nothing if there is no culvert arc in the coverage
        elif struct_data.data_dict['structure_type'] == 'Culvert':
            culvert_data = self.culvert_data_for_profiles()
            culvert_data['up_arc'] = up_arc
            culvert_data['down_arc'] = down_arc
            culvert_data['tmp_main_file'] = os.path.join(self.data_folder, 'structure.nc')
            culvert_data['locs'] = locs
            culvert_data['elev'] = elev
            culvert_data['top_profile'] = struct_data.curves['top_profile']
            culvert_data['wkt'] = self.wkt
            culvert_data['base_elev'] = self.calc_culvert_base_elev()

            self.culvert_calc.update_bridge_elev(culvert_data)
            struct_data.curves['upstream_profile'] = culvert_data['profiles'][0]
            struct_data.curves['downstream_profile'] = culvert_data['profiles'][1]
            arc_data_dict['culvert_bottom_up_profile'] = culvert_data['profiles'][2]
            arc_data_dict['culvert_bottom_dn_profile'] = culvert_data['profiles'][3]
            arc_data_dict['top_up'] = self.culvert_calc.top_up_df
            arc_data_dict['top_dn'] = self.culvert_calc.top_dn_df
            arc_data_dict['culvert_ug_base_elev'] = culvert_data['base_elev']
            arc_data_dict['match_parametric_values'] = False
            culvert_poly = self.culvert_calc.barrels_poly
            # modify up_arc and down_arc to be trimmed to the culvert, these are used to map bcs to SRH at the
            # opening and exit of the culvert
            up_ls = LineString(up_arc)
            dn_ls = LineString(down_arc)
            up_pts = []
            dn_pts = []
            for p in culvert_poly:
                sh_pt = Point(p)
                if up_ls.distance(sh_pt) < dn_ls.distance(sh_pt):
                    up_pts.append((up_ls.project(sh_pt), p))
                else:
                    dn_pts.append((dn_ls.project(sh_pt), p))
            up_pts.sort()
            up_arc = [p[1] for p in up_pts]
            dn_pts.sort()
            down_arc = [p[1] for p in dn_pts]

        srh_data = {
            'up_arc': up_arc,
            'down_arc': down_arc,
            'wkt': self.wkt,
            'culvert_poly': culvert_poly,
        }
        struct_data.data_dict['srh_mapping_info'] = json.dumps(srh_data)
        return elev

    def get_up_down_arcs(self):
        """Gets the bridge upstream and downstream arcs.

        Returns:
            (:obj:`tuple(list,list,str)`): list of the points for the up and downstream
        """
        msg = ''
        if self.bridge_mesh is not None:
            up_arc = [(p[0], p[1]) for p in self.tool.bridge_upstream_line]
            down_arc = [(p[0], p[1]) for p in self.tool.bridge_downstream_line]
        else:
            up_arc, down_arc, msg = su.offset_centerline(self.cl_arc_pts, self.struct_data.data_dict['bridge_width'])
        return up_arc, down_arc, msg

    def culvert_data_for_profiles(self):
        """Get the culvert data that is used to generate profiles.

        Returns:
            (:obj:`dict`): the culvert data
        """
        culvert_data = self.culvert_calc.culvert_arc_data(self.struct_data.data_dict['has_abutments'])
        if culvert_data is not None:
            items = [
                'culvert_type', 'culvert_diameter', 'culvert_rise', 'culvert_span', 'culvert_up_invert',
                'culvert_dn_invert', 'culvert_embed_depth', 'culvert_num_barrels', 'culvert_wall_width'
            ]
            for i in items:
                culvert_data[i] = self.struct_data.data_dict[i]
            if culvert_data['culvert_type'] == 'Box':
                culvert_data['culvert_diameter'] = culvert_data['culvert_span']
                culvert_data['culvert_rise'] = culvert_data['culvert_rise'] - culvert_data['culvert_embed_depth']
        return culvert_data

    def base_elevation(self):
        """Get the base elevation of the 3D Structure."""
        struct_data = self.struct_data
        base_elev = struct_data.data_dict['bridge_pier_base_elev']
        if struct_data.data_dict['structure_type'] == 'Culvert':
            base_elev = self.calc_culvert_base_elev()
            # up_invert = struct_data.data_dict['culvert_up_invert']
            # dn_invert = struct_data.data_dict['culvert_dn_invert']
            # base_elev = 0.5 * (up_invert + dn_invert)
        return base_elev

    def calc_culvert_base_elev(self):
        """Calculates the culvert base elevation.

        Returns:
            (:obj:`float`): base elevation
        """
        up_invert = self.struct_data.data_dict['culvert_up_invert']
        dn_invert = self.struct_data.data_dict['culvert_dn_invert']
        ave_bot = 0.5 * (up_invert + dn_invert)
        ave_top = su.compute_ave_y_from_xy_data(self.struct_data.curves['top_profile'])
        base = min(up_invert, dn_invert) - (0.1 * (ave_top - ave_bot))
        return base
