"""Apply to a GenCade simulation."""

__copyright__ = "(C) Copyright Aquaveo 2019"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math
import os
import uuid

# 2. Third party modules
import numpy as np
from shapely.geometry import LineString, Point as shPt
from shapely.ops import nearest_points
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.components.display.display_options_io import (read_display_options_from_json,
                                                       write_display_option_line_locations,
                                                       write_display_options_to_json)
from xms.components.display.xms_display_message import DrawType, XmsDisplayMessage
from xms.core.filesystem import filesystem as xfs
from xms.data_objects.parameters import Component, FilterLocation
# from xms.gdal.utilities.gdal_utils import is_geographic
from xms.grid.geometry.geometry import distance_2d
from xms.guipy.data.category_display_option_list import CategoryDisplayOptionList
from xms.guipy.data.target_type import TargetType
from xms.interp.interpolate.interp_linear import InterpLinear
from xms.mesher import meshing

# 4. Local modules
from xms.gencade.components import mapped_grid_component as mgc


class GridMapper:
    """Class for mapping a 1D grid coverage to a GenCade simulation."""

    def __init__(self, new_mainfile, grid_data):
        """Construct the mapper.

        Args:
            new_mainfile (:obj:`str`): Path to the new mapped component's mainfile
            grid_data (:obj:`GridData`): Data of the source 1D Gencade coverage
        """
        self.logger = logging.getLogger('xms.gencade')
        self.new_mainfile = new_mainfile
        self._struct_cov = None
        self._struct_comp = None
        self.grid_data = grid_data
        self.mapped_comp = None  # The Python component
        self.points = self.grid_data['grid_cov'].get_points(FilterLocation.PT_LOC_DISJOINT)
        self.grid_locs = []
        self._grid_line_string = None
        self._specified_begin_pt = None
        self._specified_end_pt = None

    def _update_display_uuids(self, mapped_disp):
        """Create unique UUIDs for the display lists copied from the source coverage component.

        Args:
            mapped_disp (:obj:`str`): Filepath to the mapped BC arc display options
        """
        # Read the source arc display options, and save ourselves a copy with a randomized UUID.
        self.logger.info('Initializing display lists.')
        categories = CategoryDisplayOptionList()
        json_dict = read_display_options_from_json(mapped_disp)
        categories.from_dict(json_dict)
        categories.comp_uuid = os.path.basename(os.path.dirname(self.new_mainfile))
        categories.uuid = str(uuid.uuid4())  # Generate a new UUID for the mapped component display
        categories.is_ids = False  # Switch to a free location draw
        # Set projection of free locations to be that of the mesh/current display
        categories.projection = {'wkt': self.grid_data['wkt']}
        write_display_options_to_json(mapped_disp, categories)
        # Save our display list UUID to the main file
        self.mapped_comp.data.info.attrs['display_uuid'] = categories.uuid
        self.mapped_comp.display_option_list = [
            XmsDisplayMessage(file=self.mapped_comp.disp_opts_file, draw_type=DrawType.draw_at_locations),
        ]

    def _find_cov_points(self):
        """Find the start and end nodes of the grid.

        Returns:
            (:obj:`tuple`): x,y,z coordinates of the endpoints and refine points
        """
        # Filter the dataset to the endpoint types
        t = TargetType.point
        gc = self.grid_data['grid_comp']
        self._query = Query()
        gc.refresh_component_ids(self._query, points=True)
        data = gc.data

        begin_loc = None
        end_loc = None
        refine_pts = []
        for pt in self.points:
            comp_id = gc.get_comp_id(t, pt.id)
            if comp_id is None or comp_id < 0:
                continue
            pt_atts = data.points.loc[dict(comp_id=[comp_id])]
            if pt_atts['point_type'] == 'Refine':
                refine_pts.append((pt.x, pt.y, pt.z, pt_atts['refine_size_const']))
            elif pt_atts['point_type'] == 'Begin':
                if begin_loc is not None:
                    self.logger.error('There must be one and only one begin point in the coverage.')
                    return None, None, None
                begin_loc = (pt.x, pt.y, pt.z)
            elif pt_atts['point_type'] == 'End':
                if end_loc is not None:
                    self.logger.error('There must be one and only one end point in the coverage.')
                    return None, None, None
                end_loc = (pt.x, pt.y, pt.z)

        return begin_loc, end_loc, refine_pts

    def _calc_grid_angle(self):
        """Compute the dx dataset from the grid point draw locations and set the Dataset attrs."""
        dx = self.grid_locs[-1][0] - self.grid_locs[0][0]
        dy = self.grid_locs[-1][1] - self.grid_locs[0][1]
        angle = math.atan2(dy, dx) * 180.0 / math.pi
        if angle < 0.0:
            angle += 360.0
        return angle

    def _write_dataset(self):
        """Compute the dx dataset from the grid point draw locations and set the Dataset attrs."""
        self.mapped_comp.data.info.attrs['x0'] = self.grid_locs[0][0]
        self.mapped_comp.data.info.attrs['y0'] = self.grid_locs[0][1]
        self.mapped_comp.data.info.attrs['xend'] = self.grid_locs[-1][0]
        self.mapped_comp.data.info.attrs['yend'] = self.grid_locs[-1][1]
        dxs = [distance_2d(self.grid_locs[i - 1], self.grid_locs[i]) for i in range(1, len(self.grid_locs))]
        dset = xr.Dataset(data_vars={'dx': xr.DataArray(data=np.array(dxs, dtype=np.float64))})
        self.mapped_comp.data.locations = dset
        self.mapped_comp.data.info.attrs['theta'] = self._calc_grid_angle()
        self.mapped_comp.data.info.attrs['num_cells'] = len(dxs)

    def _create_map_component(self):
        """Create the mapped 1D grid component Python object.

        Returns:
            (:obj:`bool`): True if any errors encountered
        """
        self.logger.info(f'PID: {os.getpid()}')
        self.logger.info('Generating 1D grid from GenCade coverage.')
        self._compute_grid_line_string()
        if self._grid_line_string is None:
            return False

        line_locs = list(self._grid_line_string.coords)
        self.grid_locs = line_locs
        draw_locs = [[p[i] for p in [line_locs[0], line_locs[-1]] for i in range(3)]]

        ends = 0.010

        # draw the ends biggest
        ls = LineString((line_locs[0], line_locs[-1]))
        off_set = ends * ls.length
        right = ls.parallel_offset(off_set)
        left = ls.parallel_offset(-off_set)
        for p in [line_locs[0], line_locs[-1]]:
            bpt = shPt(p)
            p0 = nearest_points(right, bpt)[0].coords[0]
            p1 = nearest_points(left, bpt)[0].coords[0]
            draw_locs.append([p0[0], p0[1], 0.0, p1[0], p1[1], 0.0])

        tens = 0.008
        off_set_tens = tens * ls.length
        right_tens = ls.parallel_offset(off_set_tens)
        left_tens = ls.parallel_offset(-off_set_tens)

        fives = 0.005
        off_set_fives = fives * ls.length
        right_fives = ls.parallel_offset(off_set_fives)
        left_fives = ls.parallel_offset(-off_set_fives)

        ones = 0.002
        off_set_ones = ones * ls.length
        right_ones = ls.parallel_offset(off_set_ones)
        left_ones = ls.parallel_offset(-off_set_ones)

        all_pts = [shPt(p) for p in line_locs[1:-1]]
        for i, pt in enumerate(all_pts, 1):
            if i % 10 == 0:
                p0 = nearest_points(right_tens, pt)[0].coords[0]
                p1 = nearest_points(left_tens, pt)[0].coords[0]
            elif i % 5 == 0:
                p0 = nearest_points(right_fives, pt)[0].coords[0]
                p1 = nearest_points(left_fives, pt)[0].coords[0]
            else:
                p0 = nearest_points(right_ones, pt)[0].coords[0]
                p1 = nearest_points(left_ones, pt)[0].coords[0]
            draw_locs.append([p0[0], p0[1], 0.0, p1[0], p1[1], 0.0])

        self.logger.info('Creating mapped 1D grid data.')

        write_display_option_line_locations(
            str(os.path.join(os.path.dirname(self.new_mainfile), mgc.MAPPED_GRID_LOCATIONS)),
            draw_locs  # Needs to be a 2D list of locations. Ticks would be other lines in the list.
        )
        self._write_dataset()
        return True

    def _compute_grid_line_string(self):
        """Compute the grid LineString from the begin, end points and the cell size."""
        begin, end, refine_pts = self._find_cov_points()
        if begin is None or end is None:
            return False  # Couldn't find the endpoints, abort
        self._specified_begin_pt = begin
        self._specified_end_pt = end
        cell_size = float(self.grid_data['grid_comp'].data.info.attrs['base_cell_size'])
        ls = LineString((begin, end))
        uv = [(end[0] - begin[0]) / ls.length, (end[1] - begin[1]) / ls.length]
        ncell = int(ls.length / cell_size)
        if ls.length > cell_size * ncell:
            ncell += 1
        new_length = cell_size * ncell
        uv1 = [new_length * p for p in uv]
        end = [begin[0] + uv1[0], begin[1] + uv1[1], 0.0]
        ls = LineString((begin, end))
        pts = [ls.interpolate(i * cell_size) for i in range(1, ncell)]
        self._grid_line_string = LineString([begin] + pts + [end])
        # figure out other locations on the line
        if len(refine_pts) > 0:
            dist = 2 * cell_size
            uv1 = [dist * p for p in uv]
            p0 = (begin[0] - uv1[0], begin[1] - uv1[1], 0.0)
            p1 = (end[0] + uv1[0], end[1] + uv1[1], 0.0)
            ls = LineString((p0, p1))
            dist = 0.01 * ls.length
            offsets = (ls.parallel_offset(dist), ls.parallel_offset(-dist))

            new_length = cell_size * (ncell + 2)
            uv1 = [new_length * p for p in uv]
            end = [begin[0] + uv1[0], begin[1] + uv1[1], 0.0]
            self._grid_line_string = LineString([begin] + pts + [end])

            self._apply_refine_points(refine_pts, offsets)

    def _apply_refine_points(self, refine_pts, offsets):
        """Modify the grid_line_string to include the refine points.

        Args:
            refine_pts (:obj:`list`): locations, grid size
            offsets (:obj:`tuple`): LineStrings offset from the grid
        """
        cell_size = float(self.grid_data['grid_comp'].data.info.attrs['base_cell_size'])
        coords = list(offsets[0].coords) + list(offsets[1].coords)
        size_pts = [(p[0], p[1], cell_size) for p in coords]
        # add back the specified begin and end points
        for p in [self._specified_begin_pt, self._specified_end_pt]:
            size_pt = nearest_points(offsets[0], shPt(p))[0].coords[0]
            size_pts.append((size_pt[0], size_pt[1], cell_size))
            size_pt = nearest_points(offsets[1], shPt(p))[0].coords[0]
            size_pts.append((size_pt[0], size_pt[1], cell_size))

        for p in refine_pts:
            loc = p[:3]
            size = p[3]
            spt = shPt(loc)
            size_pt = nearest_points(offsets[0], spt)[0].coords[0]
            size_pts.append((size_pt[0], size_pt[1], size))
            size_pt = nearest_points(offsets[1], spt)[0].coords[0]
            size_pts.append((size_pt[0], size_pt[1], size))

        linear = InterpLinear(points=size_pts)
        redist = meshing.poly_redistribute_points.PolyRedistributePoints()
        redist.set_size_func(linear)
        grid_pts = list(self._grid_line_string.coords)
        new_pts = redist.redistribute(grid_pts)

        ls = LineString((self._specified_begin_pt, self._specified_end_pt))
        pts = []
        for pt in new_pts:
            dist = ls.project(shPt(pt))
            pts.append(pt)
            if dist >= ls.length:
                break
        self._grid_line_string = LineString(pts)

    def map_data(self):
        """Create a mapped tidal component from a source tidal constituent component.

        Returns:
            (:obj:`Component`): The mapped tidal XMS component object, None if errors encountered.
        """
        # Copy over some default display options.
        line_default_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'gui', 'resources',
                                         mgc.MAPPED_GRID_DISPLAY)
        mapped_disp = os.path.join(os.path.dirname(self.new_mainfile), mgc.MAPPED_GRID_DISPLAY)
        xfs.copyfile(str(line_default_file), str(mapped_disp))
        # Create the new mapped component
        self.mapped_comp = mgc.MappedGridComponent(self.new_mainfile)
        # Set projection of free location draw to be the current display projection.
        self.mapped_comp.data.info.attrs['wkt'] = self.grid_data['wkt']
        # Setup the display lists
        self._update_display_uuids(mapped_disp)

        # Create the mapped BC Python component
        do_comp = None
        if self._create_map_component():
            self.logger.info('Writing applied data to new component files.')
            self.mapped_comp.data.commit()

            # Create the data_objects component
            do_comp = Component(
                name=f'{self.grid_data["grid_cov"].name} (applied)',
                comp_uuid=os.path.basename(os.path.dirname(self.new_mainfile)),
                main_file=self.new_mainfile,
                model_name='GenCade',
                unique_name='MappedGridComponent',
                locked=False
            )
        return do_comp

    def set_struct_events(self, struct_events, component):
        """Sets the struct events coverage.

        Args:
            struct_events (:obj:`Coverage`): A struct events coverage.
            component (:obj:`StructComponent`): A struct events component belonging to the coverage.
        """
        self._struct_cov = struct_events
        self._struct_comp = component


def map_grid(query, xms_data):
    """Map a tidal constituent component to a mapped BC component.

    Args:
        query (:obj:`Query`): Object for communicating with XMS
        xms_data (:obj:`dict`): Dictionary containing the data retrieved from XMS for grid mapping
    Returns:
        (:obj:`tuple(Component, MappedGridComponent)`): The data_objects component and the Python component for the
        mapped grid
    """
    logger = logging.getLogger('xms.gencade')
    # Create a folder for the new mapped component. This assumes we are in the component temp directory.
    logger.info('Creating new applied component files.')
    comp_dir = os.path.join(os.path.dirname(os.path.dirname(xms_data['grid_comp'].main_file)), str(uuid.uuid4()))
    os.makedirs(comp_dir, exist_ok=True)
    new_mainfile = os.path.join(comp_dir, mgc.MAPPED_GRID_MAINFILE)

    # Perform the mapping
    logger.info('Mapping source 1D grid coverage to simulation.')
    mapper = GridMapper(new_mainfile, xms_data)
    do_comp = mapper.map_data()
    if do_comp is None:  # Unlink the grid coverage if error occurred
        if xms_data['sim_uuid'] and xms_data['grid_cov']:
            query.unlink_item(xms_data['sim_uuid'], xms_data['grid_cov'].uuid)
    return do_comp, mapper.mapped_comp
