diff --git a/lib/loss.py b/lib/loss.py
index 4ad1064741d09f1048169bf162a07ff33790a1dc..8f78b881a0e1532341e836015dbf95569e56b277 100755
--- a/lib/loss.py
+++ b/lib/loss.py
@@ -8,9 +8,9 @@ import random
 import torch.backends.cudnn as cudnn
 from lib.knn.__init__ import KNearestNeighbor
 
-knn = KNearestNeighbor(1)
 
 def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
+    knn = KNearestNeighbor(1)
     bs, num_p, _ = pred_c.size()
 
     pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
@@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
     new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
 
     # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
-
+    del knn
     return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
 
 
diff --git a/lib/loss_refiner.py b/lib/loss_refiner.py
index d2c2931ccebe756e273a45df39df63a04813fa41..6496c9ad96e154765b514e09f397a3cd1778fa99 100755
--- a/lib/loss_refiner.py
+++ b/lib/loss_refiner.py
@@ -8,9 +8,9 @@ import random
 import torch.backends.cudnn as cudnn
 from lib.knn.__init__ import KNearestNeighbor
 
-knn = KNearestNeighbor(1)
 
 def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
+    knn = KNearestNeighbor(1)
     pred_r = pred_r.view(1, 1, -1)
     pred_t = pred_t.view(1, 1, -1)
     bs, num_p, _ = pred_r.size()
@@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin
     new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
 
     # print('------------> ', dis.item(), idx[0].item())
-
+    del knn
     return dis, new_points.detach(), new_target.detach()
 
 
diff --git a/vanilla_segmentation/data_controller.py b/vanilla_segmentation/data_controller.py
index ed0891e0a723ae0cd83ec44de2771ee9f7482b77..37189e78b394f3f29467c2c2008d85f98cc55bb4 100644
--- a/vanilla_segmentation/data_controller.py
+++ b/vanilla_segmentation/data_controller.py
@@ -15,11 +15,10 @@ from PIL import ImageEnhance
 from PIL import ImageFilter
 
 class SegDataset(data.Dataset):
-    def __init__(self, root_dir, txtlist, use_noise, num=1000):
+    def __init__(self, root_dir, txtlist, use_noise, length):
         self.path = []
         self.real_path = []
         self.use_noise = use_noise
-        self.num = num
         self.root = root_dir
         input_file = open(txtlist)
         while 1:
@@ -33,13 +32,17 @@ class SegDataset(data.Dataset):
                 self.real_path.append(copy.deepcopy(input_line))
         input_file.close()
 
+        self.length = length
+        self.data_len = len(self.path)
         self.back_len = len(self.real_path)
-        self.length = len(self.path)
+
         self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
         self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
 
-    def __getitem__(self, index):
+    def __getitem__(self, idx):
+        index = random.randint(0, self.data_len - 10)
+
         label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
         meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
         if not self.use_noise:
@@ -51,7 +54,7 @@ class SegDataset(data.Dataset):
             rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
             rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
             rgb = np.array(self.trancolor(rgb))
-            seed = random.randint(10, self.back_len - 10)
+            seed = random.randint(0, self.back_len - 10)
             back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
             back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
             mask = ma.getmaskarray(ma.masked_equal(label, 0))
diff --git a/vanilla_segmentation/train.py b/vanilla_segmentation/train.py
index 8779e993a2f47f0da8f33647ee23d4347b352eea..83759ff1e74c3c0068a5b7f9225a2147637ffaca 100644
--- a/vanilla_segmentation/train.py
+++ b/vanilla_segmentation/train.py
@@ -38,10 +38,10 @@ if __name__ == '__main__':
     random.seed(opt.manualSeed)
     torch.manual_seed(opt.manualSeed)
 
-    dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True)
+    dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
-    test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False)
-    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers))
+    test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
+    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
 
     print(len(dataset), len(test_dataset))