diff --git a/datasets/linemod/dataset.py b/datasets/linemod/dataset.py index a5aacdc34b9dee0ef5dff0a0785225e80e423a60..3f667d6fee5c5b9fb838e44c8e1b9508624d91b3 100755 --- a/datasets/linemod/dataset.py +++ b/datasets/linemod/dataset.py @@ -18,6 +18,7 @@ import copy import scipy.misc import scipy.io as scio import yaml +import cv2 class PoseDataset(data.Dataset): @@ -117,7 +118,10 @@ class PoseDataset(data.Dataset): img = np.transpose(img, (2, 0, 1)) img_masked = img - rmin, rmax, cmin, cmax = get_bbox(meta['obj_bb']) + if self.mode == 'eval': + rmin, rmax, cmin, cmax = get_bbox(mask_to_bbox(mask_label)) + else: + rmin, rmax, cmin, cmax = get_bbox(meta['obj_bb']) img_masked = img_masked[:, rmin:rmax, cmin:cmax] #p_img = np.transpose(img_masked, (1, 2, 0)) @@ -209,6 +213,24 @@ border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520 img_width = 480 img_length = 640 + +def mask_to_bbox(mask): + mask = mask.astype(np.uint8) + _, contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + x = 0 + y = 0 + w = 0 + h = 0 + for contour in contours: + tmp_x, tmp_y, tmp_w, tmp_h = cv2.boundingRect(contour) + if tmp_w * tmp_h > w * h: + x = tmp_x + y = tmp_y + w = tmp_w + h = tmp_h + return [x, y, w, h] + + def get_bbox(bbox): bbx = [bbox[1], bbox[1] + bbox[3], bbox[0], bbox[0] + bbox[2]] if bbx[0] < 0: