From 88aaa21d5fc9665796eb9fc2dec0933aee4be1cb Mon Sep 17 00:00:00 2001
From: jwangzzz <j96w@qq.com>
Date: Sun, 7 Apr 2019 13:35:13 +0800
Subject: [PATCH] small fix on memory leak and segmentation training

---
 lib/loss.py                             |  4 ++--
 lib/loss_refiner.py                     |  4 ++--
 vanilla_segmentation/data_controller.py | 13 ++++++++-----
 vanilla_segmentation/train.py           |  6 +++---
 4 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/lib/loss.py b/lib/loss.py
index 4ad1064..8f78b88 100755
--- a/lib/loss.py
+++ b/lib/loss.py
@@ -8,9 +8,9 @@ import random
 import torch.backends.cudnn as cudnn
 from lib.knn.__init__ import KNearestNeighbor
 
-knn = KNearestNeighbor(1)
 
 def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
+    knn = KNearestNeighbor(1)
     bs, num_p, _ = pred_c.size()
 
     pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
@@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
     new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
 
     # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
-
+    del knn
     return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
 
 
diff --git a/lib/loss_refiner.py b/lib/loss_refiner.py
index d2c2931..6496c9a 100755
--- a/lib/loss_refiner.py
+++ b/lib/loss_refiner.py
@@ -8,9 +8,9 @@ import random
 import torch.backends.cudnn as cudnn
 from lib.knn.__init__ import KNearestNeighbor
 
-knn = KNearestNeighbor(1)
 
 def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
+    knn = KNearestNeighbor(1)
     pred_r = pred_r.view(1, 1, -1)
     pred_t = pred_t.view(1, 1, -1)
     bs, num_p, _ = pred_r.size()
@@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin
     new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
 
     # print('------------> ', dis.item(), idx[0].item())
-
+    del knn
     return dis, new_points.detach(), new_target.detach()
 
 
diff --git a/vanilla_segmentation/data_controller.py b/vanilla_segmentation/data_controller.py
index ed0891e..37189e7 100644
--- a/vanilla_segmentation/data_controller.py
+++ b/vanilla_segmentation/data_controller.py
@@ -15,11 +15,10 @@ from PIL import ImageEnhance
 from PIL import ImageFilter
 
 class SegDataset(data.Dataset):
-    def __init__(self, root_dir, txtlist, use_noise, num=1000):
+    def __init__(self, root_dir, txtlist, use_noise, length):
         self.path = []
         self.real_path = []
         self.use_noise = use_noise
-        self.num = num
         self.root = root_dir
         input_file = open(txtlist)
         while 1:
@@ -33,13 +32,17 @@ class SegDataset(data.Dataset):
                 self.real_path.append(copy.deepcopy(input_line))
         input_file.close()
 
+        self.length = length
+        self.data_len = len(self.path)
         self.back_len = len(self.real_path)
-        self.length = len(self.path)
+
         self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
         self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
         self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
 
-    def __getitem__(self, index):
+    def __getitem__(self, idx):
+        index = random.randint(0, self.data_len - 10)
+
         label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
         meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
         if not self.use_noise:
@@ -51,7 +54,7 @@ class SegDataset(data.Dataset):
             rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
             rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
             rgb = np.array(self.trancolor(rgb))
-            seed = random.randint(10, self.back_len - 10)
+            seed = random.randint(0, self.back_len - 10)
             back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
             back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
             mask = ma.getmaskarray(ma.masked_equal(label, 0))
diff --git a/vanilla_segmentation/train.py b/vanilla_segmentation/train.py
index 8779e99..83759ff 100644
--- a/vanilla_segmentation/train.py
+++ b/vanilla_segmentation/train.py
@@ -38,10 +38,10 @@ if __name__ == '__main__':
     random.seed(opt.manualSeed)
     torch.manual_seed(opt.manualSeed)
 
-    dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True)
+    dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
-    test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False)
-    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers))
+    test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
+    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
 
     print(len(dataset), len(test_dataset))
 
-- 
GitLab