"""Reads the *.srhgeom files."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import shlex
import sys

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.core.filesystem import filesystem
from xms.data_objects.parameters import Projection
from xms.grid.ugrid import UGrid as XmUGrid

# 4. Local modules


class GeomReader:
    """Reads a SRH-2D geom file."""
    def __init__(self):
        """Constructor."""
        self.logger = logging.getLogger('xms.srh')
        self.data = {}  # The data that is read
        self.lines = []  # The lines in the file
        self.curr_line = 0  # The current line in the file
        self.line_count = 0  # Number of lines in the file (which equals len(self.lines))
        self.temp_mesh_file = ''  # Path to the file where the mesh is saved
        self.monitor_line_ids = None  # list of monitor line ids if 'MonitorString' used
        self.cogrid = None  # The grid

    def _read_and_units(self):
        """Reads the name and the length units."""
        self.logger.info('Reading geometry name and units.')
        # Use shlex to handle quoted strings
        words = shlex.split(self.lines[1], posix="win" not in sys.platform)
        self.data['name'] = words[1].strip('"\'')
        self.data['units'] = self.lines[3]
        self.curr_line = 4

    def _read_elements(self):
        """Reads the elements."""
        self.logger.info('Reading elements.')
        elements = []
        words_to_skip = 2  # Skip over the 'Elem' word and the element number
        while self.curr_line < self.line_count:
            words = self.lines[self.curr_line].split()
            if words:
                if words[0].lower() != 'elem':
                    break
                node_list = [(int(i) - 1) for i in words[words_to_skip:]]
                elements.append(node_list)
            self.curr_line += 1

        self.data['elements'] = elements

    def _read_nodes(self):
        """Reads the nodes."""
        self.logger.info('Reading nodes.')
        nodes = []
        words_to_skip = 2  # Skip over the 'Node' word and the node number
        while self.curr_line < self.line_count:
            words = self.lines[self.curr_line].split()
            if words[0].lower() != 'node':
                break
            xyz = [float(i) for i in words[words_to_skip:]]
            nodes.append(xyz)
            self.curr_line += 1

        self.data['nodes'] = nodes

    def _read_node_string(self, words):
        """Reads a node string.

        The nodestring may wrap on several lines. Keep reading lines until the first word is not an integer.

        Args:
            words (:obj:`list[str]`): First split line of the nodestring minus 'Nodestring' and the nodestring number.

        Returns:
            (:obj:`list[int]`): Node ids comprising the nodestring.
        """
        node_string = []
        while words:
            if not words[0].isdigit():
                self.curr_line -= 1
                break
            node_string.extend([(int(i) - 1) for i in words])
            self.curr_line += 1
            if self.curr_line >= self.line_count:
                break
            words = self.lines[self.curr_line].split()

        return node_string

    def _read_node_strings(self):
        """Reads the node strings."""
        self.logger.info('Reading node strings.')
        node_strings = {}
        monitor_strings = {}
        bc_strings = {}
        node_string_id_word = 1
        words_to_skip = 2  # Skip over the 'Nodestring' word and the node string number
        while self.curr_line < self.line_count:
            words = self.lines[self.curr_line].split()
            if words and words[0].lower() == 'nodestring':
                node_string = self._read_node_string(words[words_to_skip:])
                node_strings[int(words[node_string_id_word])] = node_string
            elif words and words[0].lower() == 'monitorstring':
                monitor_string = self._read_node_string(words[words_to_skip:])
                monitor_strings[int(words[node_string_id_word])] = monitor_string
            elif words and words[0].lower() == 'bcstring':
                bc_string = self._read_node_string(words[words_to_skip:])
                bc_strings[int(words[node_string_id_word])] = bc_string
            self.curr_line += 1
        # convert monitorstring and bcstring into nodestrings
        if bc_strings:
            node_strings = bc_strings
            max_key = max(bc_strings.keys())
            if monitor_strings:
                self.monitor_line_ids = []
            for idx, value in enumerate(monitor_strings.values()):
                new_id = max_key + idx + 1
                self.monitor_line_ids.append(new_id)
                node_strings[new_id] = value

        self.data['node_strings'] = node_strings

    def _get_cell_stream(self):
        """Returns the cell stream."""
        cell_stream = []
        for element in self.data['elements']:
            if len(element) == 3:
                cell_stream.extend([XmUGrid.cell_type_enum.TRIANGLE, 3, element[0], element[1], element[2]])
            else:
                cell_stream.extend([XmUGrid.cell_type_enum.QUAD, 4, element[0], element[1], element[2], element[3]])
        return cell_stream

    def _build_mesh(self):
        """Builds the mesh and writes it to disk."""
        self.logger.info('Building the mesh.')
        cell_stream = self._get_cell_stream()
        xmugrid = XmUGrid(self.data['nodes'], cell_stream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()
        co_builder.set_ugrid(xmugrid)
        self.cogrid = co_builder.build_grid()
        self.temp_mesh_file = filesystem.temp_filename()
        self.cogrid.write_to_file(self.temp_mesh_file, True)

    def geom_projection_from_grid_units(self):
        """Get a data objects projection from the grid units.

        Returns:
            (:obj:`xms.data_objects.parameters.Projection`): xms projection
        """
        if 'foot' in self.data['units'].lower():
            units = 'FEET (U.S. SURVEY)'
        else:
            units = 'METERS'
        proj = Projection(horizontal_units=units, vertical_units=units)
        return proj

    def read(self, filename):
        """Reads an SRH geom file and returns a dictionary containing the data.

        Args:
            filename (:obj:`str`): Filepath of geom file to be read.

        Returns:
            (:obj:`dict`): The imported data
        """
        try:
            with open(filename, 'r') as file:
                self.lines = file.read().splitlines()
                self.line_count = len(self.lines)
                self._read_and_units()
                self._read_elements()
                self._read_nodes()
                self._read_node_strings()
                self._build_mesh()
                return self.data

        except Exception as error:  # pragma: no cover
            self.logger.exception(f'Error reading geometry: {str(error)}')
            raise error
