From 89a8dcbb35cdc632300997fc43d5330f069e6c88 Mon Sep 17 00:00:00 2001 From: Guillaume-Duret <guillaume.duret@ec-lyon.fr> Date: Fri, 5 May 2023 15:25:26 +0200 Subject: [PATCH] change original dense fusion code for bin picking dataset --- .gitignore | 2 +- datasets/linemod/dataset.py | 60 ++++++++++++------- .../linemod/dataset_config/models_info.yml | 23 +++---- tools/train.py | 7 ++- 4 files changed, 52 insertions(+), 40 deletions(-) mode change 100755 => 100644 datasets/linemod/dataset_config/models_info.yml mode change 100755 => 100644 tools/train.py diff --git a/.gitignore b/.gitignore index 15cbdcf..fced0a2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -datasets/linemod/Linemod_preprocessed +*.png datasets/ycb/YCB_Video_Dataset *.zip *__pycache__ diff --git a/datasets/linemod/dataset.py b/datasets/linemod/dataset.py index 0a3a852..d578a79 100755 --- a/datasets/linemod/dataset.py +++ b/datasets/linemod/dataset.py @@ -19,11 +19,13 @@ import scipy.misc import scipy.io as scio import yaml import cv2 +import matplotlib.pyplot as plt class PoseDataset(data.Dataset): def __init__(self, mode, num, add_noise, root, noise_trans, refine): - self.objlist = [1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15] + # ["banana1", "kiwi1", "pear2", "strawberry1", "apricot", "orange2", "peach1", "lemon2", "apple2" ] + self.objlist = [1] self.mode = mode self.list_rgb = [] @@ -52,10 +54,11 @@ class PoseDataset(data.Dataset): break if input_line[-1:] == '\n': input_line = input_line[:-1] + input_line = input_line.split(".png")[0] self.list_rgb.append('{0}/data/{1}/rgb/{2}.png'.format(self.root, '%02d' % item, input_line)) self.list_depth.append('{0}/data/{1}/depth/{2}.png'.format(self.root, '%02d' % item, input_line)) if self.mode == 'eval': - self.list_label.append('{0}/segnet_results/{1}_label/{2}_label.png'.format(self.root, '%02d' % item, input_line)) + self.list_label.append('{0}/segnet_results/{1}_label/{2}.png'.format(self.root, '%02d' % item, input_line)) else: self.list_label.append('{0}/data/{1}/mask/{2}.png'.format(self.root, '%02d' % item, input_line)) @@ -63,17 +66,19 @@ class PoseDataset(data.Dataset): self.list_rank.append(int(input_line)) meta_file = open('{0}/data/{1}/gt.yml'.format(self.root, '%02d' % item), 'r') - self.meta[item] = yaml.load(meta_file) + #self.meta[item] = yaml.full_load(meta_file) + self.meta[item] = yaml.load(meta_file, Loader=yaml.Loader) self.pt[item] = ply_vtx('{0}/models/obj_{1}.ply'.format(self.root, '%02d' % item)) print("Object {0} buffer loaded".format(item)) self.length = len(self.list_rgb) - self.cam_cx = 325.26110 - self.cam_cy = 242.04899 - self.cam_fx = 572.41140 - self.cam_fy = 573.57043 + # TODO + self.cam_cx = 320.25 # TODO + self.cam_cy = 240.33333333333331 # TODO + self.cam_fx = 543.2527222420504 # TODO + self.cam_fy = 724.3369629894005 # TODO self.xmap = np.array([[j for i in range(640)] for j in range(480)]) self.ymap = np.array([[i for i in range(640)] for j in range(480)]) @@ -85,7 +90,7 @@ class PoseDataset(data.Dataset): self.border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680] self.num_pt_mesh_large = 500 self.num_pt_mesh_small = 500 - self.symmetry_obj_idx = [7, 8] + self.symmetry_obj_idx = [1] # TODO def __getitem__(self, index): img = Image.open(self.list_rgb[index]) @@ -95,20 +100,35 @@ class PoseDataset(data.Dataset): obj = self.list_obj[index] rank = self.list_rank[index] - if obj == 2: - for i in range(0, len(self.meta[obj][rank])): - if self.meta[obj][rank][i]['obj_id'] == 2: - meta = self.meta[obj][rank][i] - break - else: - meta = self.meta[obj][rank][0] + # if obj == 2: + # for i in range(0, len(self.meta[obj][rank])): + # if self.meta[obj][rank][i]['obj_id'] == 2: + # meta = self.meta[obj][rank][i] + # break + # else: + + print("---------------------------------------") + print("img : ", ori_img.shape) + #print("obj",obj) + print("label : ", label.shape) + print("depth : ", depth.shape) + + + print("rank : ",rank) + print("len self.meta : ", len(self.meta)) + print("len self.meta[obj] : ",len(self.meta[obj])) + #print("keys : ", self.meta[obj].keys() ) + #print("key : ", self.meta[obj]['86156']) + #print("self.meta[obj][rank] : ",self.meta[obj][rank] ) + meta = self.meta[obj][f"{rank}"][0] mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0)) if self.mode == 'eval': - mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) - else: + #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] - + else: + #mask_label = ma.getmaskarray(ma.masked_equal(label, np.array([255, 255, 255])))[:, :, 0] + mask_label = ma.getmaskarray(ma.masked_equal(label, np.array(255))) mask = mask_label * mask_depth if self.add_noise: @@ -133,6 +153,7 @@ class PoseDataset(data.Dataset): choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero()[0] if len(choose) == 0: + print("RETURN EMPTY IMAGE") cc = torch.LongTensor([0]) return(cc, cc, cc, cc, cc, cc) @@ -186,7 +207,6 @@ class PoseDataset(data.Dataset): #for it in target: # fw.write('{0} {1} {2}\n'.format(it[0], it[1], it[2])) #fw.close() - return torch.from_numpy(cloud.astype(np.float32)), \ torch.LongTensor(choose.astype(np.int32)), \ self.norm(torch.from_numpy(img_masked.astype(np.float32))), \ @@ -216,8 +236,6 @@ img_length = 640 def mask_to_bbox(mask): mask = mask.astype(np.uint8) contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) - - x = 0 y = 0 w = 0 diff --git a/datasets/linemod/dataset_config/models_info.yml b/datasets/linemod/dataset_config/models_info.yml old mode 100755 new mode 100644 index d789267..f1888a4 --- a/datasets/linemod/dataset_config/models_info.yml +++ b/datasets/linemod/dataset_config/models_info.yml @@ -1,15 +1,8 @@ -1: {diameter: 102.09865663, min_x: -37.93430000, min_y: -38.79960000, min_z: -45.88450000, size_x: 75.86860000, size_y: 77.59920000, size_z: 91.76900000} -2: {diameter: 247.50624233, min_x: -107.83500000, min_y: -60.92790000, min_z: -109.70500000, size_x: 215.67000000, size_y: 121.85570000, size_z: 219.41000000} -3: {diameter: 167.35486092, min_x: -83.21620000, min_y: -82.65910000, min_z: -37.23640000, size_x: 166.43240000, size_y: 165.31820000, size_z: 74.47280000} -4: {diameter: 172.49224865, min_x: -68.32970000, min_y: -71.51510000, min_z: -50.24850000, size_x: 136.65940000, size_y: 143.03020000, size_z: 100.49700000} -5: {diameter: 201.40358597, min_x: -50.39580000, min_y: -90.89790000, min_z: -96.86700000, size_x: 100.79160000, size_y: 181.79580000, size_z: 193.73400000} -6: {diameter: 154.54551808, min_x: -33.50540000, min_y: -63.81650000, min_z: -58.72830000, size_x: 67.01070000, size_y: 127.63300000, size_z: 117.45660000} -7: {diameter: 124.26430816, min_x: -58.78990000, min_y: -45.75560000, min_z: -47.31120000, size_x: 117.57980000, size_y: 91.51120000, size_z: 94.62240000} -8: {diameter: 261.47178102, min_x: -114.73800000, min_y: -37.73570000, min_z: -104.00100000, size_x: 229.47600000, size_y: 75.47140000, size_z: 208.00200000} -9: {diameter: 108.99920102, min_x: -52.21460000, min_y: -38.70380000, min_z: -42.84850000, size_x: 104.42920000, size_y: 77.40760000, size_z: 85.69700000} -10: {diameter: 164.62758848, min_x: -75.09230000, min_y: -53.53750000, min_z: -34.62070000, size_x: 150.18460000, size_y: 107.07500000, size_z: 69.24140000} -11: {diameter: 175.88933422, min_x: -18.36050000, min_y: -38.93300000, min_z: -86.40790000, size_x: 36.72110000, size_y: 77.86600000, size_z: 172.81580000} -12: {diameter: 145.54287471, min_x: -50.44390000, min_y: -54.24850000, min_z: -45.40000000, size_x: 100.88780000, size_y: 108.49700000, size_z: 90.80000000} -13: {diameter: 278.07811733, min_x: -129.11300000, min_y: -59.24100000, min_z: -70.56620000, size_x: 258.22600000, size_y: 118.48210000, size_z: 141.13240000} -14: {diameter: 282.60129399, min_x: -101.57300000, min_y: -58.87630000, min_z: -106.55800000, size_x: 203.14600000, size_y: 117.75250000, size_z: 213.11600000} -15: {diameter: 212.35825148, min_x: -46.95910000, min_y: -73.71670000, min_z: -92.37370000, size_x: 93.91810000, size_y: 147.43340000, size_z: 184.74740000} \ No newline at end of file +1: {diameter: 52.94260215554199, min_x: -26.221, min_y: -23.598, min_z: -25.405, size_x: 52.2, size_y: 47.816, size_z: 51.055}, +2: {diameter: 55.64182762994041, min_x: -21.122, min_y: -28.445, min_z: -22.181, size_x: 42.491, size_y: 55.287, size_z: 43.699}, +3: {diameter: 152.6314929798569, min_x: -14.428, min_y: -75.92699, min_z: -23.056, size_x: 29.562, size_y: 151.43699, size_z: 60.727000000000004}, +4: {diameter: 72.65268354988685, min_x: -24.562, min_y: -36.235, min_z: -24.574, size_x: 49.108000000000004, size_y: 72.094, size_z: 49.120000000000005}, +5: {diameter: 72.58068124508065, min_x: -23.86, min_y: -23.435, min_z: -35.375, size_x: 46.884, size_y: 46.866, size_z: 72.48400000000001}, +6: {diameter: 77.41278909198401, min_x: -36.512,min_y: -36.759, min_z: -37.344, size_x: 73.693, size_y: 76.06, size_z: 74.78200000000001}, +7: {diameter: 76.63639853229013, min_x: -37.55, min_y: -35.435, min_z: -36.938, size_x: 74.026, size_y: 71.15899999999999, size_z: 76.622}, +8: {diameter: 129.85637252518683, min_x: -32.724, min_y: -52.947, min_z: -33.292, size_x: 66.189, size_y: 129.07999999999998, size_z: 67.57300000000001} diff --git a/tools/train.py b/tools/train.py old mode 100755 new mode 100644 index 78b8fc8..a2f2936 --- a/tools/train.py +++ b/tools/train.py @@ -20,7 +20,7 @@ import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils from torch.autograd import Variable -from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb +#from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod from lib.network import PoseNet, PoseRefineNet from lib.loss import Loss @@ -59,11 +59,12 @@ def main(): opt.log_dir = 'experiments/logs/ycb' #folder to save logs opt.repeat_epoch = 1 #number of repeat times for one epoch training elif opt.dataset == 'linemod': - opt.num_objects = 13 + #opt.num_objects = 8 #TODO + opt.num_objects = 1 #TODO opt.num_points = 500 opt.outf = 'trained_models/linemod' opt.log_dir = 'experiments/logs/linemod' - opt.repeat_epoch = 20 + opt.repeat_epoch = 10 else: print('Unknown dataset') return -- GitLab