"""Class for writing canal coverage data to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import LineString, Point

# 3. Aquaveo modules
from xms.data_objects.parameters import FilterLocation
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.rsm.file_io import util
from xms.rsm.file_io.canal_bcs_writer import CanalBcsWriter
from xms.rsm.file_io.canal_monitor_writer import CanalMonitorWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo, WaterBodyInfoJunction


class CanalWriter:
    """Writer class for the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self._wd = writer_data
        self._canal_info = writer_data.water_body_info['canal']
        self._canal_junction = writer_data.water_body_info['canal_junction']
        self._arc_start_id = writer_data.xms_data.waterbody_start_id
        self._map_filename = ''
        self._init_filename = ''
        self._canal_bc_writer = None
        self._canal_cov = None
        self._canal_comp = None
        self._gm = None
        self._default_canal = None
        self._map_file = None
        self._arc_data = None
        self._init_head = []
        self._arc_index_data = {}
        self._monitor_data = []
        self._cur_arc_id = -1
        self._cur_canal_id = -1

    def append_bcs(self, bcs):
        """Append boundary conditions to the writer.

        Args:
            bcs (list): List of boundary condition XML elements.
        """
        for bc in bcs:
            self._bc_xml.append(bc)

    @cached_property
    def _network_xml(self):
        """Get the network XML element."""
        return Et.SubElement(self._wd.xml_hse, 'network')

    @cached_property
    def _bc_xml(self):
        """Get the boundary condition XML element."""
        return Et.SubElement(self._network_xml, 'network_bc')

    @cached_property
    def _monitor_set(self):
        """Get the set of monitor groups."""
        ap = self._gm.arc_parameters.copy()
        monitor_grps = [ap.group(nm) for nm in ap.group_names]
        return set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])

    def _get_arc_data(self, arc):
        """Gets the data associated with the current arc.

        Args:
            arc (xms.data_objects.parameters.Spatial.Arc.Arc): coverage arc
        """
        self._cur_arc_id = arc.id
        self._arc_parameters = self._gm.arc_parameters.copy()
        self._arc_data = self._default_canal
        arc_label = self._arc_data.parameter('label').value
        self._cur_canal_id = canal_id = arc.id + self._arc_start_id
        comp_id = self._canal_comp.get_comp_id(TargetType.arc, arc.id)
        if comp_id is None or comp_id < 0:
            comp_id = util.UNINITIALIZED_COMP_ID

        if comp_id != util.UNINITIALIZED_COMP_ID:
            canal_val = self._canal_comp.data.feature_values(TargetType.arc, comp_id)
            self._arc_parameters.restore_values(canal_val)
            if self._arc_parameters.group('specified_canal').is_active:
                self._arc_data = self._arc_parameters.group('specified_canal')
                arc_label = self._arc_data.parameter('label').value

        # save the initial head to write later
        self._init_head.append(None)
        if self._arc_data.parameter('b_init_head').value:
            self._init_head[-1] = self._arc_data.parameter('init_head').value

        arc_data_dict = {}
        if self._arc_data.parameter('b_specify_index').value:
            arc_data_dict['index'] = self._arc_data.parameter('index').value
        if self._arc_data.parameter('b_mannings').value:
            arc_data_dict['mannings'] = self._arc_data.parameter('mannings').value
        if self._arc_data.parameter('b_leakage_coeff').value:
            arc_data_dict['leakage_coeff'] = self._arc_data.parameter('leakage_coeff').value
        if self._arc_data.parameter('b_bank').value:
            arc_data_dict['bank_height'] = self._arc_data.parameter('bank_height').value
            arc_data_dict['bank_coeff'] = self._arc_data.parameter('bank_coeff').value
        if self._arc_data.parameter('b_levee').value:
            arc_data_dict['levee'] = self._arc_data.parameter('levee').value
        if len(arc_data_dict) > 0:
            data_str = str(arc_data_dict)
            arc_data_and_ids = self._arc_index_data.get(data_str, [arc_data_dict, []])
            arc_data_and_ids[1].append(arc.id)
            self._arc_index_data[data_str] = arc_data_and_ids

        # save canal bcs if they exist
        arc_id = arc.id + self._arc_start_id
        self._canal_bc_writer.add_arc_data(arc_id, self._arc_data)
        # save the arc linestring for the canal monitor writer
        arc_pts = [(pt.x, pt.y, 0.0) for pt in arc.get_points(FilterLocation.PT_LOC_ALL)]
        ls = LineString(arc_pts)
        self._canal_info.append(WaterBodyInfo(arc_label, canal_id, ls))
        self._store_monitor_info()

    def _write_canal_xsect_props(self):
        """Write the cross-section properties for the current canal."""
        ad = self._arc_data
        bw = ad.parameter('bottom_width').value
        be = ad.parameter('bottom_elev').value
        ss = ad.parameter('side_slope').value
        self._map_file.write(f'type trapezoid {bw} {be} {ss}\n')

    def _write_canal_properties(self, arc):
        """Write the canal properties.

        Args:
            arc (xms.data_objects.parameters.Spatial.Arc.Arc): coverage arc
        """
        arc_data = self._arc_data
        if arc_data.parameter('b_flowtype').value:
            val = arc_data.parameter('flowtype').value
            ival = arc_data.parameter('flowtype').options.index(val)
            self._map_file.write(f'flowtype {ival}\n')

    def _setup_for_write(self):
        """Setup the class for writing the canal data."""
        gm_default_canal = self._canal_comp.canal_helper.gm_default_canal
        ap = gm_default_canal.arc_parameters.copy()
        self._default_canal = ap.group('default_canal')
        self._gm = self._canal_comp.canal_helper.generic_model
        cov_name = self._canal_cov.name.replace(' ', '_')
        self._map_filename = f'CANALS_{cov_name}.map'
        self._canal_bc_writer = CanalBcsWriter(cov_name, self._canal_comp.data, self._arc_start_id, self._wd)

    def _write_map_file(self):
        """Writes the map file for the canal coverage."""
        with open(self._map_filename, 'w') as self._map_file:
            f = self._map_file
            f.write('MAP\n')
            f.write('BEGCOV\n')
            f.write(f'COVNAME {self._canal_cov.name}\n')
            f.write('COVELEV 0.0\n')
            f.write('COVATTS GENERAL\n')

            # Nodes
            node_to_arc = {}
            nodes = self._canal_cov.get_points(FilterLocation.PT_LOC_CORNER)
            for pt in nodes:
                f.write('NODE\n')
                f.write(f'XY {pt.x} {pt.y} {pt.z}\n')
                f.write(f'ID {pt.id}\n')
                f.write('END\n')
                node_to_arc[pt.id] = WaterBodyInfoJunction('', pt.id, Point(pt.x, pt.y), [])

            # Arcs
            start_id = self._arc_start_id
            arcs = self._canal_cov.arcs
            for arc in arcs:
                self._get_arc_data(arc)
                f.write('ARC\n')
                self._write_canal_xsect_props()
                canal_id = arc.id + start_id
                f.write(f'ID {canal_id}\n')
                f.write(f'NODES {arc.start_node.id} {arc.end_node.id}\n')
                node_to_arc[arc.start_node.id].canal_ids.append(canal_id)
                node_to_arc[arc.end_node.id].canal_ids.append(canal_id)
                self._write_canal_properties(arc)
                verts = arc.vertices
                f.write(f'ARCVERTICES {len(verts)}\n')
                for v in verts:
                    f.write(f'{v.x} {v.y} {v.z}\n')
                f.write('END\n')

            f.write('ENDCOV\n')

            # add the junctions to the canal_junction info
            for v in node_to_arc.values():
                if len(v.canal_ids) > 1:  # only add junctions with more than one canal
                    v.canal_ids.sort()
                    self._canal_junction.append(v)

    def _write_init_file(self):
        """Writes the init file for the canal coverage."""
        filename = self._map_filename.replace('.map', '.init')
        # if all init head values are None, don't write the file
        count_none = self._init_head.count(None)
        if count_none == len(self._init_head):
            return
        # if any arcs have an init head value of None then report a warning and don't write the file
        if count_none > 0:
            ids_with_none = [idx + 1 for idx, head in enumerate(self._init_head) if head is None]
            msg = (
                'Some canal arcs have no initial head value. The canal init file will not be written.\n'
                f'Canal arc ids with no initial head value: {ids_with_none}'
            )
            self._wd.logger.warning(msg)
            return

        with open(filename, 'w') as f:
            self._init_filename = filename
            f.write('netinit\n')
            for head in self._init_head:
                f.write(f'{head}\n')

    def _index_file_name(self):
        """Get the index file name."""
        return self._map_filename.replace('.map', '.index')

    def _write_index_file(self):
        """Writes the index file for the canal coverage."""
        if len(self._arc_index_data) < 1:
            return
        # do one pass to make sure all specified indexes are unique
        idx_set = set()
        for value in self._arc_index_data.values():
            arc_dd, arc_ids = value[0], value[1]
            if 'index' in arc_dd:
                idx = arc_dd['index']
                if idx in idx_set:
                    msg = (
                        f'The canal index value "{idx}" is not unique and will be replaced with a unique value.\n'
                        f'Canal arc ids with duplicate index value: {arc_ids}'
                    )
                    self._wd.logger.warning(msg)
                    arc_dd['index'] = -1
                else:
                    idx_set.add(idx)
        # now set the index on each arc data
        arc_indexes = [-1] * len(self._canal_cov.arcs)
        next_idx = 1
        for value in self._arc_index_data.values():
            arc_dd, arc_ids = value[0], value[1]
            idx = arc_dd.get('index', -1)
            if idx == -1:
                while next_idx in idx_set:
                    next_idx += 1
                arc_dd['index'] = next_idx
                idx_set.add(next_idx)
            idx = arc_dd['index']
            for arc_id in arc_ids:
                arc_indexes[arc_id - 1] = idx
        # write the index file
        idx_fname = self._index_file_name().replace('.index', '')
        util.export_ds_file(idx_fname, arc_indexes, 'index')

    def _write_xml(self):
        """Write the canal XML file."""
        Et.SubElement(self._network_xml, 'geometry', {'file': self._map_filename})
        if self._init_filename:
            Et.SubElement(self._network_xml, 'initial', {'file': self._init_filename})
        self._canal_bc_writer.write()
        self.append_bcs(self._canal_bc_writer.bcs)
        self._write_index_data_xml(self._network_xml)

    def _write_index_data_xml(self, xml_parent):
        """Write the index data to the XML parent.

        Args:
            xml_parent (xml.etree.cElementTree.SubElement): XML parent element
        """
        if len(self._arc_index_data) < 1:
            return

        arcs = Et.SubElement(xml_parent, 'arcs')
        idx_xml = Et.SubElement(arcs, 'indexed', {'file': self._index_file_name()})
        for value in self._arc_index_data.values():
            arc_dd = value[0]
            xsentry = Et.SubElement(idx_xml, 'xsentry', {'id': str(arc_dd['index'])})
            if 'mannings' in arc_dd:
                Et.SubElement(xsentry, 'arcflow', {'n': str(arc_dd['mannings'])})
            if 'leakage_coeff' in arc_dd:
                Et.SubElement(xsentry, 'arcseepage', {'leakage_coeff': str(arc_dd['leakage_coeff'])})
            if 'bank_height' in arc_dd:
                atts = {'bank_height': str(arc_dd['bank_height']), 'bank_coeff': str(arc_dd['bank_coeff'])}
                Et.SubElement(xsentry, 'arcoverbank', atts)
            if 'levee' in arc_dd:
                Et.SubElement(xsentry, 'arclevee', {'coeff': str(arc_dd['levee'])})

    def write(self):
        """Write the RSM control file."""
        self._canal_cov, self._canal_comp = self._wd.xms_data.canal_coverage_and_component
        if self._canal_cov is None:
            return
        self._setup_for_write()
        self._write_map_file()
        self._write_init_file()
        self._write_index_file()
        self._write_xml()

    def _store_monitor_info(self):
        """Store the monitor information for the canal coverage."""
        active_grps = set(self._arc_parameters.active_group_names)
        if not active_grps.intersection(self._monitor_set):  # no monitor groups to write
            return

        ap = self._arc_parameters.copy()  # make a copy of the arc parameters
        for gp in ap.group_names:  # remove groups that are not monitor groups
            if gp not in self._monitor_set or gp not in active_grps:
                ap.remove_group(gp)

        md = (self._cur_canal_id, ap, self._cur_arc_id, self._canal_cov.name)
        self._monitor_data.append(md)

    def write_monitor(self, xml_parent):
        """Write the segment monitor data to the XML parent.

        Args:
            xml_parent (xml.etree.cElementTree.SubElement): XML parent element
        """
        writer = CanalMonitorWriter(xml_parent, self._monitor_data)
        writer.write()
