From 5e6a756eedead026f00d53aa30f74f9d06f2e614 Mon Sep 17 00:00:00 2001
From: Guillaume-Duret <guillaume.duret@ec-lyon.fr>
Date: Fri, 19 May 2023 18:43:04 +0200
Subject: [PATCH] run for 8 objects, batch 32, epoch 60

---
 datasets/linemod/dataset.py | 19 +++++++++++--------
 tools/train.py              | 14 +++++++-------
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/datasets/linemod/dataset.py b/datasets/linemod/dataset.py
index 4fccd48..2582a40 100755
--- a/datasets/linemod/dataset.py
+++ b/datasets/linemod/dataset.py
@@ -24,12 +24,11 @@ import matplotlib.pyplot as plt
 
 class PoseDataset(data.Dataset):
     def __init__(self, mode, num, add_noise, root, noise_trans, refine):
-        # ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ]
         #self.objlist = [1] #TODO
 
-        # apple, ....
-        #self.objlist = [1, 2, 3, 4, 5, 6, 7, 8]
-        self.objlist = [1]
+        # apple2 0 , apricot 1 , banana1 2 , kiwi1 3 , lemon2 4 , orange2 5 , peach1 6 , pear2 7
+        self.objlist = [1, 2, 3, 4, 5, 6, 7, 8]
+        #self.objlist = [1, 3, 6, 7, 8]
         self.mode = mode
 
         self.list_rgb = []
@@ -94,8 +93,8 @@ class PoseDataset(data.Dataset):
         self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
         self.num_pt_mesh_large = 500
         self.num_pt_mesh_small = 500
-        self.symmetry_obj_idx = [1] # TODO
-        #self.symmetry_obj_idx = [0 ,1, 3, 4, 5, 6]
+        self.symmetry_obj_idx = [0, 1, 3, 4, 5, 6] # TODO
+        #self.symmetry_obj_idx = [0, 2, 3]
 
 
     def __getitem__(self, index):
@@ -130,8 +129,8 @@ class PoseDataset(data.Dataset):
 
         mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0))
         if self.mode == 'eval':
-            #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
-            mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
+            mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
+            #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
         else: 
             #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0]
             mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255)))
@@ -193,6 +192,10 @@ class PoseDataset(data.Dataset):
 
         model_points = self.pt[obj] / 1000.0
         dellist = [j for j in range(0, len(model_points))]
+        #print("del :", len(dellist))
+        #print("len : ", len(model_points)-self.num_pt_mesh_small)
+        #print("model point : ", len(model_points))
+        #print("num_pt_mesh : ", self.num_pt_mesh_small)
         dellist = random.sample(dellist, len(model_points) - self.num_pt_mesh_small)
         model_points = np.delete(model_points, dellist, axis=0)
 
diff --git a/tools/train.py b/tools/train.py
index 61698f2..3ad78e1 100644
--- a/tools/train.py
+++ b/tools/train.py
@@ -30,7 +30,7 @@ from lib.utils import setup_logger
 parser = argparse.ArgumentParser()
 parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod')
 parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')')
-parser.add_argument('--batch_size', type=int, default = 8, help='batch size')
+parser.add_argument('--batch_size', type=int, default = 32, help='batch size')
 parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')
 parser.add_argument('--lr', default=0.0001, help='learning rate')
 parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')
@@ -40,7 +40,7 @@ parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr &
 parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement')
 parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data')
 parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')
-parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train')
+parser.add_argument('--nepoch', type=int, default=60, help='max number of epochs to train')
 parser.add_argument('--resume_posenet', type=str, default = '',  help='resume PoseNet model')
 parser.add_argument('--resume_refinenet', type=str, default = '',  help='resume PoseRefineNet model')
 parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
@@ -59,12 +59,12 @@ def main():
         opt.log_dir = 'experiments/logs/ycb' #folder to save logs
         opt.repeat_epoch = 1 #number of repeat times for one epoch training
     elif opt.dataset == 'linemod':
-        #opt.num_objects = 8 #TODO
-        opt.num_objects = 1 #TODO
+        opt.num_objects = 8 #TODO
+        #opt.num_objects = 5 #TODO
         opt.num_points = 500
-        opt.outf = 'trained_models/linemod'
-        opt.log_dir = 'experiments/logs/linemod'
-        opt.repeat_epoch = 20
+        opt.outf = 'trained_models/linemod8'
+        opt.log_dir = 'experiments/logs/linemod8'
+        opt.repeat_epoch = 5
     else:
         print('Unknown dataset')
         return
-- 
GitLab