From 88aaa21d5fc9665796eb9fc2dec0933aee4be1cb Mon Sep 17 00:00:00 2001 From: jwangzzz <j96w@qq.com> Date: Sun, 7 Apr 2019 13:35:13 +0800 Subject: [PATCH] small fix on memory leak and segmentation training --- lib/loss.py | 4 ++-- lib/loss_refiner.py | 4 ++-- vanilla_segmentation/data_controller.py | 13 ++++++++----- vanilla_segmentation/train.py | 6 +++--- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/lib/loss.py b/lib/loss.py index 4ad1064..8f78b88 100755 --- a/lib/loss.py +++ b/lib/loss.py @@ -8,9 +8,9 @@ import random import torch.backends.cudnn as cudnn from lib.knn.__init__ import KNearestNeighbor -knn = KNearestNeighbor(1) def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list): + knn = KNearestNeighbor(1) bs, num_p, _ = pred_c.size() pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1)) @@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item()) - + del knn return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach() diff --git a/lib/loss_refiner.py b/lib/loss_refiner.py index d2c2931..6496c9a 100755 --- a/lib/loss_refiner.py +++ b/lib/loss_refiner.py @@ -8,9 +8,9 @@ import random import torch.backends.cudnn as cudnn from lib.knn.__init__ import KNearestNeighbor -knn = KNearestNeighbor(1) def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list): + knn = KNearestNeighbor(1) pred_r = pred_r.view(1, 1, -1) pred_t = pred_t.view(1, 1, -1) bs, num_p, _ = pred_r.size() @@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() # print('------------> ', dis.item(), idx[0].item()) - + del knn return dis, new_points.detach(), new_target.detach() diff --git a/vanilla_segmentation/data_controller.py b/vanilla_segmentation/data_controller.py index ed0891e..37189e7 100644 --- a/vanilla_segmentation/data_controller.py +++ b/vanilla_segmentation/data_controller.py @@ -15,11 +15,10 @@ from PIL import ImageEnhance from PIL import ImageFilter class SegDataset(data.Dataset): - def __init__(self, root_dir, txtlist, use_noise, num=1000): + def __init__(self, root_dir, txtlist, use_noise, length): self.path = [] self.real_path = [] self.use_noise = use_noise - self.num = num self.root = root_dir input_file = open(txtlist) while 1: @@ -33,13 +32,17 @@ class SegDataset(data.Dataset): self.real_path.append(copy.deepcopy(input_line)) input_file.close() + self.length = length + self.data_len = len(self.path) self.back_len = len(self.real_path) - self.length = len(self.path) + self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.back_front = np.array([[1 for i in range(640)] for j in range(480)]) - def __getitem__(self, index): + def __getitem__(self, idx): + index = random.randint(0, self.data_len - 10) + label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index]))) meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index])) if not self.use_noise: @@ -51,7 +54,7 @@ class SegDataset(data.Dataset): rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB") rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8)) rgb = np.array(self.trancolor(rgb)) - seed = random.randint(10, self.back_len - 10) + seed = random.randint(0, self.back_len - 10) back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB"))) back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed]))) mask = ma.getmaskarray(ma.masked_equal(label, 0)) diff --git a/vanilla_segmentation/train.py b/vanilla_segmentation/train.py index 8779e99..83759ff 100644 --- a/vanilla_segmentation/train.py +++ b/vanilla_segmentation/train.py @@ -38,10 +38,10 @@ if __name__ == '__main__': random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) - dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True) + dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) - test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers)) + test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers)) print(len(dataset), len(test_dataset)) -- GitLab