"""Class for writing mesh bcs to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import Polygon

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.rsm.data import lake_data_def as ldd
from xms.rsm.file_io import util
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.lake_monitor_writer import LakeMonitorWriter
from xms.rsm.file_io.water_body_info import WaterBodyInfo
from xms.rsm.file_io.water_body_rain_et_writer import WaterBodyRainEtData, WaterBodyRainEtWriter


class _WriterData:
    """Data class for writer data."""
    def __init__(self, writer_data, lake_label_id):
        self.logger = writer_data.logger
        self.lake_label_id = lake_label_id
        self.lake_info = writer_data.water_body_info['lake']
        self.xms_data = writer_data.xms_data
        self.xml_hse = writer_data.xml_hse
        self.csv_writer = writer_data.csv_writer
        self.rule_curve_label_id = writer_data.rule_curve_label_id
        self.pp = ldd.generic_model().polygon_parameters
        self.cur_lake_id = writer_data.xms_data.waterbody_start_id + 100_000
        self.cur_poly_id = -1
        self.cur_poly_pts = []
        self.cur_msg = ''
        self.cov = None
        self.comp = None
        self.cov_name = ''
        self.cur_lake_xml = None
        self.bcs = []
        monitor_grps = [self.pp.group(nm) for nm in self.pp.group_names]
        self.monitor_set = set([gp.group_name for gp in monitor_grps if gp.label.startswith('Monitor ')])
        self.monitor_data = []

    @cached_property
    def lakes_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.xml_hse, 'lakes')

    @cached_property
    def bc_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self.lakes_xml, 'lake_bc')


class LakeWriter:
    """Writer class for lakes in the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self.lake_label_id = {}
        self._data = _WriterData(writer_data, self.lake_label_id)
        self._write_methods = {
            'rain': self._write_rain,
            'refet': self._write_refet,
            'hpm': self._write_hpm,
            'bc': self._write_bc,
        }

    def write(self):
        """Write the mesh bc portion of the control file."""
        for cov, comp in self._data.xms_data.lake_coverages:
            self._data.cov = cov
            self._data.comp = comp
            self._data.cov_name = cov.name
            self._write_lake_coverage()
        # put BCs after all the lakes
        for bc in self._data.bcs:
            self._data.bc_xml.append(bc)

    def write_monitor(self, xml_parent):
        """Write the monitor data to the xml.

        Args:
            xml_parent (xml.etree.cElementTree.Element): xml parent element
        """
        lmw = LakeMonitorWriter(xml_parent, self._data.monitor_data)
        lmw.write()

    def _write_lake_coverage(self):
        """Write the lake coverage to the xml."""
        for poly in self._data.cov.polygons:
            self._data.cur_poly_pts = get_polygon_point_lists(poly)
            self._data.cur_poly_id = poly.id
            self._data.cur_msg = f'Polygon id: "{poly.id}" in lake coverage "{self._data.cov_name}"'
            comp_id = self._data.comp.get_comp_id(TargetType.polygon, poly.id)
            if comp_id is None or comp_id < 0:
                comp_id = util.UNINITIALIZED_COMP_ID
            if comp_id == util.UNINITIALIZED_COMP_ID:
                continue
            p_type, p_val = self._data.comp.data.feature_type_values(TargetType.polygon, comp_id)
            self._data.pp.restore_values(p_val)
            active_groups = self._data.pp.active_group_names
            if not active_groups:
                continue
            if 'lake' not in active_groups:  # if the lake group is not active then skip this polygon
                msg = f'{self._data.cur_msg} skipped because the "Lake" item is not checked.'
                self._data.logger.warning(msg)
                continue

            self._write_lake()
            for grp_name in active_groups:
                if grp_name in self._write_methods:
                    self._write_methods[grp_name]()

            self._store_monitor_info()

    def _write_lake(self):
        """Write the lake data to the xml."""
        self._data.cur_lake_id += 1
        writer = _LakeWriter(self._data)
        writer.write()

    def _write_rain(self):
        """Write the rain data to the xml."""
        writer = _LakeRainEtWriter('rain', self._data)
        writer.write()

    def _write_refet(self):
        """Write the refet data to the xml."""
        writer = _LakeRainEtWriter('refet', self._data)
        writer.write()

    def _write_hpm(self):
        """Write the HPM data to the xml."""
        writer = _LakeHpmWriter(self._data)
        writer.write()

    def _write_bc(self):
        """Write the boundary condition data to the xml."""
        writer = _LakeBcWriter(self._data)
        writer.write()

    def _store_monitor_info(self):
        """Store monitor information for the lake."""
        active_grps = set(self._data.pp.active_group_names)
        if not active_grps.intersection(self._data.monitor_set):  # no monitor groups to write
            return

        pp = self._data.pp.copy()  # make a copy of the polygon parameters
        for gp in pp.group_names:  # remove groups that are not monitor groups
            if gp not in self._data.monitor_set or gp not in active_grps:
                pp.remove_group(gp)

        md = (self._data.cur_lake_id, pp, self._data.cur_poly_id, self._data.cov_name)
        self._data.monitor_data.append(md)


class _LakeWriter:
    """Class for writing lake rain/ET data to the xml."""
    def __init__(self, data):
        """Constructor."""
        self.data = data
        self._package_write = {
            'SSTable': self._write_sstable,
            'cylinder': self._write_cylinder,
            'parabolic': self._write_parabolic,
            # 'polynomial': self._write_polynomial,  TODO: Implement polynomial package
        }

    def write(self):
        """Write the lake data to the xml."""
        gp = self.data.pp.group('lake')
        label = gp.parameter('label').value
        head0 = gp.parameter('head0').value
        package = gp.parameter('package').value
        atts = {
            'id': str(self.data.cur_lake_id),
            'head0': str(head0),
        }
        if label:
            atts['label'] = label
            self.data.lake_label_id[label] = self.data.cur_lake_id
        atts['package'] = str(package)
        self.data.cur_lake_xml = Et.SubElement(self.data.lakes_xml, 'lake', attrib=atts)
        self._package_write[package]()
        poly_pts = [(p[0], p[1]) for p in self.data.cur_poly_pts[0]]
        sh_poly = Polygon(poly_pts)
        self.data.lake_info.append(WaterBodyInfo(label, self.data.cur_lake_id, sh_poly))

    def _write_sstable(self):
        """Write the stage-volume table data to the xml."""
        gp = self.data.pp.group('lake')
        sstab_xml = Et.SubElement(self.data.cur_lake_xml, 'SSTable')
        sv_table = gp.parameter('sv').value
        err_msg = f'{self.data.cur_msg} has no stage-volume table data. Aborting.'
        self._write_table(sv_table, sstab_xml, 'sv', err_msg)
        sa_table = gp.parameter('sa').value
        err_msg = f'{self.data.cur_msg} has no stage-area table data. Aborting.'
        self._write_table(sa_table, sstab_xml, 'sa', err_msg)

    def _write_table(self, table, xml_parent, xml_tag, err_msg):
        """Write a table to the xml."""
        if len(table) < 1:
            self.data.logger.error(err_msg)
            raise RuntimeError
        table_xml = Et.SubElement(xml_parent, xml_tag)
        for item in table:
            line_elem = Et.SubElement(table_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _write_cylinder(self):
        gp = self.data.pp.group('lake')
        bot = gp.parameter('bottom').value
        top_area = gp.parameter('toparea').value
        atts = {
            'bot': str(bot),
            'toparea': str(top_area),
        }
        Et.SubElement(self.data.cur_lake_xml, 'cylinder', atts)

    def _write_parabolic(self):
        """Write the parabolic lake data to the xml."""
        gp = self.data.pp.group('lake')
        top = gp.parameter('top').value
        bot = gp.parameter('para_bottom').value
        top_area = gp.parameter('para_toparea').value
        atts = {
            'top': str(top),
            'bot': str(bot),
            'toparea': str(top_area),
        }
        Et.SubElement(self.data.cur_lake_xml, 'parabolic', atts)


class _LakeRainEtWriter:
    """Class for writing lake rain/ET data to the xml."""
    def __init__(self, group_name, data):
        """Constructor."""
        d = data
        wb_name = 'Lake'
        wbd = WaterBodyRainEtData(
            d.xms_data, group_name, d.cur_lake_xml, d.pp, d.csv_writer, wb_name, d.cur_poly_id, d.cov_name
        )
        self.wb_rain_et_writer = WaterBodyRainEtWriter(wbd)

    def write(self):
        """Write the rain or refet data to the xml."""
        self.wb_rain_et_writer.write()


class _LakeHpmWriter:
    """Class for writing lake HPM data to the xml."""
    def __init__(self, data):
        """Constructor."""
        self.data = data
        self.hpm_xml = None
        self.lake_et_xml = None
        self._export_pacakge = {
            ldd.HPM_LITZONE: self._litzone,
            ldd.HPM_PAN_CONST: self._pan_const,
            ldd.HPM_PAN_VAR: self._pan_var,
            #  ldd.HPM_SFWMM: self._sfwmm,  TODO: Implement SFWMM package
        }

    def write(self):
        """Write the lake HPM data to the xml."""
        gp = self.data.pp.group('hpm')
        atts = {
            'id': str(self.data.cur_lake_id),
        }
        if gp.parameter('label').value:
            atts['label'] = gp.parameter('label').value
        if gp.parameter('tag').value:
            atts['tag'] = gp.parameter('tag').value
        self.hpm_xml = Et.SubElement(self.data.lakes_xml, 'hpmEntry', attrib=atts)

        package = gp.parameter('package').value
        atts = {'package': package}
        if gp.parameter('package_label').value:
            atts['label'] = gp.parameter('package_label').value
        self.lake_et_xml = Et.SubElement(self.hpm_xml, 'lakeET', atts)
        self._export_pacakge[package]()

    def _litzone(self):
        gp = self.data.pp.group('hpm')
        a_xml = Et.SubElement(self.lake_et_xml, 'A')
        open_water_coef = gp.parameter('open_water_coef').value
        et_coeff = gp.parameter('et_coef').value
        a_xml.text = f'{open_water_coef} {et_coeff}'
        # TODO monitor for B and C
        # monthly_et = gp.parameter('monthly_et').value
        # monthly_rain = gp.parameter('monthly_rain').value
        sa_table = gp.parameter('sa').value
        if len(sa_table) > 0:
            sa_xml = Et.SubElement(self.lake_et_xml, 'sa')
            for item in sa_table:
                line_elem = Et.SubElement(sa_xml, 'remove_me')
                line_elem.text = f'{item[0]} {item[1]}'

    def _pan_const(self):
        gp = self.data.pp.group('hpm')
        pca_et_coef = gp.parameter('pca_et_coef').value
        pca_area = gp.parameter('pca_area').value
        self.lake_et_xml.text = f'{pca_et_coef} {pca_area}'

    def _pan_var(self):
        gp = self.data.pp.group('hpm')
        pva_et_coef = gp.parameter('pva_et_coef').value
        self.lake_et_xml.text = f'{pva_et_coef}'


class _LakeBcWriter:
    """Class for writing lake boundary condition data to the xml."""
    def __init__(self, data):
        """Constructor."""
        self.data = data
        self.atts = {
            'lakeID': str(self.data.cur_lake_id),
        }
        self._bc_write = {
            'lakesource': self._write_lakesource,
            'owet': self._write_owet,
            'lakeHeadBC': self._write_head,
            'lakeghb': self._write_ghb,
        }
        self._cur_bc_xml = None
        self.gp = self.data.pp.group('bc')
        self.bc_type = self.gp.parameter('bc_type').value
        csv_name = f'{self.bc_type}_{self.data.cov_name}_{self.data.cur_poly_id}.csv'.replace(' ', '_')
        self.data.csv_writer.set_desired_filename(csv_name)

    def write(self):
        label = self.gp.parameter('label').value
        if label:
            self.atts['label'] = label
        bc_id = self.gp.parameter('bc_id').value
        if bc_id > 0:
            self.atts['id'] = str(bc_id)
        self._bc_write[self.bc_type]()
        self._write_bc_val()
        self.data.bcs.append(self._cur_bc_xml)

    def _write_bc_val(self):
        """Write the boundary condition value to the xml."""
        d = self.data
        bc_val = BcValWriter(self._cur_bc_xml, self.gp, d.csv_writer, d.rule_curve_label_id, d.cur_msg)
        bc_val.write()

    def _write_lakesource(self):
        package = self.gp.parameter('lksrc_pkg').value
        area = self.gp.parameter('lksrc_area').value
        if package != 'Not specified':
            self.atts['package'] = package
        self._cur_bc_xml = Et.Element('lakesource', self.atts)
        if 'package' in self.atts:
            area_xml = Et.SubElement(self._cur_bc_xml, 'A')
            area_xml.text = str(area)

    def _write_owet(self):
        table = self.gp.parameter('sa').value
        if len(table) < 1:
            msg = f'No stage-area table data for {self.data.cur_msg}. Aborting.'
            self.data.logger.error(msg)
            raise RuntimeError
        self._cur_bc_xml = Et.Element('owet', self.atts)
        tab_xml = Et.SubElement(self._cur_bc_xml, 'sa')
        for item in table:
            line_elem = Et.SubElement(tab_xml, 'remove_me')
            line_elem.text = f'{item[0]} {item[1]}'

    def _write_head(self):
        self._cur_bc_xml = Et.Element('lakeHeadBC', self.atts)

    def _write_ghb(self):
        self.atts['flowdir'] = self.gp.parameter('flowdir').value
        self.atts['kcoeff'] = str(self.gp.parameter('kcoeff').value)
        self._cur_bc_xml = Et.Element('lakeghb', self.atts)
