"""Class for managing interprocess communication with XMS."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from io import StringIO
import os

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.components.display.display_options_io import write_display_option_ids
from xms.constraint import read_grid_from_file
from xms.guipy.data.target_type import TargetType
from xms.guipy.time_format import XmsTimeFormatter

# 4. Local modules
import xms.bridge.gui.struct_dlg.arc_props as ap


class XmsData:
    """Class for managing interprocess communication with XMS."""
    def __init__(self, query, structure_comp):
        """Constructor.

        Args:
            query (:obj:`xms.api.dmi.Query`): Object for communicating with XMS
            structure_comp (:obj:`xms.bridge.structure_component.StructureComponent`): The structure
        """
        self._query = query
        self._structure_comp = structure_comp
        self._coverage = None
        self._display_projection = None
        self._display_projection_wkt = ''
        self._display_projection_vertical_units = ''
        self._project_tree = None
        self._co_grids = {}
        self._wse_datasets = {}
        self._velocity_mag_datasets = {}
        self._coverages = {}
        self._set_data_from_query()
        self._dataset_time_formatter = None

    def _set_data_from_query(self):
        """Set the class data from query."""
        if self._query is None:
            return

        if self._structure_comp is not None and self._coverage is None:
            self._coverage = self._query.parent_item()
            self._query.load_component_ids(self._structure_comp, arcs=True)
        dp = self._query.display_projection
        self._display_projection = dp
        self._display_projection_wkt = '' if not dp else dp.well_known_text
        self._display_projection_vertical_units = '' if not dp else dp.vertical_units
        self._project_tree = self._query.project_tree

    def set_query_structure_comp_cov(self, query, structure_comp, structure_cov):
        """Set the query, structure component, and coverage (used by SRH2D).

        Args:
            query (:obj:`xms.api.dmi.Query`): Object for communicating with XMS
            structure_comp (:obj:`xms.bridge.structure_component.StructureComponent`): The structure
            structure_cov (:obj:`xms.data_objects.parameters.Coverage`): The coverage
        """
        self._query = query
        self._structure_comp = structure_comp
        self._coverage = structure_cov
        self._set_data_from_query()

    @property
    def project_tree(self):
        """Returns the project tree."""
        return self._project_tree

    def tree_item_from_uuid(self, uuid):
        """Returns the tree item from the uuid."""
        if self._query is None or self._project_tree is None:
            return None
        # if there is no tree item associated with the uuid then the return value is None
        return tree_util.find_tree_node_by_uuid(self._project_tree, uuid)

    @property
    def structure_component(self):
        """Returns the coverage associated with this 3d structure."""
        return self._structure_comp

    @property
    def coverage(self):
        """Returns the coverage associated with this 3d structure."""
        return self._coverage

    @property
    def display_projection(self):
        """Returns the display projection."""
        return self._display_projection

    @property
    def display_projection_wkt(self):
        """Returns the display projection."""
        return self._display_projection_wkt

    @property
    def vertical_units(self):
        """Returns the vertical units."""
        return self._display_projection_vertical_units

    def raster_file_from_uuid(self, uuid):
        """Returns the raster file from the uuid."""
        if self._query is None:
            return None
        # if there is no raster associated with the uuid then the return value is None
        return self._query.item_with_uuid(uuid)

    def add_ugrid(self, ugrid):
        """Adds a ugrid to the Query.

        Args:
            ugrid (Ugrid): The ugrid to add to the Query.
        """
        if self._query:
            self._query.add_ugrid(ugrid)

    def add_coverage(self, coverage):
        """Adds a coverage to the Query.

        Args:
            coverage (Coverage): The coverage to add to the Query.
        """
        if self._query:
            self._query.add_coverage(coverage)

    def update_component_ids(self):
        """Updates the component ids."""
        if self._coverage and self._structure_comp:
            for arc in self.coverage.arcs:
                self._structure_comp.update_component_id(TargetType.arc, arc.id, arc.id)
            _update_display_id_files(self._structure_comp)
            self._structure_comp.add_display_message()

    def co_grid_from_uuid(self, uuid):
        """Returns the co_grid given an uuid.

        Args:
            uuid (str): The uuid of the co_grid.

        Returns:
            A constrained Grid.
        """
        if self._query is None:
            return None
        if uuid not in self._co_grids:
            do_grid = self._query.item_with_uuid(uuid)
            if do_grid is None:
                return None
            self._co_grids[uuid] = read_grid_from_file(do_grid.cogrid_file)
        return self._co_grids[uuid]

    def wse_dataset_from_uuid(self, uuid):
        """Returns the water surface elevation dataset given an uuid.

        Args:
            uuid (str): The uuid of the co_grid.

        Returns:
            (Dataset): The dataset
        """
        if self._query is None:
            return None
        if uuid not in self._wse_datasets:
            dataset = self._query.item_with_uuid(uuid)
            if dataset is None:
                return None
            self._wse_datasets[uuid] = dataset
        return self._wse_datasets[uuid]

    def velocity_mag_dataset_from_uuid(self, uuid):
        """Returns the velocity magnitude dataset given an uuid.

        Args:
            uuid (str): The uuid of the dataset.

        Returns:
            (Dataset): The dataset
        """
        if self._query is None:
            return None
        if uuid not in self._velocity_mag_datasets:
            dataset = self._query.item_with_uuid(uuid)
            if dataset is None:
                return None
            self._velocity_mag_datasets[uuid] = dataset
        return self._velocity_mag_datasets[uuid]

    def time_strings_from_dataset_times(self, dataset):
        """Returns the times from a dataset.

        Args:
            dataset (Dataset): The dataset

        Returns:
            (list): The times
        """
        if self._dataset_time_formatter is None:
            self._dataset_time_formatter = XmsTimeFormatter(self._query.global_time_settings)
        if not dataset:
            return []

        ts_times_strings = []
        # Set reference time if there is one in case there is no zero time
        self._dataset_time_formatter.ref_time = dataset.ref_time if dataset.ref_time is not None else None

        for ts_idx in range(dataset.num_times):
            ts_delta = dataset.timestep_offset(ts_idx)
            ts_times_strings.append(self._dataset_time_formatter.format_time(ts_delta))

        return ts_times_strings

    def coverage_from_uuid(self, uuid):
        """
        Returns the coverage from the uuid.

        Args:
            uuid (str): The uuid of the coverage.

        Returns:
            (:obj:'xms.data_objects.parameters.Coverage'): The coverage
        """
        if self._query is None:
            return None
        if uuid not in self._coverages:
            do_cov = self._query.item_with_uuid(uuid)
            if do_cov is None:
                return None
            self._coverages[uuid] = do_cov
        return self._coverages[uuid]

    def display_categories(self):
        """Returns the display categories from the structure component."""
        return _display_categories_from_file(self._structure_comp.disp_opts_file)

    def display_category_color(self):
        """Returns a dict of display category colors from the structure component."""
        categories = self.display_categories()
        rval = {}
        for cat in categories.categories:
            rval[cat.description] = cat.options.color.name()
        return rval


def _display_categories_from_file(disp_opts_file):
    """Returns a dict of display categories and colors from the structure component."""
    from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
    from xms.components.display.display_options_io import read_display_options_from_json
    categories = CategoryDisplayOptionList()
    json_dict = read_display_options_from_json(disp_opts_file)
    categories.from_dict(json_dict)
    return categories


def _update_display_id_files(structure_comp):
    """Update the display id files for the coverage arcs."""
    fnames = [
        'bridge.display_ids', 'pier.display_ids', 'abutment.display_ids', 'culvert.display_ids',
        'embankment.display_ids'
    ]
    dir_name = os.path.dirname(structure_comp.main_file)
    for fname in fnames:
        file = os.path.join(dir_name, fname)
        if os.path.exists(file):
            os.remove(file)
    arc_ids, arc_types = _arc_types_from_structure_component(structure_comp)
    id_lists = {fname: [] for fname in fnames}
    for arc_id, arc_type in zip(arc_ids, arc_types):
        if arc_type == 'Bridge':
            id_lists['bridge.display_ids'].append(arc_id)
        elif arc_type in {ap.ARC_TYPE_PG, ap.ARC_TYPE_WP}:
            id_lists['pier.display_ids'].append(arc_id)
        elif arc_type == 'Abutment':
            id_lists['abutment.display_ids'].append(arc_id)
        elif arc_type == 'Culvert':
            id_lists['culvert.display_ids'].append(arc_id)
        elif arc_type == 'Embankment':
            id_lists['embankment.display_ids'].append(arc_id)
    for k, v in id_lists.items():
        if not v:
            continue
        write_display_option_ids(os.path.join(dir_name, k), v)


def _arc_types_from_structure_component(structure_comp):
    """Returns a dict of arc id, arc type from the structure component."""
    rval = [], []
    struct_type = structure_comp.data.data_dict['structure_type']
    if struct_type == 'Bridge':
        use_arc_props = structure_comp.data.data_dict['specify_arc_properties']
        csv_str = structure_comp.data.data_dict['arc_properties']
    else:  # Culvert
        use_arc_props = structure_comp.data.data_dict['specify_culvert_arc_properties']
        csv_str = structure_comp.data.data_dict['culvert_arc_properties']
    if use_arc_props and csv_str:
        df = pd.read_csv(StringIO(csv_str))
        rval = df[df.columns[0]].to_list(), df[df.columns[1]].to_list()
    return rval
