Skip to content
Snippets Groups Projects
Commit 88aaa21d authored by jwangzzz's avatar jwangzzz
Browse files

small fix on memory leak and segmentation training

parent dd8ff1a4
No related branches found
No related tags found
No related merge requests found
...@@ -8,9 +8,9 @@ import random ...@@ -8,9 +8,9 @@ import random
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor from lib.knn.__init__ import KNearestNeighbor
knn = KNearestNeighbor(1)
def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list): def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
bs, num_p, _ = pred_c.size() bs, num_p, _ = pred_c.size()
pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1)) pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
...@@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, ...@@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
# print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item()) # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
del knn
return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach() return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
......
...@@ -8,9 +8,9 @@ import random ...@@ -8,9 +8,9 @@ import random
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor from lib.knn.__init__ import KNearestNeighbor
knn = KNearestNeighbor(1)
def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list): def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
pred_r = pred_r.view(1, 1, -1) pred_r = pred_r.view(1, 1, -1)
pred_t = pred_t.view(1, 1, -1) pred_t = pred_t.view(1, 1, -1)
bs, num_p, _ = pred_r.size() bs, num_p, _ = pred_r.size()
...@@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin ...@@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
# print('------------> ', dis.item(), idx[0].item()) # print('------------> ', dis.item(), idx[0].item())
del knn
return dis, new_points.detach(), new_target.detach() return dis, new_points.detach(), new_target.detach()
......
...@@ -15,11 +15,10 @@ from PIL import ImageEnhance ...@@ -15,11 +15,10 @@ from PIL import ImageEnhance
from PIL import ImageFilter from PIL import ImageFilter
class SegDataset(data.Dataset): class SegDataset(data.Dataset):
def __init__(self, root_dir, txtlist, use_noise, num=1000): def __init__(self, root_dir, txtlist, use_noise, length):
self.path = [] self.path = []
self.real_path = [] self.real_path = []
self.use_noise = use_noise self.use_noise = use_noise
self.num = num
self.root = root_dir self.root = root_dir
input_file = open(txtlist) input_file = open(txtlist)
while 1: while 1:
...@@ -33,13 +32,17 @@ class SegDataset(data.Dataset): ...@@ -33,13 +32,17 @@ class SegDataset(data.Dataset):
self.real_path.append(copy.deepcopy(input_line)) self.real_path.append(copy.deepcopy(input_line))
input_file.close() input_file.close()
self.length = length
self.data_len = len(self.path)
self.back_len = len(self.real_path) self.back_len = len(self.real_path)
self.length = len(self.path)
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.back_front = np.array([[1 for i in range(640)] for j in range(480)]) self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
def __getitem__(self, index): def __getitem__(self, idx):
index = random.randint(0, self.data_len - 10)
label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index]))) label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index])) meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
if not self.use_noise: if not self.use_noise:
...@@ -51,7 +54,7 @@ class SegDataset(data.Dataset): ...@@ -51,7 +54,7 @@ class SegDataset(data.Dataset):
rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB") rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8)) rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
rgb = np.array(self.trancolor(rgb)) rgb = np.array(self.trancolor(rgb))
seed = random.randint(10, self.back_len - 10) seed = random.randint(0, self.back_len - 10)
back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB"))) back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed]))) back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
mask = ma.getmaskarray(ma.masked_equal(label, 0)) mask = ma.getmaskarray(ma.masked_equal(label, 0))
......
...@@ -38,10 +38,10 @@ if __name__ == '__main__': ...@@ -38,10 +38,10 @@ if __name__ == '__main__':
random.seed(opt.manualSeed) random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed) torch.manual_seed(opt.manualSeed)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True) dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False) test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers)) test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset)) print(len(dataset), len(test_dataset))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment