From d4dbab2e3099aeafd0f992429f1096a093e5d2d9 Mon Sep 17 00:00:00 2001
From: liuxingyu <lxy17@foxmail.com>
Date: Wed, 2 Nov 2022 14:01:27 +0800
Subject: [PATCH] reformat

---
 det/yolox/data/datasets/lm_dataset_d2.py | 409 ++++++++++++++++-------
 1 file changed, 297 insertions(+), 112 deletions(-)

diff --git a/det/yolox/data/datasets/lm_dataset_d2.py b/det/yolox/data/datasets/lm_dataset_d2.py
index 269c1e3..dae9314 100644
--- a/det/yolox/data/datasets/lm_dataset_d2.py
+++ b/det/yolox/data/datasets/lm_dataset_d2.py
@@ -20,7 +20,7 @@ import ref
 
 from lib.pysixd import inout, misc
 from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
-from lib.utils.utils import dprint, lazy_property
+from lib.utils.utils import dprint, iprint, lazy_property
 
 
 logger = logging.getLogger(__name__)
@@ -43,17 +43,16 @@ class LM_Dataset(object):
 
         self.ann_files = data_cfg["ann_files"]  # idx files with image ids
         self.image_prefixes = data_cfg["image_prefixes"]
+        self.xyz_prefixes = data_cfg["xyz_prefixes"]
 
         self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/lm/
+        assert osp.exists(self.dataset_root), self.dataset_root
         self.models_root = data_cfg["models_root"]  # BOP_DATASETS/lm/models
         self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
 
         self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
         self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
-        self.depth_factor = data_cfg["depth_factor"]  # 1000.0
 
-        self.cam_type = data_cfg["cam_type"]
-        self.cam = data_cfg["cam"]  #
         self.height = data_cfg["height"]  # 480
         self.width = data_cfg["width"]  # 640
 
@@ -62,31 +61,8 @@ class LM_Dataset(object):
         self.num_to_load = data_cfg["num_to_load"]  # -1
         self.filter_invalid = data_cfg["filter_invalid"]
         self.filter_scene = data_cfg.get("filter_scene", False)
+        self.debug_im_id = data_cfg.get("debug_im_id", None)
         ##################################################
-        if self.cam is None:
-            assert self.cam_type in ["local", "dataset"]
-            if self.cam_type == "dataset":
-                self.cam = np.array(
-                    [
-                        [572.4114, 0, 325.2611],
-                        [0, 573.57043, 242.04899],
-                        [0, 0, 1],
-                    ]
-                )
-            elif self.cam_type == "local":
-                # self.cam = np.array([[539.8100, 0, 318.2700], [0, 539.8300, 239.5600], [0, 0, 1]])
-                # yapf: disable
-                self.cam = np.array(
-                    [[518.81993115, 0.,           320.50653699],
-                     [0.,           518.86581081, 243.5604188 ],
-                     [0.,           0.,           1.          ]])
-                # yapf: enable
-                # RMS: 0.14046169348724977
-                # camera matrix:
-                # [[518.81993115   0.         320.50653699]
-                # [  0.         518.86581081 243.5604188 ]
-                # [  0.           0.           1.        ]]
-                # distortion coefficients:  [ 0.04147325 -0.21469544 -0.00053707 -0.00251986  0.17406399]
 
         # NOTE: careful! Only the selected objects
         self.cat_ids = [cat_id for cat_id, obj_name in ref.lm_full.id2obj.items() if obj_name in self.objs]
@@ -107,12 +83,11 @@ class LM_Dataset(object):
         hashed_file_name = hashlib.md5(
             (
                 "".join([str(fn) for fn in self.objs])
-                + "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
                     self.name,
                     self.dataset_root,
                     self.with_masks,
                     self.with_depth,
-                    self.cam_type,
                     __name__,
                 )
             ).encode("utf-8")
@@ -131,15 +106,17 @@ class LM_Dataset(object):
         logger.info("loading dataset dicts: {}".format(self.name))
         self.num_instances_without_valid_segmentation = 0
         self.num_instances_without_valid_box = 0
-        dataset_dicts = []  #######################################################
+        dataset_dicts = []  # ######################################################
         assert len(self.ann_files) == len(self.image_prefixes), f"{len(self.ann_files)} != {len(self.image_prefixes)}"
+        assert len(self.ann_files) == len(self.xyz_prefixes), f"{len(self.ann_files)} != {len(self.xyz_prefixes)}"
         unique_im_id = 0
-        for ann_file, scene_root in zip(tqdm(self.ann_files), self.image_prefixes):
+        for ann_file, scene_root, xyz_root in zip(tqdm(self.ann_files), self.image_prefixes, self.xyz_prefixes):
             # linemod each scene is an object
             with open(ann_file, "r") as f_ann:
                 indices = [line.strip("\r\n") for line in f_ann.readlines()]  # string ids
             gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
             gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))  # bbox_obj, bbox_visib
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
             for im_id in tqdm(indices):
                 int_im_id = int(im_id)
                 str_im_id = str(int_im_id)
@@ -150,6 +127,13 @@ class LM_Dataset(object):
 
                 scene_id = int(rgb_path.split("/")[-3])
                 scene_im_id = f"{scene_id}/{int_im_id}"
+
+                if self.debug_im_id is not None:
+                    if self.debug_im_id != scene_im_id:
+                        continue
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]
                 if self.filter_scene:
                     if scene_id not in self.cat_ids:
                         continue
@@ -161,7 +145,8 @@ class LM_Dataset(object):
                     "width": self.width,
                     "image_id": unique_im_id,
                     "scene_im_id": scene_im_id,  # for evaluation
-                    "cam": self.cam,
+                    "cam": K,
+                    "depth_factor": depth_factor,
                     "img_type": "real",
                 }
                 unique_im_id += 1
@@ -197,7 +182,7 @@ class LM_Dataset(object):
                     )
                     assert osp.exists(mask_file), mask_file
                     assert osp.exists(mask_visib_file), mask_visib_file
-                    # load mask visib  TODO: load both mask_visib and mask_full
+                    # load mask visib
                     mask_single = mmcv.imread(mask_visib_file, "unchanged")
                     mask_single = mask_single.astype("bool")
                     area = mask_single.sum()
@@ -205,17 +190,32 @@ class LM_Dataset(object):
                         self.num_instances_without_valid_segmentation += 1
                         continue
                     mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
                     inst = {
                         "category_id": cur_label,  # 0-based label
-                        "bbox": bbox_visib,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
                         "bbox_mode": BoxMode.XYWH_ABS,
                         "pose": pose,
                         "quat": quat,
                         "trans": t,
                         "centroid_2d": proj,  # absolute (cx, cy)
                         "segmentation": mask_rle,
-                        "mask_full_file": mask_file,  # TODO: load as mask_full, rle
+                        "mask_full": mask_full_rle,
                     }
+
+                    if "test" not in self.name.lower():
+                        # if True:
+                        xyz_path = osp.join(xyz_root, f"{int_im_id:06d}_{anno_i:06d}.pkl")
+                        assert osp.exists(xyz_path), xyz_path
+                        inst["xyz_path"] = xyz_path
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
                     for key in ["bbox3d_and_center"]:
                         inst[key] = self.models[cur_label][key]
                     insts.append(inst)
@@ -247,6 +247,13 @@ class LM_Dataset(object):
         logger.info("Dumped dataset_dicts to {}".format(cache_path))
         return dataset_dicts
 
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
     @lazy_property
     def models(self):
         """Load models into a list."""
@@ -281,9 +288,25 @@ class LM_Dataset(object):
 ########### register datasets ############################################################
 
 
-def get_lm_metadata(obj_names):
-    # task specific metadata
-    meta = {"thing_classes": obj_names}
+def get_lm_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
     return meta
 
 
@@ -334,12 +357,16 @@ SPLITS_LM = dict(
             )
             for _obj in LM_13_OBJECTS
         ],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[_obj]),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lm_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -366,44 +393,16 @@ SPLITS_LM = dict(
             osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[_obj])
             for _obj in LM_13_OBJECTS
         ],
-        scale_to_meter=0.001,
-        with_masks=True,  # (load masks but may not use it)
-        with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lm_full.camera_matrix,
-        height=480,
-        width=640,
-        cache_dir=osp.join(PROJ_ROOT, ".cache"),
-        use_cache=True,
-        num_to_load=-1,
-        filter_scene=True,
-        filter_invalid=False,
-        ref_key="lm_full",
-    ),
-    lm_13_all=dict(
-        name="lm_13_all",  # for get all real bboxes
-        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
-        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
-        objs=LM_13_OBJECTS,
-        ann_files=[
+        xyz_prefixes=[
             osp.join(
                 DATASETS_ROOT,
-                "BOP_DATASETS/lm/image_set/{}_{}.txt".format(_obj, "all"),
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[_obj]),
             )
             for _obj in LM_13_OBJECTS
         ],
-        # NOTE: scene root
-        image_prefixes=[
-            osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[_obj])
-            for _obj in LM_13_OBJECTS
-        ],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lm_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -433,12 +432,16 @@ SPLITS_LM = dict(
             )
             for _obj in LM_OCC_OBJECTS
         ],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lmo_full.obj2id[_obj]),
+            )
+            for _obj in LM_OCC_OBJECTS
+        ],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lmo_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -454,14 +457,16 @@ SPLITS_LM = dict(
         models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
         objs=LM_OCC_OBJECTS,
         ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_no_bop_test.txt")],
-        # NOTE: scene root
         image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lmo/test/xyz_crop/{:06d}".format(2),
+            )
+        ],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lmo_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -479,12 +484,10 @@ SPLITS_LM = dict(
         ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_test.txt")],
         # NOTE: scene root
         image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[None],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lmo_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -502,12 +505,10 @@ SPLITS_LM = dict(
         ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_bop_test.txt")],
         # NOTE: scene root
         image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[None],
         scale_to_meter=0.001,
         with_masks=True,  # (load masks but may not use it)
         with_depth=True,  # (load depth path here, but may not use it)
-        depth_factor=1000.0,
-        cam_type="dataset",
-        cam=ref.lmo_full.camera_matrix,
         height=480,
         width=640,
         cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -543,18 +544,21 @@ for obj in ref.lm_full.objects:
                 objs=[obj],  # only this obj
                 ann_files=ann_files,
                 image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[obj])],
+                xyz_prefixes=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[obj]),
+                    )
+                ],
                 scale_to_meter=0.001,
                 with_masks=True,  # (load masks but may not use it)
                 with_depth=True,  # (load depth path here, but may not use it)
-                depth_factor=1000.0,
-                cam_type="dataset",
-                cam=ref.lm_full.camera_matrix,
                 height=480,
                 width=640,
                 cache_dir=osp.join(PROJ_ROOT, ".cache"),
                 use_cache=True,
                 num_to_load=-1,
-                filter_invalid=False,
+                filter_invalid=filter_invalid,
                 filter_scene=True,
                 ref_key="lm_full",
             )
@@ -581,12 +585,15 @@ for obj in ref.lmo_full.objects:
                 ],
                 # NOTE: scene root
                 image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/test/xyz_crop/{:06d}".format(2),
+                    )
+                ],
                 scale_to_meter=0.001,
                 with_masks=True,  # (load masks but may not use it)
                 with_depth=True,  # (load depth path here, but may not use it)
-                depth_factor=1000.0,
-                cam_type="dataset",
-                cam=ref.lmo_full.camera_matrix,
                 height=480,
                 width=640,
                 cache_dir=osp.join(PROJ_ROOT, ".cache"),
@@ -597,6 +604,131 @@ for obj in ref.lmo_full.objects:
                 ref_key="lmo_full",
             )
 
+# single obj splits for lmo_test
+for obj in ref.lmo_full.objects:
+    for split in ["test"]:
+        name = "lmo_{}_{}".format(obj, split)
+        if split in ["train", "all"]:  # all is used to train lmo
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/image_set/lmo_test.txt",
+                    )
+                ],
+                # NOTE: scene root
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[None],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_scene=False,
+                filter_invalid=False,
+                ref_key="lmo_full",
+            )
+
+# single obj splits for lmo_bop_test
+for obj in ref.lmo_full.objects:
+    for split in ["test"]:
+        name = "lmo_{}_bop_{}".format(obj, split)
+        if split in ["train", "all"]:  # all is used to train lmo
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/image_set/lmo_bop_test.txt",
+                    )
+                ],
+                # NOTE: scene root
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[None],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_scene=False,
+                filter_invalid=False,
+                ref_key="lmo_full",
+            )
+
+# ================ add single image dataset for debug =======================================
+debug_im_ids = {
+    "train": {obj: [] for obj in ref.lm_full.objects},
+    "test": {obj: [] for obj in ref.lm_full.objects},
+}
+for obj in ref.lm_full.objects:
+    for split in ["train", "test"]:
+        cur_ann_file = osp.join(DATASETS_ROOT, f"BOP_DATASETS/lm/image_set/{obj}_{split}.txt")
+        ann_files = [cur_ann_file]
+
+        im_ids = []
+        with open(cur_ann_file, "r") as f:
+            for line in f:
+                # scene_id(obj_id)/im_id
+                im_ids.append("{}/{}".format(ref.lm_full.obj2id[obj], int(line.strip("\r\n"))))
+
+        debug_im_ids[split][obj] = im_ids
+        for debug_im_id in debug_im_ids[split][obj]:
+            name = "lm_single_{}{}_{}".format(obj, debug_im_id.split("/")[1], split)
+            if name not in SPLITS_LM:
+                SPLITS_LM[name] = dict(
+                    name=name,
+                    dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+                    models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+                    objs=[obj],  # only this obj
+                    ann_files=ann_files,
+                    image_prefixes=[
+                        osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[obj])
+                    ],
+                    xyz_prefixes=[
+                        osp.join(
+                            DATASETS_ROOT,
+                            "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[obj]),
+                        )
+                    ],
+                    scale_to_meter=0.001,
+                    with_masks=True,  # (load masks but may not use it)
+                    with_depth=True,  # (load depth path here, but may not use it)
+                    height=480,
+                    width=640,
+                    cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                    use_cache=True,
+                    num_to_load=-1,
+                    filter_invalid=False,
+                    filter_scene=True,
+                    ref_key="lm_full",
+                    debug_im_id=debug_im_id,  # NOTE: debug im id
+                )
+
 
 def register_with_name_cfg(name, data_cfg=None):
     """Assume pre-defined datasets live in `./datasets`.
@@ -620,8 +752,8 @@ def register_with_name_cfg(name, data_cfg=None):
         ref_key=used_cfg["ref_key"],
         objs=used_cfg["objs"],
         eval_error_types=["ad", "rete", "proj"],
-        evaluator_type="coco_bop",
-        **get_lm_metadata(obj_names=used_cfg["objs"]),
+        evaluator_type="bop",
+        **get_lm_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
     )
 
 
@@ -664,38 +796,91 @@ def test_vis():
         cat_ids = [anno["category_id"] for anno in annos]
         K = d["cam"]
         kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
         labels = [objs[cat_id] for cat_id in cat_ids]
-        img_vis = vis_image_mask_bbox_cv2(img, masks, bboxes=bboxes_xyxy, labels=labels)
-        img_vis_kpts2d = img.copy()
-        for anno_i in range(len(annos)):
-            img_vis_kpts2d = misc.draw_projected_box3d(img_vis_kpts2d, kpts_2d[anno_i])
-        grid_show(
-            [
-                img[:, :, [2, 1, 0]],
-                img_vis[:, :, [2, 1, 0]],
-                img_vis_kpts2d[:, :, [2, 1, 0]],
-                depth,
-            ],
-            [f"img:{d['file_name']}", "vis_img", "img_vis_kpts2d", "depth"],
-            row=2,
-            col=2,
-        )
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
 
 
 if __name__ == "__main__":
     """Test the  dataset loader.
 
-    Usage:
-        python -m det.yolov4.datasets.lm_dataset_d2 dataset_name
+    python this_file.py dataset_name
     """
     from lib.vis_utils.image import grid_show
     from lib.utils.setup_logger import setup_my_logger
+
     import detectron2.data.datasets  # noqa # add pre-defined metadata
-    from core.utils.data_utils import read_image_mmcv
     from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
 
     print("sys.argv:", sys.argv)
     logger = setup_my_logger(name="core")
     register_with_name_cfg(sys.argv[1])
     print("dataset catalog: ", DatasetCatalog.list())
+
     test_vis()
-- 
GitLab