"""Context manager for running Query in playback mode during tests."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import difflib
import filecmp
import logging
import os
import shutil
import sys
import tempfile

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.core.filesystem import filesystem
from xms.guipy.dialogs.process_feedback_dlg import LogEchoQSignalStream

# 4. Local modules


class QueryPlayback:
    """Context manager for running Query in playback mode during tests."""
    def __init__(self, *args, **kwargs):
        """Construct the wrapper.

        Args:
            *args:
                recording_folder (str): The path to the recording folder
            **kwargs:
                rerecord (bool): If True, will delete the the recording folder to force a re-record
                compare_log (bool): If True, will capture logging output and compare to baseline
                xms_temp (str): Full path to the XMS temp directory that should be used
                process_temp (str): Full path to the process temp directory that should be used
                app_name (str): Name of the XMS app that should be used
                app_version (str): XMS app version that should be used
                notes_db (str): Full path to the XMS notes database file that should be used
                project_path (str): Absolute filename of the XMS project that should be used
                running_xms_tests (str): Set the flag that indicates the XMS app is running tests itself
                sent_data_file (str): Name of sent_data.base file.
                request_file (str): Name of the request.rec file.
                logging_file (str): Name of the logging.base file.

        """
        if args:
            self._recording_folder = args[0]
            kwargs['recording_folder'] = self._recording_folder
        else:
            self._recording_folder = kwargs.get('recording_folder')
            if not self._recording_folder:
                raise ValueError('Must provide the path to the recording folder.')
        self._logger = None
        self._sysstdout = sys.stdout  # Restore sys std output streams on exit in case code uses ProcessFeedbackDlg
        self._sysstderr = sys.stderr

        # Check if we need to compare the logging output
        self._compare_log = kwargs.get('compare_log', True)

        # Delete the recording folder if we need to force a re-record.
        self._rerecord = kwargs.get('rerecord', False)
        if self._rerecord and os.path.isdir(self._recording_folder):
            shutil.rmtree(self._recording_folder, ignore_errors=True)

        # Set specified environment variables
        self.configure_environment(**kwargs)

        self._playback_folder = XmEnv.xms_environ_playback_folder()

    def __enter__(self):
        """Write the file that will trigger playback (or re-record) mode when the next Query object is constructed.

        Returns:
            XmsQueryEnvironment: self
        """
        # Copy all recorded data files to a temp folder so we don't accidentally clean up recorded data.
        shutil.rmtree(self._playback_folder, ignore_errors=True)
        shutil.copytree(self._recording_folder, self._playback_folder)

        # Reset the static error flags on the ProcessFeedbackDlg stream in case a previous test has us in a bad state.
        LogEchoQSignalStream.reset_flags()

        self._logger = logging.getLogger('xms')
        self._logger.setLevel(logging.DEBUG)
        log_file = os.path.join(self._playback_folder, XmEnv.LOGGING_OUT_FILE)
        filesystem.removefile(log_file)  # Clear the old logging file
        fh = logging.FileHandler(log_file)
        fh.setLevel(logging.DEBUG)
        formatter = logging.Formatter('%(name)s - %(levelname)s - %(message)s')
        fh.setFormatter(formatter)
        self._logger.addHandler(fh)

        # Delete the old send data output file if it exists to avoid false positives in tests.
        output_file = os.path.join(self._playback_folder, XmEnv.SENT_DATA_OUT_FILE)
        filesystem.removefile(output_file)

        # Write the file that will trigger record or playback when the next Query object is constructed.
        with open(XmEnv.xms_environ_record_trigger_file(), 'w') as f:
            f.write(f'"{self._recording_folder}"')
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """Compare sent data baseline to output if not re-recording.

        Args:
            exc_type: unused
            exc_value: unused
            traceback: unused

        Returns:
            bool: False if sent data does not match baseline and not re-recording, True otherwise
        """
        sys.stdout = self._sysstdout
        sys.stderr = self._sysstderr

        for handler in self._logger.handlers:  # Disable piping log output to file.
            handler.close()
            self._logger.removeHandler(handler)

        # If an unhandled exception was thrown or we are re-recording, don't bother comparing baseline.
        if self._rerecord or exc_value is not None:
            return

        # Reset the static error flags on the ProcessFeedbackDlg stream so as not to mess up other tests.
        LogEchoQSignalStream.reset_flags()

        # Remove the trigger file
        path, filename = os.path.split(XmEnv.RECORD_TRIGGER_FILE)
        full_filename = os.path.normpath(os.path.join(path, f'_{filename}_{os.getpid()}'))
        filesystem.removefile(full_filename)

        self._check_sent_data()
        self._check_log_output()

        shutil.rmtree(self._playback_folder, ignore_errors=True)

    def _check_sent_data(self):
        """Compare the sent data output to the baseline."""
        sent_data_base = os.path.join(self._playback_folder, XmEnv.xms_environ_sent_data_base_file())
        sent_data_out = os.path.join(self._playback_folder, XmEnv.SENT_DATA_OUT_FILE)
        if not os.path.isfile(sent_data_out):
            raise RuntimeError(f'Unable to find output file: {sent_data_out}')
        if not os.path.isfile(sent_data_base):
            raise RuntimeError(f'Unable to find baseline file: {sent_data_base}')
        if not filecmp.cmp(sent_data_base, sent_data_out):
            raise RuntimeError(
                f'Expected sent data does not match:\nExpected: {sent_data_base}\nFound: {sent_data_out}'
            )

    def _check_log_output(self):
        """Compare the logging output to the baseline."""
        if self._compare_log:
            log_file_base = os.path.join(self._playback_folder, XmEnv.xms_environ_logging_base_file())
            log_file_out = os.path.join(self._playback_folder, XmEnv.LOGGING_OUT_FILE)
            if not os.path.isfile(log_file_out):
                raise RuntimeError(f'Unable to find output file: {log_file_out}')
            if not os.path.isfile(log_file_base):
                raise RuntimeError(f'Unable to find baseline file: {log_file_base}')

            differ = difflib.Differ()
            with open(log_file_base) as f:
                base_lines = f.readlines()
            with open(log_file_out) as f:
                out_lines = f.readlines()
            diff = differ.compare(base_lines, out_lines)
            changes = [  # Skip changed lines if it is the "Elapsed time:" message from the ProcessFeedbackDlg
                line for line in diff if
                (line.startswith('+ ') or line.startswith('- ')) and 'INFO - Elapsed time:' not in line
            ]
            if changes:
                raise RuntimeError(f'Unexpected logger output found: {changes}')

    @staticmethod
    def configure_environment(**kwargs):
        """Set system environment variables that are usually set by XMS.

        Args:
            **kwargs:
                xms_temp (str): Full path to the XMS temp directory that should be used
                process_temp (str): Full path to the process temp directory that should be used
                app_name (str): Name of the XMS app that should be used
                app_version (str): XMS app version that should be used
                notes_db (str): Full path to the XMS notes database file that should be used
                project_path (str): Absolute filename of the XMS project that should be used
                running_xms_tests (str): Set the flag that indicates the XMS app is running tests itself
        """
        xms_temp = kwargs.get('xms_temp')
        if xms_temp:
            # Ensure the XMS temp directory exists. May not when running Python tests.
            os.makedirs(xms_temp, exist_ok=True)
            os.environ[XmEnv.ENVIRON_XMS_TEMP_FOLDER] = xms_temp

        process_temp = kwargs.get('process_temp')
        if process_temp:
            # Ensure the process temp directory exists.
            os.makedirs(process_temp, exist_ok=True)
            os.environ[XmEnv.ENVIRON_PROCESS_TEMP_FOLDER] = process_temp

        app_name = kwargs.get('app_name')
        if app_name:
            os.environ[XmEnv.ENVIRON_XMS_APP_NAME] = app_name

        app_version = kwargs.get('app_version')
        if app_version:
            os.environ[XmEnv.ENVIRON_XMS_APP_VERSION] = str(app_version)  # Accept a float or int

        notes_db = kwargs.get('notes_db')
        if notes_db:
            os.environ[XmEnv.ENVIRON_NOTES_DATABASE] = notes_db

        project_path = kwargs.get('project_path')
        if project_path:
            os.environ[XmEnv.ENVIRON_PROJECT_PATH] = project_path

        running_xms_tests = kwargs.get('running_xms_tests')
        if running_xms_tests is not None:
            os.environ[XmEnv.ENVIRON_RUNNING_TESTS] = str(running_xms_tests).upper()  # Accept a bool

        # Set these to the default state before potentially changing them. A previous test may have changed them
        # and we want them to have their default value to start with.
        os.environ[XmEnv.ENVIRON_XMS_SENT_DATA_BASE_FILE] = XmEnv.SENT_DATA_BASE_FILE
        os.environ[XmEnv.ENVIRON_XMS_PLAYBACK_RECORD_FILE] = XmEnv.PLAYBACK_RECORD_FILE
        os.environ[XmEnv.ENVIRON_XMS_LOGGING_BASE_FILE] = XmEnv.LOGGING_BASE_FILE

        sent_data_file = kwargs.get('sent_data_file')
        if sent_data_file is not None:
            os.environ[XmEnv.ENVIRON_XMS_SENT_DATA_BASE_FILE] = sent_data_file

        request_file = kwargs.get('request_file')
        if request_file is not None:
            os.environ[XmEnv.ENVIRON_XMS_PLAYBACK_RECORD_FILE] = request_file

        logging_file = kwargs.get('logging_file')
        if logging_file is not None:
            os.environ[XmEnv.ENVIRON_XMS_LOGGING_BASE_FILE] = logging_file

        # Set playback folder to be parallel to recording folder with recording_folder_name_tmp9dkfjwl
        recording_folder = kwargs.get('recording_folder')
        dir = os.path.normpath(os.path.join(recording_folder, '..'))
        prefix = f'{os.path.basename(recording_folder)}_'
        playback_folder = tempfile.TemporaryDirectory(dir=dir, prefix=prefix).name
        os.environ[XmEnv.ENVIRON_XMS_PLAYBACK_FOLDER] = playback_folder

        if not os.path.isfile(XmEnv.RECORD_TRIGGER_FILE):
            os.environ[XmEnv.ENVIRON_XMS_RECORD_TRIGGER_FILE] = f'{XmEnv.RECORD_TRIGGER_FILE}_{os.getpid()}'
