diff --git a/lib/loss.py b/lib/loss.py index 4ad1064741d09f1048169bf162a07ff33790a1dc..8f78b881a0e1532341e836015dbf95569e56b277 100755 --- a/lib/loss.py +++ b/lib/loss.py @@ -8,9 +8,9 @@ import random import torch.backends.cudnn as cudnn from lib.knn.__init__ import KNearestNeighbor -knn = KNearestNeighbor(1) def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list): + knn = KNearestNeighbor(1) bs, num_p, _ = pred_c.size() pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1)) @@ -67,7 +67,7 @@ def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item()) - + del knn return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach() diff --git a/lib/loss_refiner.py b/lib/loss_refiner.py index d2c2931ccebe756e273a45df39df63a04813fa41..6496c9ad96e154765b514e09f397a3cd1778fa99 100755 --- a/lib/loss_refiner.py +++ b/lib/loss_refiner.py @@ -8,9 +8,9 @@ import random import torch.backends.cudnn as cudnn from lib.knn.__init__ import KNearestNeighbor -knn = KNearestNeighbor(1) def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list): + knn = KNearestNeighbor(1) pred_r = pred_r.view(1, 1, -1) pred_t = pred_t.view(1, 1, -1) bs, num_p, _ = pred_r.size() @@ -60,7 +60,7 @@ def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_poin new_target = torch.bmm((new_target - ori_t), ori_base).contiguous() # print('------------> ', dis.item(), idx[0].item()) - + del knn return dis, new_points.detach(), new_target.detach() diff --git a/vanilla_segmentation/data_controller.py b/vanilla_segmentation/data_controller.py index ed0891e0a723ae0cd83ec44de2771ee9f7482b77..37189e78b394f3f29467c2c2008d85f98cc55bb4 100644 --- a/vanilla_segmentation/data_controller.py +++ b/vanilla_segmentation/data_controller.py @@ -15,11 +15,10 @@ from PIL import ImageEnhance from PIL import ImageFilter class SegDataset(data.Dataset): - def __init__(self, root_dir, txtlist, use_noise, num=1000): + def __init__(self, root_dir, txtlist, use_noise, length): self.path = [] self.real_path = [] self.use_noise = use_noise - self.num = num self.root = root_dir input_file = open(txtlist) while 1: @@ -33,13 +32,17 @@ class SegDataset(data.Dataset): self.real_path.append(copy.deepcopy(input_line)) input_file.close() + self.length = length + self.data_len = len(self.path) self.back_len = len(self.real_path) - self.length = len(self.path) + self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.back_front = np.array([[1 for i in range(640)] for j in range(480)]) - def __getitem__(self, index): + def __getitem__(self, idx): + index = random.randint(0, self.data_len - 10) + label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[index]))) meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.path[index])) if not self.use_noise: @@ -51,7 +54,7 @@ class SegDataset(data.Dataset): rgb = Image.open('{0}/{1}-color.png'.format(self.root, self.path[index])).convert("RGB") rgb = ImageEnhance.Brightness(rgb).enhance(1.5).filter(ImageFilter.GaussianBlur(radius=0.8)) rgb = np.array(self.trancolor(rgb)) - seed = random.randint(10, self.back_len - 10) + seed = random.randint(0, self.back_len - 10) back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, self.path[seed])).convert("RGB"))) back_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.path[seed]))) mask = ma.getmaskarray(ma.masked_equal(label, 0)) diff --git a/vanilla_segmentation/train.py b/vanilla_segmentation/train.py index 8779e993a2f47f0da8f33647ee23d4347b352eea..83759ff1e74c3c0068a5b7f9225a2147637ffaca 100644 --- a/vanilla_segmentation/train.py +++ b/vanilla_segmentation/train.py @@ -38,10 +38,10 @@ if __name__ == '__main__': random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) - dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True) + dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/train_data_list.txt', True, 5000) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=int(opt.workers)) - test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False) - test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=int(opt.workers)) + test_dataset = SegDataset(opt.dataset_root, '../datasets/ycb/dataset_config/test_data_list.txt', False, 1000) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=int(opt.workers)) print(len(dataset), len(test_dataset))