diff --git a/tools/eval_linemod.py b/tools/eval_linemod.py
index c07dcdf52b4ec75a5885cde41143d34616c3dca5..26ffdf2073f60d840af3dd3f3ae8c30934a40fa7 100644
--- a/tools/eval_linemod.py
+++ b/tools/eval_linemod.py
@@ -4,6 +4,7 @@ import os
 import random
 import numpy as np
 import yaml
+import copy
 import torch
 import torch.nn as nn
 import torch.nn.parallel
@@ -18,6 +19,8 @@ from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
 from lib.network import PoseNet, PoseRefineNet
 from lib.loss import Loss
 from lib.loss_refiner import Loss_refine
+from lib.transformations import euler_matrix, quaternion_matrix, quaternion_from_matrix
+from lib.knn.__init__ import KNearestNeighbor
 
 parser = argparse.ArgumentParser()
 parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir')
@@ -29,9 +32,10 @@ num_objects = 13
 objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15]
 num_points = 500
 iteration = 2
+bs = 1
 dataset_config_dir = 'datasets/linemod/dataset_config'
 output_result_dir = 'experiments/eval_result/linemod'
-
+knn = KNearestNeighbor(1)
 
 estimator = PoseNet(num_points = num_points, num_obj = num_objects)
 estimator.cuda()
@@ -73,19 +77,65 @@ for i, data in enumerate(testdataloader, 0):
                                                      Variable(target).cuda(), \
                                                      Variable(model_points).cuda(), \
                                                      Variable(idx).cuda()
+
     pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx)
-    _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, 0.0, False)
+    pred_r = pred_r / torch.norm(pred_r, dim=2).view(1, num_points, 1)
+    pred_c = pred_c.view(bs, num_points)
+    how_max, which_max = torch.max(pred_c, 1)
+    pred_t = pred_t.view(bs * num_points, 1, 3)
+
+    my_r = pred_r[0][which_max[0]].view(-1).cpu().data.numpy()
+    my_t = (points.view(bs * num_points, 1, 3) + pred_t)[which_max[0]].view(-1).cpu().data.numpy()
+    my_pred = np.append(my_r, my_t)
+
     for ite in range(0, iteration):
+        T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(num_points, 1).contiguous().view(1, num_points, 3)
+        my_mat = quaternion_matrix(my_r)
+        R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
+        my_mat[0:3, 3] = my_t
+        
+        new_points = torch.bmm((points - T), R).contiguous()
         pred_r, pred_t = refiner(new_points, emb, idx)
-        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
+        pred_r = pred_r.view(1, 1, -1)
+        pred_r = pred_r / (torch.norm(pred_r, dim=2).view(1, 1, 1))
+        my_r_2 = pred_r.view(-1).cpu().data.numpy()
+        my_t_2 = pred_t.view(-1).cpu().data.numpy()
+        my_mat_2 = quaternion_matrix(my_r_2)
+        my_mat_2[0:3, 3] = my_t_2
+
+        my_mat_final = np.dot(my_mat, my_mat_2)
+        my_r_final = copy.deepcopy(my_mat_final)
+        my_r_final[0:3, 3] = 0
+        my_r_final = quaternion_from_matrix(my_r_final, True)
+        my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
+
+        my_pred = np.append(my_r_final, my_t_final)
+        my_r = my_r_final
+        my_t = my_t_final
+
+    # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation)
+
+    model_points = model_points[0].cpu().detach().numpy()
+    my_r = quaternion_matrix(my_r)[:3, :3]
+    pred = np.dot(model_points, my_r.T) + my_t
+    target = target[0].cpu().detach().numpy()
+
+    if idx[0].item() in sym_list:
+        pred = torch.from_numpy(pred.astype(np.float32)).cuda().transpose(1, 0).contiguous()
+        target = torch.from_numpy(target.astype(np.float32)).cuda().transpose(1, 0).contiguous()
+        inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
+        target = torch.index_select(target, 1, inds.view(-1) - 1)
+        dis = torch.mean(torch.norm((pred.transpose(1, 0) - target.transpose(1, 0)), dim=1), dim=0).item()
+    else:
+        dis = np.mean(np.linalg.norm(pred - target, axis=1))
 
-    if dis.item() < diameter[idx[0].item()]:
+    if dis < diameter[idx[0].item()]:
         success_count[idx[0].item()] += 1
-        print('No.{0} Pass! Distance: {1}'.format(i, dis.item()))
-        fw.write('No.{0} Pass! Distance: {1}\n'.format(i, dis.item()))
+        print('No.{0} Pass! Distance: {1}'.format(i, dis))
+        fw.write('No.{0} Pass! Distance: {1}\n'.format(i, dis))
     else:
-        print('No.{0} NOT Pass! Distance: {1}'.format(i, dis.item()))
-        fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(i, dis.item()))
+        print('No.{0} NOT Pass! Distance: {1}'.format(i, dis))
+        fw.write('No.{0} NOT Pass! Distance: {1}\n'.format(i, dis))
     num_count[idx[0].item()] += 1
 
 for i in range(num_objects):
diff --git a/tools/eval_ycb.py b/tools/eval_ycb.py
index 515a2eacb0c2bb318866004c25d850a88d915ab2..5991029519a4ca1cc19267da784f18583578b463 100644
--- a/tools/eval_ycb.py
+++ b/tools/eval_ycb.py
@@ -206,9 +206,7 @@ for now in range(0, 2949):
                 T = Variable(torch.from_numpy(my_t.astype(np.float32))).cuda().view(1, 3).repeat(num_points, 1).contiguous().view(1, num_points, 3)
                 my_mat = quaternion_matrix(my_r)
                 R = Variable(torch.from_numpy(my_mat[:3, :3].astype(np.float32))).cuda().view(1, 3, 3)
-                my_mat[0][3] = my_t[0]
-                my_mat[1][3] = my_t[1]
-                my_mat[2][3] = my_t[2]
+                my_mat[0:3, 3] = my_t
                 
                 new_cloud = torch.bmm((cloud - T), R).contiguous()
                 pred_r, pred_t = refiner(new_cloud, emb, index)
@@ -218,15 +216,11 @@ for now in range(0, 2949):
                 my_t_2 = pred_t.view(-1).cpu().data.numpy()
                 my_mat_2 = quaternion_matrix(my_r_2)
 
-                my_mat_2[0][3] = my_t_2[0]
-                my_mat_2[1][3] = my_t_2[1]
-                my_mat_2[2][3] = my_t_2[2]
+                my_mat_2[0:3, 3] = my_t_2
 
                 my_mat_final = np.dot(my_mat, my_mat_2)
                 my_r_final = copy.deepcopy(my_mat_final)
-                my_r_final[0][3] = 0
-                my_r_final[1][3] = 0
-                my_r_final[2][3] = 0
+                my_r_final[0:3, 3] = 0
                 my_r_final = quaternion_from_matrix(my_r_final, True)
                 my_t_final = np.array([my_mat_final[0][3], my_mat_final[1][3], my_mat_final[2][3]])
 
@@ -234,6 +228,8 @@ for now in range(0, 2949):
                 my_r = my_r_final
                 my_t = my_t_final
 
+            # Here 'my_pred' is the final pose estimation result after refinement ('my_r': quaternion, 'my_t': translation)
+
             my_result.append(my_pred.tolist())
         except ZeroDivisionError:
             print("PoseCNN Detector Lost {0} at No.{1} keyframe".format(itemid, now))