diff --git a/tools/eval_linemod.py b/tools/eval_linemod.py index c07dcdf52b4ec75a5885cde41143d34616c3dca5..26ffdf2073f60d840af3dd3f3ae8c30934a40fa7 100644 --- a/tools/eval_linemod.py +++ b/tools/eval_linemod.py @@ -4,6 +4,7 @@ import os import random import numpy as np import yaml +import copy import torch import torch.nn as nn import torch.nn.parallel @@ -18,6 +19,8 @@ from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod from lib.network import PoseNet, PoseRefineNet from lib.loss import Loss from lib.loss_refiner import Loss_refine +from lib.transformations import euler_matrix, quaternion_matrix, quaternion_from_matrix +from lib.knn.__init__ import KNearestNeighbor parser = argparse.ArgumentParser() parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir') @@ -29,9 +32,10 @@ num_objects = 13 objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15] num_points = 500 iteration = 2 +bs = 1 dataset_config_dir = 'datasets/linemod/dataset_config' output_result_dir = 'experiments/eval_result/linemod' - +knn = KNearestNeighbor(1) estimator = PoseNet(num_points = num_points, num_obj = num_objects) estimator.cuda() @@ -73,19 +77,65 @@ for i, data in enumerate(testdataloader, 0): Variable(target).cuda(), \ Variable(model_points).cuda(), \ Variable(idx).cuda() + pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx) - _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, 0.0, False) + pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, num_points, 1) + pred_c = pred_c.view(bs, num_points) + how_max, which_max = torch.max(pred_c, 1) + pred_t = pred_t.view(bs * num_points, 1, 3) + + my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy() + my_t = (points.view(bs * num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy() + my_pred = np.append(my_r, my_t) + for ite in range(0, iteration): + T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(num_points, 1).contiguous().view(1, num_points, 3) + my_mat = quaternion_matrix(my_r) + R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3) + my_mat[0:3, 3] = my_t + + new_points = torch.bmm((points - T), R).contiguous() pred_r, pred_t = refiner(new_points, emb, idx) - dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points) + pred_r = pred_r.view(1, 1, -1) + pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1)) + my_r_2 = pred_r.view(-1).cpu().data.numpy() + my_t_2 = pred_t.view(-1).cpu().data.numpy() + my_mat_2 = quaternion_matrix(my_r_2) + my_mat_2[0:3, 3] = my_t_2 + + my_mat_final = np.dot(my_mat, my_mat_2) + my_r_final = copy.deepcopy(my_mat_final) + my_r_final[0:3, 3] = 0 + my_r_final = quaternion_from_matrix(my_r_final, True) + my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]]) + + my_pred = np.append(my_r_final, my_t_final) + my_r = my_r_final + my_t = my_t_final + + # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation) + + model_points = model_points[0].cpu().detach().numpy() + my_r = quaternion_matrix(my_r)[:3, :3] + pred = np.dot(model_points, my_r.T) + my_t + target = target[0].cpu().detach().numpy() + + if idx[0].item() in sym_list: + pred = torch.from_numpy(pred.astype(np.float32)).cuda().transpose(1, 0).contiguous() + target = torch.from_numpy(target.astype(np.float32)).cuda().transpose(1, 0).contiguous() + inds = knn(target.unsqueeze(0), pred.unsqueeze(0)) + target = torch.index_select(target, 1, inds.view(-1) - 1) + dis = torch.mean(torch.norm((pred.transpose(1, 0) - target.transpose(1, 0)), dim=1), dim=0).item() + else: + dis = np.mean(np.linalg.norm(pred - target, axis=1)) - if dis.item() < diameter[idx[0].item()]: + if dis < diameter[idx[0].item()]: success_count[idx[0].item()] += 1 - print('No.{0} Pass! Distance: {1}'.format(i, dis.item())) - fw.write('No.{0} Pass! Distance: {1}\n'.format(i, dis.item())) + print('No.{0} Pass! Distance: {1}'.format(i, dis)) + fw.write('No.{0} Pass! Distance: {1}\n'.format(i, dis)) else: - print('No.{0} NOT Pass! Distance: {1}'.format(i, dis.item())) - fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(i, dis.item())) + print('No.{0} NOT Pass! Distance: {1}'.format(i, dis)) + fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(i, dis)) num_count[idx[0].item()] += 1 for i in range(num_objects): diff --git a/tools/eval_ycb.py b/tools/eval_ycb.py index 515a2eacb0c2bb318866004c25d850a88d915ab2..5991029519a4ca1cc19267da784f18583578b463 100644 --- a/tools/eval_ycb.py +++ b/tools/eval_ycb.py @@ -206,9 +206,7 @@ for now in range(0, 2949): T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(num_points, 1).contiguous().view(1, num_points, 3) my_mat = quaternion_matrix(my_r) R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3) - my_mat[0][3] = my_t[0] - my_mat[1][3] = my_t[1] - my_mat[2][3] = my_t[2] + my_mat[0:3, 3] = my_t new_cloud = torch.bmm((cloud - T), R).contiguous() pred_r, pred_t = refiner(new_cloud, emb, index) @@ -218,15 +216,11 @@ for now in range(0, 2949): my_t_2 = pred_t.view(-1).cpu().data.numpy() my_mat_2 = quaternion_matrix(my_r_2) - my_mat_2[0][3] = my_t_2[0] - my_mat_2[1][3] = my_t_2[1] - my_mat_2[2][3] = my_t_2[2] + my_mat_2[0:3, 3] = my_t_2 my_mat_final = np.dot(my_mat, my_mat_2) my_r_final = copy.deepcopy(my_mat_final) - my_r_final[0][3] = 0 - my_r_final[1][3] = 0 - my_r_final[2][3] = 0 + my_r_final[0:3, 3] = 0 my_r_final = quaternion_from_matrix(my_r_final, True) my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]]) @@ -234,6 +228,8 @@ for now in range(0, 2949): my_r = my_r_final my_t = my_t_final + # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation) + my_result.append(my_pred.tolist()) except ZeroDivisionError: print("PoseCNN Detector Lost {0} at No.{1} keyframe".format(itemid, now))