"""Class for writing mesh bcs to model input files."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import xml.etree.cElementTree as Et

# 2. Third party modules
from shapely.geometry import LineString, Point

# 3. Aquaveo modules

# 4. Local modules
from xms.rsm.data.bc_data_def import generic_model
from xms.rsm.file_io.bc_val_writer import BcValWriter
from xms.rsm.file_io.mesh_bc_snapper import MeshBcSnapper


class _MeshBcVars:
    """Class for holding mesh bc variables."""
    def __init__(self, bc_writer, ug_locs, csv_writer, bc_comp, cov_name, rule_curve_label_id):
        """Constructor.

        Args:
            bc_writer (MeshBcWriter): mesh item in the xml
            ug_locs (list): list of locations for the mesh
            csv_writer (CsvWriter): csv writer
            bc_comp (BcComponent): the bc component
            cov_name (str): name of the coverage
            rule_curve_label_id (dict): dictionary mapping rule curve labels to ids
        """
        self.bc_writer = bc_writer
        self.ug_locs = ug_locs
        self.csv_writer = csv_writer
        self.bc_comp = bc_comp
        self.cov_name = cov_name
        self.rule_curve_label_id = rule_curve_label_id


class _ArcBcVars:
    """Class for holding arc bc variables."""
    def __init__(self, arc_id, bc_type, gm_arc, pt_string):
        """Constructor.

        Args:
            arc_id (int): id of the arc
            bc_type (str): type of the bc
            gm_arc (xms.gmi.generic_model.Section): arc section for the generic model
            pt_string (list): list of points
        """
        self.arc_id = arc_id
        self.bc_type = bc_type
        self.gm_group = gm_arc.group(bc_type)
        self.pt_string = pt_string


class _PtBcVars:
    """Class for holding point bc variables."""
    def __init__(self, pt_id, bc_type, gm_pt, cell_idx):
        """Constructor.

        Args:
            pt_id (int): id of the point
            bc_type (str): type of the bc
            gm_pt (xms.gmi.generic_model.Section): point section for the generic model
            cell_idx (int): index of the cell
        """
        self.pt_id = pt_id
        self.bc_type = bc_type
        self.gm_group = gm_pt.group(bc_type)
        self.cell_idx = cell_idx


class _MeshBc:
    """Base class for mesh boundary condition data."""
    def __init__(self, mesh_bc_vars):
        """Constructor.

        Args:
            mesh_bc_vars (_MeshBcVars): mesh bc variables
            arc_bc_vars (_ArcBcVars): arc bc variables
        """
        self.bc_writer = mesh_bc_vars.bc_writer
        self.ug_locs = mesh_bc_vars.ug_locs
        self.csv_writer = mesh_bc_vars.csv_writer
        self.bc_comp = mesh_bc_vars.bc_comp
        self.cov_name = mesh_bc_vars.cov_name
        self.rule_curve_label_id = mesh_bc_vars.rule_curve_label_id
        self.bc_type = ''
        self.gm_group = None
        self.bc_id = -1
        self._xml_atts = {}
        self._xml_sub_tag = ''
        self._new_sub_element = None
        self._bc_val_xml = None
        self._bc_val = None
        self._bc_val_type = None

    def _add_element(self):
        """Adds a new sub element to the xml."""
        if self.gm_group.parameter('label').value:
            self._xml_atts['label'] = self.gm_group.parameter('label').value
        self._new_sub_element = Et.SubElement(self.bc_writer.mesh_bc_xml, self._xml_sub_tag, self._xml_atts)


class _MeshBcArc(_MeshBc):
    """Class for writing mesh bc arcs."""
    def __init__(self, mesh_bc_vars, arc_bc_vars, feature_id):
        """Constructor.

        Args:
            mesh_bc_vars (_MeshBcVars): mesh bc variables
            arc_bc_vars (_ArcBcVars): arc bc variables
            feature_id (str): string identifier of the feature and the coverage
        """
        super().__init__(mesh_bc_vars)
        self.bc_type = arc_bc_vars.bc_type
        self.gm_group = arc_bc_vars.gm_group
        self.pt_string = arc_bc_vars.pt_string
        self.bc_id = arc_bc_vars.arc_id
        self.feature_id = feature_id
        self._wts2pts_xml = None
        self._uniform = False
        self._arc_method = {
            'noflow': self._arc_noflow,
            'wallhead': self._arc_wallhead,
            'wallghb': self._arc_wallghb,
            'walluf': self._arc_walluf,
        }

    def _add_nodelist(self):
        """Add nodelist to a given xml item."""
        nd_list = Et.SubElement(self._new_sub_element, 'nodelist')
        nd_list.text = ' '.join(f'{idx + 1}' for idx in self.pt_string)

    def _add_wts2pts_to_xml(self):
        """Add items for the wts2pts to the xml."""
        self._uniform = self.gm_group.parameter('uniform').value
        if self._uniform:
            self._wts2pts_xml = Et.SubElement(self._new_sub_element, 'uniform')
        else:
            self._wts2pts_xml = Et.SubElement(self._new_sub_element, 'wts2pts')
        self._add_head_to_wts2pts('upstream_')
        if not self._uniform:
            self._add_head_to_wts2pts('downstream_')
            self._add_wts_to_wts2pts_xml()

    def _add_head_to_wts2pts(self, head_type):
        """Add wts for the wts2pts item.

        Arguments:
            head_type (str): the type of head to add
        """
        entry_id = '1' if head_type == 'upstream_' else '2'
        if self._uniform:
            self._bc_val_xml = self._wts2pts_xml
        else:
            self._bc_val_xml = Et.SubElement(self._wts2pts_xml, 'entry', {'id': entry_id})
        csv_file = f'{self._xml_sub_tag}_{self.cov_name}_{self.bc_id}.csv'
        self.csv_writer.set_desired_filename(csv_file)
        bc_val_writer = BcValWriter(
            self._bc_val_xml, self.gm_group, self.csv_writer, self.rule_curve_label_id, self.feature_id
        )
        bc_val_writer.set_bc_val_prefix(head_type)
        bc_val_writer.write()

    def _add_wts_to_wts2pts_xml(self):
        """Add wts for the wts2pts item."""
        xy = self.ug_locs
        locs = [(xy[pt_idx][0], xy[pt_idx][1]) for pt_idx in self.pt_string]
        ls = LineString(locs)
        wts = [0.0]
        for p in locs[1:-1]:
            wts.append(ls.project(Point(p), normalized=True))
        wts.append(1.0)
        xml_wts = Et.SubElement(self._wts2pts_xml, 'wts')
        xml_wts.text = ' '.join(f'{round(w, 3)}' for w in wts)

    def write_arc(self):
        """Write the arc bc to the xml."""
        self._xml_sub_tag = self.bc_type
        self._arc_method[self.bc_type]()

    def _arc_noflow(self):
        """Adds xml element for noflow."""
        self._xml_atts['section'] = self.gm_group.parameter('section').value
        self._add_element()
        self._add_nodelist()

    def _arc_wallhead(self):
        """Adds xml element for wallhead."""
        self._xml_atts['section'] = self.gm_group.parameter('section').value
        self._add_element()
        self._add_nodelist()
        self._add_wts2pts_to_xml()

    def _arc_wallghb(self):
        """Adds xml element for noflow."""
        self._xml_atts['value'] = f'{self.gm_group.parameter("ghb_factor").value}'
        self._add_element()
        self._add_nodelist()
        self._add_wts2pts_to_xml()

    def _arc_walluf(self):
        """Adds xml element for noflow."""
        self._xml_atts['value'] = f'{self.gm_group.parameter("uniform_flow").value}'
        self._add_element()
        self._add_nodelist()


class _MeshBcPt(_MeshBc):
    """Class for writing mesh bc points."""
    def __init__(self, mesh_bc_vars, pt_bc_vars, feature_id):
        """Constructor.

        Args:
            mesh_bc_vars (_MeshBcVars): mesh bc variables
            pt_bc_vars (_PtBcVars): arc bc variables
            feature_id (str): string identifier of the feature and the coverage
        """
        super().__init__(mesh_bc_vars)
        self.bc_type = pt_bc_vars.bc_type
        self.gm_group = pt_bc_vars.gm_group
        self.bc_id = pt_bc_vars.pt_id
        self.cell_idx = pt_bc_vars.cell_idx
        self.feature_id = feature_id
        self._pt_method = {'well': self._pt_well, 'cellhead': self._pt_cellhead, 'cellghb': self._pt_cellghb}

    def write_pt(self):
        """Write the point bc to the xml."""
        self._xml_sub_tag = self.bc_type
        self._pt_method[self.bc_type]()
        self._add_element()
        csv_filename = f'{self._xml_sub_tag}_{self.cov_name}_{self.bc_id}.csv'
        self.csv_writer.set_desired_filename(csv_filename)
        bc_val_writer = BcValWriter(
            self._new_sub_element, self.gm_group, self.csv_writer, self.rule_curve_label_id, self.feature_id
        )
        bc_val_writer.write()

    def _pt_well(self):
        """Adds xml element for a well."""
        self._xml_atts['cellid'] = f'{self.cell_idx + 1}'
        if self.gm_group.parameter('well_id').value > 0:
            self._xml_atts['wellid'] = f"{self.gm_group.parameter('well_id').value}"
            self.bc_id = self.gm_group.parameter('well_id').value
        val_type = self.gm_group.parameter('well_flow_type').value
        if 'bc_val_flt_type' not in self.gm_group.parameter_names:
            self.gm_group.add_text(name='bc_val_flt_type', label='Temp variable for writer', default=val_type)
        else:
            self.gm_group.parameter('well_flow_type').value = val_type

    def _check_add_bc_id(self):
        """Check if bc_id needs to be added."""
        if self.gm_group.parameter('bc_id').value > 0:
            self._xml_atts['bcid'] = f"{self.gm_group.parameter('bc_id').value}"
            self.bc_id = self.gm_group.parameter('bc_id').value

    def _pt_cellghb(self):
        """Adds xml element for a well."""
        self._xml_atts['id'] = f'{self.cell_idx + 1}'
        self._xml_atts['value'] = f"{self.gm_group.parameter('ghb_factor').value}"

    def _pt_cellhead(self):
        """Adds xml element for a well."""
        self._xml_atts['id'] = f'{self.cell_idx + 1}'
        self._check_add_bc_id()


class MeshBcWriter:
    """Writer class for the RSM control file."""
    def __init__(self, writer_data):
        """Constructor.

        Args:
            writer_data (WriterData): Class with information needed to writer model input files.
        """
        self._logger = writer_data.logger
        self._xms_data = writer_data.xms_data
        self._csv_writer = writer_data.csv_writer
        self._rule_curve_label_id = writer_data.rule_curve_label_id
        self._bc_snapper = MeshBcSnapper(writer_data.xms_data)
        self._bc_cov = None
        self._bc_comp = None
        self._mesh_bc_vars = None
        self._xml_mesh_parent = writer_data.xml_mesh
        self._ug_locs = self._xms_data.xmugrid.locations
        gm = generic_model()
        self._gm_pt = gm.point_parameters
        self._gm_arc = gm.arc_parameters
        self._err_msg = ''

    @cached_property
    def mesh_bc_xml(self):
        """Get the mesh bc xml element."""
        return Et.SubElement(self._xml_mesh_parent, 'mesh_bc')

    def write(self):
        """Write the mesh bc portion of the control file."""
        # intersect coverages with the grid/mesh
        if not self._xms_data.bc_coverages:
            return
        self._logger.info('Processing Boundary Condition coverages.')
        self._bc_snapper.generate_snap()
        for cov_comp in self._xms_data.bc_coverages:
            self._bc_cov = cov_comp[0]
            self._bc_comp = cov_comp[1]
            self._logger.info(f'Processing coverage: {cov_comp[0].name}')
            cov_name = self._bc_cov.name.replace(' ', '_')
            self._mesh_bc_vars = _MeshBcVars(
                self, self._ug_locs, self._csv_writer, self._bc_comp, cov_name, self._rule_curve_label_id
            )
            self._process_cov_arcs()
            self._process_cov_points()

    def _process_cov_points(self):
        """Process the points in the current coverage."""
        pt_grid_cell_snap = self._bc_snapper.pt_grid_cell_snap[self._bc_cov.uuid]
        for pt_id, data in pt_grid_cell_snap.items():
            self._err_msg = f'point id: "{pt_id}" in coverage: "{self._bc_cov.name}"'
            self._gm_pt.restore_values(data['comp_val'])
            pt_bc_vars = _PtBcVars(pt_id, data['comp_type'], self._gm_pt, data['cell_idx'])
            mesh_bc = _MeshBcPt(self._mesh_bc_vars, pt_bc_vars, self._err_msg)
            mesh_bc.write_pt()

    def _process_cov_arcs(self):
        """Process the arcs in the current coverage."""
        arc_grid_pt_snap = self._bc_snapper.arc_grid_pt_snap[self._bc_cov.uuid]
        for arc_id, data in arc_grid_pt_snap.items():
            self._feature_id = f' arc id: "{arc_id}" in coverage: "{self._bc_cov.name}"'
            self._gm_arc.restore_values(data['comp_val'])
            arc_bc_vars = _ArcBcVars(arc_id, data['comp_type'], self._gm_arc, data['pt_string'])
            mesh_bc = _MeshBcArc(self._mesh_bc_vars, arc_bc_vars, self._feature_id)
            mesh_bc.write_arc()
