From 5e6a756eedead026f00d53aa30f74f9d06f2e614 Mon Sep 17 00:00:00 2001 From: Guillaume-Duret <guillaume.duret@ec-lyon.fr> Date: Fri, 19 May 2023 18:43:04 +0200 Subject: [PATCH] run for 8 objects, batch 32, epoch 60 --- datasets/linemod/dataset.py | 19 +++++++++++-------- tools/train.py | 14 +++++++------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/datasets/linemod/dataset.py b/datasets/linemod/dataset.py index 4fccd48..2582a40 100755 --- a/datasets/linemod/dataset.py +++ b/datasets/linemod/dataset.py @@ -24,12 +24,11 @@ import matplotlib.pyplot as plt class PoseDataset(data.Dataset): def __init__(self, mode, num, add_noise, root, noise_trans, refine): - # ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ] #self.objlist = [1] #TODO - # apple, .... - #self.objlist = [1, 2, 3, 4, 5, 6, 7, 8] - self.objlist = [1] + # apple2 0 , apricot 1 , banana1 2 , kiwi1 3 , lemon2 4 , orange2 5 , peach1 6 , pear2 7 + self.objlist = [1, 2, 3, 4, 5, 6, 7, 8] + #self.objlist = [1, 3, 6, 7, 8] self.mode = mode self.list_rgb = [] @@ -94,8 +93,8 @@ class PoseDataset(data.Dataset): self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680] self.num_pt_mesh_large = 500 self.num_pt_mesh_small = 500 - self.symmetry_obj_idx = [1] # TODO - #self.symmetry_obj_idx = [0 ,1, 3, 4, 5, 6] + self.symmetry_obj_idx = [0, 1, 3, 4, 5, 6] # TODO + #self.symmetry_obj_idx = [0, 2, 3] def __getitem__(self, index): @@ -130,8 +129,8 @@ class PoseDataset(data.Dataset): mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0)) if self.mode == 'eval': - #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) - mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] + mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) + #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] else: #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) @@ -193,6 +192,10 @@ class PoseDataset(data.Dataset): model_points = self.pt[obj] / 1000.0 dellist = [j for j in range(0, len(model_points))] + #print("del :", len(dellist)) + #print("len : ", len(model_points)-self.num_pt_mesh_small) + #print("model point : ", len(model_points)) + #print("num_pt_mesh : ", self.num_pt_mesh_small) dellist = random.sample(dellist, len(model_points) - self.num_pt_mesh_small) model_points = np.delete(model_points, dellist, axis=0) diff --git a/tools/train.py b/tools/train.py index 61698f2..3ad78e1 100644 --- a/tools/train.py +++ b/tools/train.py @@ -30,7 +30,7 @@ from lib.utils import setup_logger parser = argparse.ArgumentParser() parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod') parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')') -parser.add_argument('--batch_size', type=int, default = 8, help='batch size') +parser.add_argument('--batch_size', type=int, default = 32, help='batch size') parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers') parser.add_argument('--lr', default=0.0001, help='learning rate') parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate') @@ -40,7 +40,7 @@ parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement') parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data') parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations') -parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train') +parser.add_argument('--nepoch', type=int, default=60, help='max number of epochs to train') parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model') parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model') parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start') @@ -59,12 +59,12 @@ def main(): opt.log_dir = 'experiments/logs/ycb' #folder to save logs opt.repeat_epoch = 1 #number of repeat times for one epoch training elif opt.dataset == 'linemod': - #opt.num_objects = 8 #TODO - opt.num_objects = 1 #TODO + opt.num_objects = 8 #TODO + #opt.num_objects = 5 #TODO opt.num_points = 500 - opt.outf = 'trained_models/linemod' - opt.log_dir = 'experiments/logs/linemod' - opt.repeat_epoch = 20 + opt.outf = 'trained_models/linemod8' + opt.log_dir = 'experiments/logs/linemod8' + opt.repeat_epoch = 5 else: print('Unknown dataset') return -- GitLab