Skip to content
Snippets Groups Projects
Commit 5e6a756e authored by Guillaume Duret's avatar Guillaume Duret
Browse files

run for 8 objects, batch 32, epoch 60

parent 0c8c9d47
No related branches found
No related tags found
No related merge requests found
...@@ -24,12 +24,11 @@ import matplotlib.pyplot as plt ...@@ -24,12 +24,11 @@ import matplotlib.pyplot as plt
class PoseDataset(data.Dataset): class PoseDataset(data.Dataset):
def __init__(self, mode, num, add_noise, root, noise_trans, refine): def __init__(self, mode, num, add_noise, root, noise_trans, refine):
# ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ]
#self.objlist = [1] #TODO #self.objlist = [1] #TODO
# apple, .... # apple2 0 , apricot 1 , banana1 2 , kiwi1 3 , lemon2 4 , orange2 5 , peach1 6 , pear2 7
#self.objlist = [1, 2, 3, 4, 5, 6, 7, 8] self.objlist = [1, 2, 3, 4, 5, 6, 7, 8]
self.objlist = [1] #self.objlist = [1, 3, 6, 7, 8]
self.mode = mode self.mode = mode
self.list_rgb = [] self.list_rgb = []
...@@ -94,8 +93,8 @@ class PoseDataset(data.Dataset): ...@@ -94,8 +93,8 @@ class PoseDataset(data.Dataset):
self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680] self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
self.num_pt_mesh_large = 500 self.num_pt_mesh_large = 500
self.num_pt_mesh_small = 500 self.num_pt_mesh_small = 500
self.symmetry_obj_idx = [1] # TODO self.symmetry_obj_idx = [0, 1, 3, 4, 5, 6] # TODO
#self.symmetry_obj_idx = [0 ,1, 3, 4, 5, 6] #self.symmetry_obj_idx = [0, 2, 3]
def __getitem__(self, index): def __getitem__(self, index):
...@@ -130,8 +129,8 @@ class PoseDataset(data.Dataset): ...@@ -130,8 +129,8 @@ class PoseDataset(data.Dataset):
mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0)) mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
if self.mode == 'eval': if self.mode == 'eval':
#mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
else: else:
#mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
...@@ -193,6 +192,10 @@ class PoseDataset(data.Dataset): ...@@ -193,6 +192,10 @@ class PoseDataset(data.Dataset):
model_points = self.pt[obj] / 1000.0 model_points = self.pt[obj] / 1000.0
dellist = [j for j in range(0, len(model_points))] dellist = [j for j in range(0, len(model_points))]
#print("del :", len(dellist))
#print("len : ", len(model_points)-self.num_pt_mesh_small)
#print("model point : ", len(model_points))
#print("num_pt_mesh : ", self.num_pt_mesh_small)
dellist = random.sample(dellist, len(model_points) - self.num_pt_mesh_small) dellist = random.sample(dellist, len(model_points) - self.num_pt_mesh_small)
model_points = np.delete(model_points, dellist, axis=0) model_points = np.delete(model_points, dellist, axis=0)
......
...@@ -30,7 +30,7 @@ from lib.utils import setup_logger ...@@ -30,7 +30,7 @@ from lib.utils import setup_logger
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod') parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod')
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')') parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')')
parser.add_argument('--batch_size', type=int, default = 8, help='batch size') parser.add_argument('--batch_size', type=int, default = 32, help='batch size')
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers') parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')
parser.add_argument('--lr', default=0.0001, help='learning rate') parser.add_argument('--lr', default=0.0001, help='learning rate')
parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate') parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')
...@@ -40,7 +40,7 @@ parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & ...@@ -40,7 +40,7 @@ parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr &
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement') parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement')
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data') parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data')
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations') parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train') parser.add_argument('--nepoch', type=int, default=60, help='max number of epochs to train')
parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model') parser.add_argument('--resume_posenet', type=str, default = '', help='resume PoseNet model')
parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model') parser.add_argument('--resume_refinenet', type=str, default = '', help='resume PoseRefineNet model')
parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start') parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
...@@ -59,12 +59,12 @@ def main(): ...@@ -59,12 +59,12 @@ def main():
opt.log_dir = 'experiments/logs/ycb' #folder to save logs opt.log_dir = 'experiments/logs/ycb' #folder to save logs
opt.repeat_epoch = 1 #number of repeat times for one epoch training opt.repeat_epoch = 1 #number of repeat times for one epoch training
elif opt.dataset == 'linemod': elif opt.dataset == 'linemod':
#opt.num_objects = 8 #TODO opt.num_objects = 8 #TODO
opt.num_objects = 1 #TODO #opt.num_objects = 5 #TODO
opt.num_points = 500 opt.num_points = 500
opt.outf = 'trained_models/linemod' opt.outf = 'trained_models/linemod8'
opt.log_dir = 'experiments/logs/linemod' opt.log_dir = 'experiments/logs/linemod8'
opt.repeat_epoch = 20 opt.repeat_epoch = 5
else: else:
print('Unknown dataset') print('Unknown dataset')
return return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment