Skip to content
Snippets Groups Projects
Commit 88aaa21d authored by jwangzzz's avatar jwangzzz
Browse files

small fix on memory leak and segmentation training

parent dd8ff1a4
No related branches found
No related tags found
No related merge requests found
......@@ -8,9 +8,9 @@ import random
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor
knn = KNearestNeighbor(1)
def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
bs, num_p, _ = pred_c.size()
pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
......@@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points,
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
# print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
del knn
return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach()
......
......@@ -8,9 +8,9 @@ import random
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor
knn = KNearestNeighbor(1)
def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
knn = KNearestNeighbor(1)
pred_r = pred_r.view(1, 1, -1)
pred_t = pred_t.view(1, 1, -1)
bs, num_p, _ = pred_r.size()
......@@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin
new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()
# print('------------> ', dis.item(), idx[0].item())
del knn
return dis, new_points.detach(), new_target.detach()
......
......@@ -15,11 +15,10 @@ from PIL import ImageEnhance
from PIL import ImageFilter
class SegDataset(data.Dataset):
def __init__(self, root_dir, txtlist, use_noise, num=1000):
def __init__(self, root_dir, txtlist, use_noise, length):
self.path = []
self.real_path = []
self.use_noise = use_noise
self.num = num
self.root = root_dir
input_file = open(txtlist)
while 1:
......@@ -33,13 +32,17 @@ class SegDataset(data.Dataset):
self.real_path.append(copy.deepcopy(input_line))
input_file.close()
self.length = length
self.data_len = len(self.path)
self.back_len = len(self.real_path)
self.length = len(self.path)
self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05)
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.back_front = np.array([[1 for i in range(640)] for j in range(480)])
def __getitem__(self, index):
def __getitem__(self, idx):
index = random.randint(0, self.data_len - 10)
label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index])))
meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index]))
if not self.use_noise:
......@@ -51,7 +54,7 @@ class SegDataset(data.Dataset):
rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB")
rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8))
rgb = np.array(self.trancolor(rgb))
seed = random.randint(10, self.back_len - 10)
seed = random.randint(0, self.back_len - 10)
back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB")))
back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed])))
mask = ma.getmaskarray(ma.masked_equal(label, 0))
......
......@@ -38,10 +38,10 @@ if __name__ == '__main__':
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True)
dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers))
test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers))
print(len(dataset), len(test_dataset))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment