"""Mf6ProgressTracker class."""

__copyright__ = "(C) Copyright Aquaveo 2019"
__license__ = "All rights reserved"

# 1. Standard Python modules
import shlex

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.file_io import io_factory


class Mf6ProgressTracker:
    """Class to keep track of progress as the simulation runs."""
    prog = None
    query = None
    echo_file = None
    echo_pos = 0
    stress_periods: list[int] = []
    total_ts = 0

    def __init__(self):
        """Initializes the class."""
        pass

    @staticmethod
    def calculate_progress(curr_period, curr_ts):
        """Calculates the progress at a given timestep.

        Also a helper function for progress_function.

        Args:
            curr_period (int): The current stress period
            curr_ts (int): The current time step of the current stress period

        Returns:
            (float): Progress as a percent of all stress_periods.
        """
        # Cast to float to avoid truncation if int
        elapsed_ts = curr_ts
        for i, num_ts in enumerate(Mf6ProgressTracker.stress_periods):
            if i + 1 == curr_period:
                break
            elapsed_ts += num_ts
        if Mf6ProgressTracker.stress_periods:
            return (elapsed_ts / Mf6ProgressTracker.total_ts) * 100.0
        return 100.0

    @staticmethod
    def progress_function():
        """Progress is calculated and sent to query as a percentage."""
        if not Mf6ProgressTracker.echo_file:
            Mf6ProgressTracker.echo_file = Mf6ProgressTracker.prog.command_line_output_file

        found_iter = False
        current_ts = 0
        try:
            with open(Mf6ProgressTracker.echo_file, "r") as f:
                f.seek(Mf6ProgressTracker.echo_pos)
                echo_line = f.readline()
                while echo_line:
                    if (echo_line.endswith('\n') or echo_line.endswith('\r')) \
                            and echo_line.strip().startswith('Solving:  Stress period:'):
                        echo_vals = shlex.split(echo_line)
                        try:
                            current_stress_period = int(echo_vals[3])
                            current_ts = int(echo_vals[6])
                        except IndexError:
                            pass
                        found_iter = True
                        Mf6ProgressTracker.echo_pos = f.tell()
                    echo_line = f.readline()
        except Exception:
            pass  # File might not exist yet

        if found_iter:
            percent_done = Mf6ProgressTracker.calculate_progress(current_stress_period, current_ts)
            Mf6ProgressTracker.query.update_progress_percent(percent_done)

    @staticmethod
    def start_tracking():
        """Entry point for the MODFLOW 6 progress script."""
        # Get the total number of time steps for all stress periods.
        Mf6ProgressTracker.query = Query(progress_script=True, timeout=300000)
        session = Mf6ProgressTracker.query.xms_agent.session
        if not session:
            return

        Mf6ProgressTracker.prog = session.progress_loop
        tdis_mainfile = dmi_util.package_mainfile_from_query(Mf6ProgressTracker.query, 'TDIS6')
        reader = io_factory.reader_from_ftype('TDIS6')
        tdis_data = reader.read(tdis_mainfile)
        Mf6ProgressTracker.stress_periods = tdis_data.period_df.iloc[:, 1].tolist()
        Mf6ProgressTracker.total_ts = sum(Mf6ProgressTracker.stress_periods)

        Mf6ProgressTracker.prog.set_progress_function(Mf6ProgressTracker.progress_function)
        Mf6ProgressTracker.prog.start_loop()
