import numpy as np
import open3d as o3d
from matplotlib import pyplot as plt
from matplotlib import image
from numpy.linalg import norm
import math
import cv2
from pose import rotation_matrix, convert2


def fps(points, centroid, n_samples):
    """
    points: [N, 3] array containing the whole point cloud
    n_samples: samples you want in the sampled point cloud typically << N
    """
    points = np.array(points)

    # Represent the points by their indices in points
    points_left = np.arange(len(points))  # [P]

    # Initialise an array for the sampled indices
    sample_inds = np.zeros(n_samples, dtype='int')  # [S]

    # Initialise distances to inf
    dists = np.ones_like(points_left) * float('inf')  # [P]

    # Select a point from points by its index, save it
    selected = 0
    sample_inds[0] = points_left[selected]

    # Delete selected
    points_left = np.delete(points_left, selected)  # [P - 1]

    # Iteratively select points for a maximum of n_samples
    for i in range(1, n_samples):
        # Find the distance to the last added point in selected
        # and all the others
        last_added = sample_inds[i - 1]

        dist_to_last_added_point = ((points[last_added] - points[points_left]) ** 2).sum(-1)  # [P - i]

        # If closer, updated distances
        dists[points_left] = np.minimum(dist_to_last_added_point, dists[points_left])  # [P - i]

        # We want to pick the one that has the largest nearest neighbour
        # distance to the sampled points
        selected = np.argmax(dists[points_left])
        sample_inds[i] = points_left[selected]

        # Update points_left
        points_left = np.delete(points_left, selected)

    return points[sample_inds]


def unit_vector_gt(outimage, x, y, keypoint):
    for i in range(len(keypoint[0])):
        xdiff = float(keypoint[0][i][0][0]) - x
        ydiff = float(keypoint[0][i][0][1]) - y
        mag = math.sqrt(ydiff ** 2 + xdiff ** 2)

        outimage[x][y][i * 2] = xdiff / mag
        outimage[x][y][i * 2 + 1] = ydiff / mag


def plot_unit_vector(point, vector):
    V = np.array([vector[0], vector[1]])
    origin = point[0]  # np.array([[0, 0, 0], [0, 0, 0]])  # origin point

    plt.quiver(*origin, V, color=['r', 'b', 'g'], scale=21)
    plt.show()

    plt.quiver([0, 0, 0], [0, 0, 0], [1, -2, 4], [1, 2, -7], angles='xy', scale_units='xy', scale=1)
    plt.xlim(-10, 10)
    plt.ylim(-10, 10)
    plt.show()

    X = np.arange(-10, 10, 1)
    Y = np.arange(-10, 10, 1)
    U, V = np.meshgrid(X, Y)

    fig, ax = plt.subplots()
    q = ax.quiver(X, Y, U, V)
    ax.quiverkey(q, X=0.3, Y=1.1, U=10,
                 label='Quiver key, length = 10', labelpos='E')

    plt.show()


def process(pcd, R_exp, tVec, camera, img, id):
    # pcd = o3d.io.read_point_cloud(obj_path)  # Read the point cloud

    # textured_mesh = o3d.io.read_triangle_mesh("banana1_visual.obj")
    # o3d.visualization.draw_geometries([textured_mesh])

    # Visualize the point cloud within open3d
    # o3d.visualization.draw_geometries([pcd])

    # Convert open3d format to numpy array
    # Here, you have the point cloud in numpy format.
    point_cloud_in_numpy = np.asarray(pcd.points)
    center = point_cloud_in_numpy.mean(0)

    new_point = fps(point_cloud_in_numpy, center, 8)
    # print(new_point)

    pcd2 = o3d.geometry.PointCloud()
    pcd2.points = o3d.utility.Vector3dVector(new_point)
    # o3d.visualization.draw_geometries([pcd2])

    # vis = o3d.visualization.Visualizer()
    # vis.create_window()
    # vis.add_geometry(pcd)
    # vis.add_geometry(pcd2)
    # vis.run()
    # vis.destroy_window()

    # print(pose)

    camera = np.array(camera)
    R_exp = np.array(R_exp, dtype="float64")
    tVec = np.array(tVec, dtype="float64")
    # print(R_exp)
    # print(tVec)

    pcd2_in_numpy = np.asarray(pcd.points)
    keypoint_2d = cv2.projectPoints(pcd2_in_numpy, R_exp, tVec, camera, np.zeros(shape=[8, 1], dtype='float64'))

    for n in range(len(pcd2_in_numpy)):
        print(pcd2_in_numpy[n], '==>', keypoint_2d[0][n])

    out = np.zeros((img.shape[0], img.shape[1], 16))
    fig, ax = plt.subplots()
    ax.imshow(img)
    for n in range(len(pcd2_in_numpy)):
        point = keypoint_2d[0][n]
        ax.plot(point[0][0], point[0][1], marker='.', color="red")

    plt.imshow(img)
    plt.show()
    # plt.savefig(f"result/img_{id}.png")

# ==============================================================================


x, y, z = [ 2.9788743387155581, -0.27449049579661572, -1.2476345641806936 ]

# Banana [ 0.085909521758723559, -0.10993803084296662, 2.4625571948393996 ]
# Pear [2.940213420813008, -0.91304327878070535, -2.6173582584096144]
# Orange [ 2.9788743387155581, -0.27449049579661572, -1.2476345641806936 ]

dic = [[x, y, z],
       [x, z, y],
       [y, x, z],
       [y, z, x],
       [z, x, y],
       [z, y, x],
       # ===========
       [-x, -y, -z],
       [-x, -z, -y],
       [-y, -x, -z],
       [-y, -z, -x],
       [-z, -x, -y],
       [-z, -y, -x],
       # ===========
       [-x, y, z],
       [-x, z, y],
       [-y, x, z],
       [-y, z, x],
       [-z, x, y],
       [-z, y, x],
       # ===========
       [x, -y, z],
       [x, -z, y],
       [y, -x, z],
       [y, -z, x],
       [z, -x, y],
       [z, -y, x],
       # ===========
       [x, y, -z],
       [x, z, -y],
       [y, x, -z],
       [y, z, -x],
       [z, x, -y],
       [z, y, -x],
       # ===========
       [-x, -y, z],
       [-x, -z, y],
       [-y, -x, z],
       [-y, -z, x],
       [-z, -x, y],
       [-z, -y, x],
       # ===========
       [-x, y, -z],
       [-x, z, -y],
       [-y, x, -z],
       [-y, z, -x],
       [-z, x, -y],
       [-z, y, -x],
       # ===========
       [x, -y, -z],
       [x, -z, -y],
       [y, -x, -z],
       [y, -z, -x],
       [z, -x, -y],
       [z, -y, -x],
       ]
point_cloud = "/home/mahmoud/GUIMOD/Models/Orange/orange2_visual2.ply"

pose = np.load('/home/mahmoud/GUIMOD/Pose/Orange/0.npy')
new = np.matrix([[0.0000000, -1.0000000, 0.0000000],
                 [0.0000000, 0.0000000, -1.0000000],
                 [1.0000000, 0.0000000, 0.0000000]])
t_org = pose[0:3, 3]
tVec = new @ t_org

print(tVec)

img = image.imread('/media/mahmoud/F/GUIMOD/data/1/grabber_1/color/image/0_0.png')
camera = [[1386.4138492513919, 0.0, 960.5], [0.0, 1386.4138492513919, 540.5], [0.0, 0.0, 1.0]]
pcd = o3d.io.read_point_cloud(point_cloud)  # Read the point cloud


if __name__ == '__main__':
    # Read .ply file
    obj = 'cat'
    # point_cloud = "/home/mahmoud/GUIMOD/Models/Pear/pear2_visual2.ply"
    # point_cloud = f"{obj}.ply"
    # --------------------

    # pose = np.load(f'./{obj}/pose/pose0.npy')
    # pose = np.load('/home/mahmoud/GUIMOD/Pose/Pear/1.npy')
    # R_exp = pose[0:3, 0:3]

    # XYZ change x with z and remove - from y
    # R_xyz = np.matrix([[0.9902972, -0.0852859, 0.1097167],
    #                    [0.0018819, -0.7812214, -0.6242512],
    #                    [0.1389529, 0.6184007, -0.7734809]])

    # t_org = pose[0:3, 3]
    # tVec = new @ t_org
    #
    # print(tVec)

    # --------------------
    # img = image.imread(f'./{obj}/JPEGImages/000000.jpg')

    # img = image.imread('/media/mahmoud/F/GUIMOD/data/1/grabber_2/color/image/0_0.png')
    # --------------------

    # camera = [[1386.4138492513919, 0.0, 960.5], [0.0, 1386.4138492513919, 540.5], [0.0, 0.0, 1.0]]
    i = 0
    for pos in dic:
        rot = rotation_matrix(pos)
        R_exp = rot
        process(pcd, R_exp, tVec, camera, img, i)
        i += 1


"""
D: [0.0, 0.0, 0.0, 0.0, 0.0]
K: [1386.4138492513919, 0.0, 960.5, 0.0, 1386.4138492513919, 540.5, 0.0, 0.0, 1.0]
R: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
P: [1386.4138492513919, 0.0, 960.5, -0.0, 0.0, 1386.4138492513919, 540.5, 0.0, 0.0, 0.0, 1.0, 0.0]
The output matrices of the camera_calibration package are: 1) Distortion parameters (D) 2) Intrinsic camera matrix (K) 3) Rectification matrix (R) 4) Projection matrix of the processed (rectified) image (P)
D: [0.0, 0.0, 0.0, 0.0, 0.0]
K: [1086.5054444841007, 0.0, 640.5, 0.0, 1086.5054444841007, 360.5, 0.0, 0.0, 1.0]
R: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
P: [1086.5054444841007, 0.0, 640.5, -0.0, 0.0, 1086.5054444841007, 360.5, 0.0, 0.0, 0.0, 1.0, 0.0]
"""