From 0086b12339a2445d20347db577f9a15e154835fa Mon Sep 17 00:00:00 2001 From: Guillaume-Duret <guillaume.duret@ec-lyon.fr> Date: Sat, 6 May 2023 21:30:29 +0200 Subject: [PATCH] denseffusion 1 object running --- datasets/linemod/dataset.py | 20 +++++++++++++------- tools/train.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/datasets/linemod/dataset.py b/datasets/linemod/dataset.py index d578a79..4fccd48 100755 --- a/datasets/linemod/dataset.py +++ b/datasets/linemod/dataset.py @@ -25,6 +25,10 @@ import matplotlib.pyplot as plt class PoseDataset(data.Dataset): def __init__(self, mode, num, add_noise, root, noise_trans, refine): # ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ] + #self.objlist = [1] #TODO + + # apple, .... + #self.objlist = [1, 2, 3, 4, 5, 6, 7, 8] self.objlist = [1] self.mode = mode @@ -91,6 +95,8 @@ class PoseDataset(data.Dataset): self.num_pt_mesh_large = 500 self.num_pt_mesh_small = 500 self.symmetry_obj_idx = [1] # TODO + #self.symmetry_obj_idx = [0 ,1, 3, 4, 5, 6] + def __getitem__(self, index): img = Image.open(self.list_rgb[index]) @@ -107,16 +113,16 @@ class PoseDataset(data.Dataset): # break # else: - print("---------------------------------------") - print("img : ", ori_img.shape) + #print("---------------------------------------") + #print("img : ", ori_img.shape) #print("obj",obj) - print("label : ", label.shape) - print("depth : ", depth.shape) + #print("label : ", label.shape) + #print("depth : ", depth.shape) - print("rank : ",rank) - print("len self.meta : ", len(self.meta)) - print("len self.meta[obj] : ",len(self.meta[obj])) + #print("rank : ",rank) + #print("len self.meta : ", len(self.meta)) + #print("len self.meta[obj] : ",len(self.meta[obj])) #print("keys : ", self.meta[obj].keys() ) #print("key : ", self.meta[obj]['86156']) #print("self.meta[obj][rank] : ",self.meta[obj][rank] ) diff --git a/tools/train.py b/tools/train.py index a2f2936..61698f2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -64,7 +64,7 @@ def main(): opt.num_points = 500 opt.outf = 'trained_models/linemod' opt.log_dir = 'experiments/logs/linemod' - opt.repeat_epoch = 10 + opt.repeat_epoch = 20 else: print('Unknown dataset') return -- GitLab