diff --git a/configs/yolox/bop_pbr/yolox_base.py b/configs/yolox/bop_pbr/yolox_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bac49fb1a2864a9aca43d29a881eed5afc149f1
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_base.py
@@ -0,0 +1,229 @@
+from itertools import count
+import os
+import os.path as osp
+from omegaconf import OmegaConf
+
+import torch
+import detectron2.data.transforms as T
+from detectron2.config import LazyCall as L
+from detectron2.data import get_detection_dataset_dicts
+from detectron2.solver.build import get_default_optimizer_params
+
+# import torch.nn as nn
+
+from det.yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+from det.yolox.data import (
+    # COCODataset,
+    TrainTransform,
+    ValTransform,
+    # YoloBatchSampler,
+    # DataLoader,
+    # InfiniteSampler,
+    MosaicDetection,
+    build_yolox_train_loader,
+    build_yolox_test_loader,
+)
+from det.yolox.data.datasets import Base_DatasetFromList
+from det.yolox.utils import LRScheduler
+
+# from detectron2.evaluation import COCOEvaluator
+# from det.yolox.evaluators import COCOEvaluator
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.lr_scheduler import flat_and_anneal_lr_scheduler
+
+
+# Common training-related configs that are designed for "tools/lazyconfig_train_net.py"
+# You can use your own instead, together with your own train_net.py
+train = dict(
+    # NOTE: need to copy these two lines to get correct dirs
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+    seed=-1,
+    cudnn_deterministic=False,
+    cudnn_benchmark=True,
+    init_checkpoint="",
+    # init_checkpoint="pretrained_models/yolox/yolox_s.pth",
+    resume_from="",
+    # init_checkpoint="detectron2://ImageNetPretrained/MSRA/R-50.pkl",
+    # max_iter=90000,
+    amp=dict(  # options for Automatic Mixed Precision
+        enabled=True,
+    ),
+    grad_clip=dict(  # options for grad clipping
+        enabled=False,
+        clip_type="full_model",  # value, norm, full_model
+        clip_value=1.0,
+        norm_type=2.0,
+    ),
+    ddp=dict(  # options for DistributedDataParallel
+        broadcast_buffers=False,
+        find_unused_parameters=False,
+        fp16_compression=False,
+    ),
+    # NOTE: epoch based period
+    checkpointer=dict(period=1, max_to_keep=10),  # options for PeriodicCheckpointer
+    # eval_period=5000,
+    eval_period=-1,  # epoch based
+    log_period=20,
+    device="cuda",
+    # ...
+    basic_lr_per_img=0.01 / 64.0,  # 1.5625e-4
+    random_size=(14, 26),  # set None to disable; randomly choose a int in this range, and *32
+    mscale=(0.8, 1.6),
+    ema=True,
+    total_epochs=16,
+    warmup_epochs=5,
+    no_aug_epochs=2,
+    sync_norm_period=10,  # sync norm every n epochs
+    # l1 loss:
+    # 1) if use_l1 and l1_from_sctrach: use l1 for the whole training phase
+    # 2) use_l1=False: no l1 at all
+    # 3) use_l1 and l1_from_scratch=False: just use l1 after closing mosaic (YOLOX default)
+    l1_from_scratch=False,
+    use_l1=True,
+    anneal_after_warmup=True,
+    # ...
+    occupy_gpu=False,
+)
+train = OmegaConf.create(train)
+
+
+# OmegaConf.register_new_resolver(
+#      "mul2", lambda x: x*2
+# )
+
+# --------------------------------------------------------------------
+# model
+# --------------------------------------------------------------------
+model = L(YOLOX)(
+    backbone=L(YOLOPAFPN)(
+        depth=1.0,
+        width=1.0,
+        in_channels=[256, 512, 1024],
+    ),
+    head=L(YOLOXHead)(
+        num_classes=1,
+        width="${..backbone.width}",
+        # width="${mul2: ${..backbone.width}}",  # NOTE: do not forget $
+        in_channels="${..backbone.in_channels}",
+    ),
+)
+
+# --------------------------------------------------------------------
+# optimizer
+# --------------------------------------------------------------------
+optimizer = L(torch.optim.SGD)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.01,  # bs=64
+    momentum=0.9,
+    weight_decay=5e-4,
+    nesterov=True,
+)
+
+
+lr_config = L(flat_and_anneal_lr_scheduler)(
+    warmup_method="pow",
+    warmup_pow=2,
+    warmup_factor=0.0,
+    # to be set
+    # optimizer=
+    # total_iters=total_iters,  # to be set
+    # warmup_iters=epoch_len * 3,
+    # anneal_point=5 / (total_epochs - 15),
+    anneal_method="cosine",
+    target_lr_factor=0.05,
+)
+
+
+DATASETS = dict(TRAIN=("",), TEST=("",))
+DATASETS = OmegaConf.create(DATASETS)
+
+
+dataloader = OmegaConf.create()
+dataloader.train = L(build_yolox_train_loader)(
+    dataset=L(Base_DatasetFromList)(
+        split="train",
+        lst=L(get_detection_dataset_dicts)(names=DATASETS.TRAIN),
+        img_size=(640, 640),
+        preproc=L(TrainTransform)(
+            max_labels=50,
+        ),
+    ),
+    aug_wrapper=L(MosaicDetection)(
+        mosaic=True,
+        img_size="${..dataset.img_size}",
+        preproc=L(TrainTransform)(
+            max_labels=120,
+        ),
+        degrees=10.0,
+        translate=0.1,
+        mosaic_scale=(0.1, 2),
+        mixup_scale=(0.5, 1.5),
+        shear=2.0,
+        enable_mixup=True,
+        mosaic_prob=1.0,
+        mixup_prob=1.0,
+    ),
+    # reference_batch_size=64,
+    total_batch_size=64,  # 8x8gpu
+    num_workers=4,
+    pin_memory=True,
+)
+
+
+val = dict(
+    eval_cached=False,
+)
+val = OmegaConf.create(val)
+
+
+test = dict(
+    test_dataset_names=DATASETS.TEST,
+    test_size=(640, 640),  # (height, width)
+    conf_thr=0.01,
+    nms_thr=0.65,
+    num_classes="${model.head.num_classes}",
+    amp_test=False,
+    half_test=True,
+    precise_bn=dict(
+        enabled=False,
+        num_iter=200,
+    ),
+    # fuse_conv_bn=False,
+    fuse_conv_bn=True,
+)
+test = OmegaConf.create(test)
+
+
+# NOTE: for multiple test loaders, just write it as a list
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        # total_batch_size=1,
+        total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_hb_pbr_hb_test_primesense_bop19.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_hb_pbr_hb_test_primesense_bop19.py
new file mode 100644
index 0000000000000000000000000000000000000000..db0b4d2ea7d42ef2a4aa482178304d3814c3b58c
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_hb_pbr_hb_test_primesense_bop19.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 33
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["hb_pbr_train"]
+DATASETS.TEST = ["hb_test_primesense_bop19"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_icbin_pbr_icbin_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_icbin_pbr_icbin_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8e8aa763e8b16f10c88c360ebd1a571ae184f50
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_icbin_pbr_icbin_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 2
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["icbin_pbr_train"]
+DATASETS.TEST = ["icbin_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_itodd_pbr_itodd_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_itodd_pbr_itodd_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a41f8cd5e74fbcc588ae6ad058c1453e298e00c
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_itodd_pbr_itodd_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 28
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["itodd_pbr_train"]
+DATASETS.TEST = ["itodd_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_lmo_pbr_lmo_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_lmo_pbr_lmo_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f33a2b4fac8d974d4a7dcc44a2071e8025ee06ca
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_lmo_pbr_lmo_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 8
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["lmo_pbr_train"]
+DATASETS.TEST = ["lmo_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_pbr_tless_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_pbr_tless_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..570537187e9ac5259063b05dedaea041c9f88a7b
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_pbr_tless_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 30
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["tless_pbr_train"]
+DATASETS.TEST = ["tless_bop_test_primesense"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_real_pbr_tless_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_real_pbr_tless_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b69fe5fc39223a565997c6ccfd37b899e30a02f8
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tless_real_pbr_tless_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 30
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["tless_pbr_train", "tless_primesense_train"]
+DATASETS.TEST = ["tless_bop_test_primesense"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_pbr_tudl_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_pbr_tudl_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd6ab1e4a89f29b693916ba167ce2458bcffdc03
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_pbr_tudl_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 3
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["tudl_pbr_train"]
+DATASETS.TEST = ["tudl_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_real_pbr_tudl_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_real_pbr_tudl_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..d21ae408249ea0c8e9e6efca00fe26e2d2b6f1e1
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_tudl_real_pbr_tudl_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 3
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["tudl_pbr_train", "tudl_train_real"]
+DATASETS.TEST = ["tudl_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_pbr_ycbv_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_pbr_ycbv_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c522d21eae2d60a800f3f0f4c8c4cf52b1922ef0
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_pbr_ycbv_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 21
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["ycbv_train_pbr"]
+DATASETS.TEST = ["ycbv_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_real_pbr_ycbv_bop_test.py b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_real_pbr_ycbv_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f228c7bcb5c67919970eb61945672923b9580e95
--- /dev/null
+++ b/configs/yolox/bop_pbr/yolox_x_640_augCozyAAEhsv_ranger_30_epochs_ycbv_real_pbr_ycbv_bop_test.py
@@ -0,0 +1,112 @@
+import os.path as osp
+
+import torch
+from detectron2.config import LazyCall as L
+from detectron2.solver.build import get_default_optimizer_params
+
+from .yolox_base import train, val, test, model, dataloader, optimizer, lr_config, DATASETS  # noqa
+from det.yolox.data import build_yolox_test_loader, ValTransform
+from det.yolox.data.datasets import Base_DatasetFromList
+from detectron2.data import get_detection_dataset_dicts
+from det.yolox.evaluators import YOLOX_COCOEvaluator
+from lib.torch_utils.solver.ranger import Ranger
+
+train.update(
+    output_dir=osp.abspath(__file__).replace("configs", "output", 1)[0:-3],
+    exp_name=osp.split(osp.abspath(__file__))[1][0:-3],  # .py
+)
+train.amp.enabled = True
+
+model.backbone.depth = 1.33
+model.backbone.width = 1.25
+
+model.head.num_classes = 21
+
+train.init_checkpoint = "pretrained_models/yolox/yolox_x.pth"
+
+# datasets
+DATASETS.TRAIN = ["ycbv_train_pbr", "ycbv_train_real"]
+DATASETS.TEST = ["ycbv_bop_test"]
+
+dataloader.train.dataset.lst.names = DATASETS.TRAIN
+dataloader.train.total_batch_size = 32
+
+# color aug
+dataloader.train.aug_wrapper.COLOR_AUG_PROB = 0.8
+dataloader.train.aug_wrapper.COLOR_AUG_TYPE = "code"
+dataloader.train.aug_wrapper.COLOR_AUG_CODE = (
+    "Sequential(["
+    # Sometimes(0.5, PerspectiveTransform(0.05)),
+    # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+    # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+    "Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),"
+    "Sometimes(0.4, GaussianBlur((0., 3.))),"
+    "Sometimes(0.3, pillike.EnhanceSharpness(factor=(0., 50.))),"
+    "Sometimes(0.3, pillike.EnhanceContrast(factor=(0.2, 50.))),"
+    "Sometimes(0.5, pillike.EnhanceBrightness(factor=(0.1, 6.))),"
+    "Sometimes(0.3, pillike.EnhanceColor(factor=(0., 20.))),"
+    "Sometimes(0.5, Add((-25, 25), per_channel=0.3)),"
+    "Sometimes(0.3, Invert(0.2, per_channel=True)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),"
+    "Sometimes(0.5, Multiply((0.6, 1.4))),"
+    "Sometimes(0.1, AdditiveGaussianNoise(scale=10, per_channel=True)),"
+    "Sometimes(0.5, iaa.contrast.LinearContrast((0.5, 2.2), per_channel=0.3)),"
+    # "Sometimes(0.5, Grayscale(alpha=(0.0, 1.0))),"  # maybe remove for det
+    "], random_order=True)"
+    # cosy+aae
+)
+
+# hsv color aug
+dataloader.train.aug_wrapper.AUG_HSV_PROB = 1.0
+dataloader.train.aug_wrapper.HSV_H = 0.015
+dataloader.train.aug_wrapper.HSV_S = 0.7
+dataloader.train.aug_wrapper.HSV_V = 0.4
+dataloader.train.aug_wrapper.FORMAT = "RGB"
+
+optimizer = L(Ranger)(
+    params=L(get_default_optimizer_params)(
+        # params.model is meant to be set to the model object, before instantiating
+        # the optimizer.
+        weight_decay_norm=0.0,
+        weight_decay_bias=0.0,
+    ),
+    lr=0.001,  # bs=64
+    # momentum=0.9,
+    weight_decay=0,
+    # nesterov=True,
+)
+
+train.total_epochs = 30
+train.no_aug_epochs = 15
+train.checkpointer = dict(period=2, max_to_keep=10)
+
+test.test_dataset_names = DATASETS.TEST
+test.augment = True
+test.scales = (1, 0.75, 0.83, 1.12, 1.25)
+test.conf_thr = 0.001
+
+dataloader.test = [
+    L(build_yolox_test_loader)(
+        dataset=L(Base_DatasetFromList)(
+            split="test",
+            lst=L(get_detection_dataset_dicts)(names=test_dataset_name, filter_empty=False),
+            img_size="${test.test_size}",
+            preproc=L(ValTransform)(
+                legacy=False,
+            ),
+        ),
+        total_batch_size=1,
+        # total_batch_size=64,
+        num_workers=4,
+        pin_memory=True,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
+
+dataloader.evaluator = [
+    L(YOLOX_COCOEvaluator)(
+        dataset_name=test_dataset_name,
+        filter_scene=False,
+    )
+    for test_dataset_name in test.test_dataset_names
+]
\ No newline at end of file
diff --git a/det/__init__.py b/det/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/det/yolox/__init__.py b/det/yolox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cbc411d419c55098e7d4e24ff0f21caaaf10a1f
--- /dev/null
+++ b/det/yolox/__init__.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .utils import configure_module
+
+configure_module()
+
+__version__ = "0.1.0"
diff --git a/det/yolox/data/__init__.py b/det/yolox/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c15c45ed1369e22537cba68e0eddfcf7da10625
--- /dev/null
+++ b/det/yolox/data/__init__.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .data_augment import TrainTransform, ValTransform
+from .data_prefetcher import DataPrefetcher
+from .dataloading import (
+    DataLoader,
+    build_yolox_train_loader,
+    build_yolox_batch_data_loader,
+    build_yolox_test_loader,
+)
+from .dataloading import yolox_worker_init_reset_seed as worker_init_reset_seed
+from .datasets import *
+from .samplers import InfiniteSampler, YoloBatchSampler
diff --git a/det/yolox/data/data_augment.py b/det/yolox/data/data_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8d3fcf214f6980db303f0095ea45a049e07a680
--- /dev/null
+++ b/det/yolox/data/data_augment.py
@@ -0,0 +1,259 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+"""Data augmentation functionality. Passed as callable transformations to
+Dataset classes.
+
+The data augmentation procedures were interpreted from @weiliu89's SSD
+paper http://arxiv.org/abs/1512.02325
+"""
+
+import math
+import random
+
+import cv2
+import numpy as np
+
+from det.yolox.utils import xyxy2cxcywh
+
+
+def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5, source_format="BGR"):
+    r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
+    if source_format == "RGB":
+        hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_RGB2HSV))
+    else:  # default BGR
+        hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
+    dtype = img.dtype  # uint8
+
+    x = np.arange(0, 256, dtype=np.int16)
+    lut_hue = ((x * r[0]) % 180).astype(dtype)
+    lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+    lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
+    if source_format == "RGB":
+        cv2.cvtColor(img_hsv, cv2.COLOR_HSV2RGB, dst=img)  # no return needed
+    else:
+        cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+    # Histogram equalization
+    # if random.random() < 0.2:
+    #     for i in range(3):
+    #         img[:, :, i] = cv2.equalizeHist(img[:, :, i])
+
+
+# def augment_hsv(img, hgain=5, sgain=30, vgain=30):
+#     hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain]  # random gains
+#     hsv_augs *= np.random.randint(0, 2, 3)  # random selection of h, s, v
+#     hsv_augs = hsv_augs.astype(np.int16)
+#     img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
+
+#     img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
+#     img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
+#     img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
+
+#     cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+
+
+def get_aug_params(value, center=0):
+    if isinstance(value, float):
+        return random.uniform(center - value, center + value)
+    elif len(value) == 2:
+        return random.uniform(value[0], value[1])
+    else:
+        raise ValueError(
+            "Affine params should be either a sequence containing two values\
+                          or single float values. Got {}".format(
+                value
+            )
+        )
+
+
+def get_affine_matrix(
+    target_size,
+    degrees=10,
+    translate=0.1,
+    scales=0.1,
+    shear=10,
+):
+    twidth, theight = target_size
+
+    # Rotation and Scale
+    angle = get_aug_params(degrees)
+    scale = get_aug_params(scales, center=1.0)
+
+    if scale <= 0.0:
+        raise ValueError("Argument scale should be positive")
+
+    R = cv2.getRotationMatrix2D(angle=angle, center=(0, 0), scale=scale)
+
+    M = np.ones([2, 3])
+    # Shear
+    shear_x = math.tan(get_aug_params(shear) * math.pi / 180)
+    shear_y = math.tan(get_aug_params(shear) * math.pi / 180)
+
+    M[0] = R[0] + shear_y * R[1]
+    M[1] = R[1] + shear_x * R[0]
+
+    # Translation
+    translation_x = get_aug_params(translate) * twidth  # x translation (pixels)
+    translation_y = get_aug_params(translate) * theight  # y translation (pixels)
+
+    M[0, 2] = translation_x
+    M[1, 2] = translation_y
+
+    return M, scale
+
+
+def apply_affine_to_bboxes(targets, target_size, M, scale):
+    num_gts = len(targets)
+
+    # warp corner points
+    twidth, theight = target_size
+    corner_points = np.ones((4 * num_gts, 3))
+    corner_points[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(4 * num_gts, 2)  # x1y1, x2y2, x1y2, x2y1
+    corner_points = corner_points @ M.T  # apply affine transform
+    corner_points = corner_points.reshape(num_gts, 8)
+
+    # create new boxes
+    corner_xs = corner_points[:, 0::2]
+    corner_ys = corner_points[:, 1::2]
+    new_bboxes = (
+        np.concatenate((corner_xs.min(1), corner_ys.min(1), corner_xs.max(1), corner_ys.max(1))).reshape(4, num_gts).T
+    )
+
+    # clip boxes
+    new_bboxes[:, 0::2] = new_bboxes[:, 0::2].clip(0, twidth)
+    new_bboxes[:, 1::2] = new_bboxes[:, 1::2].clip(0, theight)
+
+    targets[:, :4] = new_bboxes
+
+    return targets
+
+
+def random_affine(
+    img,
+    targets=(),
+    target_size=(640, 640),
+    degrees=10,
+    translate=0.1,
+    scales=0.1,
+    shear=10,
+):
+    M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
+
+    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
+
+    # Transform label coordinates
+    if len(targets) > 0:
+        targets = apply_affine_to_bboxes(targets, target_size, M, scale)
+
+    return img, targets
+
+
+def _mirror(image, boxes, prob=0.5):
+    _, width, _ = image.shape
+    if random.random() < prob:
+        image = image[:, ::-1]
+        boxes[:, 0::2] = width - boxes[:, 2::-2]
+    return image, boxes
+
+
+def preproc(img, input_size, swap=(2, 0, 1)):
+    if len(img.shape) == 3:
+        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+    else:
+        padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+    resized_img = cv2.resize(
+        img,
+        (int(img.shape[1] * r), int(img.shape[0] * r)),
+        interpolation=cv2.INTER_LINEAR,
+    ).astype(np.uint8)
+    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+    padded_img = padded_img.transpose(swap)
+    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+    return padded_img, r
+
+
+class TrainTransform:
+    def __init__(self, max_labels=50, flip_prob=0.5, hsv_prob=1.0):
+        self.max_labels = max_labels
+        self.flip_prob = flip_prob
+        self.hsv_prob = hsv_prob
+
+    def __call__(self, image, targets, input_dim):
+        boxes = targets[:, :4].copy()
+        labels = targets[:, 4].copy()
+        if len(boxes) == 0:
+            targets = np.zeros((self.max_labels, 5), dtype=np.float32)
+            image, r_o = preproc(image, input_dim)
+            return image, targets
+
+        image_o = image.copy()
+        targets_o = targets.copy()
+        height_o, width_o, _ = image_o.shape
+        boxes_o = targets_o[:, :4]
+        labels_o = targets_o[:, 4]
+        # bbox_o: [xyxy] to [c_x,c_y,w,h]
+        boxes_o = xyxy2cxcywh(boxes_o)
+
+        if random.random() < self.hsv_prob:
+            augment_hsv(image)
+        image_t, boxes = _mirror(image, boxes, self.flip_prob)
+        height, width, _ = image_t.shape
+        image_t, r_ = preproc(image_t, input_dim)
+        # boxes [xyxy] 2 [cx,cy,w,h]
+        boxes = xyxy2cxcywh(boxes)
+        boxes *= r_
+
+        mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1
+        boxes_t = boxes[mask_b]
+        labels_t = labels[mask_b]
+
+        if len(boxes_t) == 0:
+            image_t, r_o = preproc(image_o, input_dim)
+            boxes_o *= r_o
+            boxes_t = boxes_o
+            labels_t = labels_o
+
+        labels_t = np.expand_dims(labels_t, 1)
+
+        targets_t = np.hstack((labels_t, boxes_t))
+        padded_labels = np.zeros((self.max_labels, 5))
+        padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[: self.max_labels]
+        padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
+        return image_t, padded_labels
+
+
+class ValTransform:
+    """Defines the transformations that should be applied to test PIL image for
+    input into the network.
+
+    dimension -> tensorize -> color adj
+
+    Arguments:
+        resize (int): input dimension to SSD
+        rgb_means ((int,int,int)): average RGB of the dataset
+            (104,117,123)
+        swap ((int,int,int)): final order of channels
+
+    Returns:
+        transform (transform) : callable transform to be applied to test/val
+        data
+    """
+
+    def __init__(self, swap=(2, 0, 1), legacy=False):
+        self.swap = swap
+        self.legacy = legacy
+
+    # assume input is cv2 img for now
+    def __call__(self, img, res, input_size):
+        img, _ = preproc(img, input_size, self.swap)
+        if self.legacy:
+            img = img[::-1, :, :].copy()
+            img /= 255.0
+            img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
+            img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
+        return img, np.zeros((1, 5))
diff --git a/det/yolox/data/data_prefetcher.py b/det/yolox/data/data_prefetcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ad491858155b55302c1ea4c177ce7cb479e64c5
--- /dev/null
+++ b/det/yolox/data/data_prefetcher.py
@@ -0,0 +1,50 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import torch
+
+
+class DataPrefetcher:
+    """DataPrefetcher is inspired by code of following file:
+
+    https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
+    It could speedup your pytorch dataloader. For more information, please check
+    https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
+    """
+
+    def __init__(self, loader):
+        self.loader_iter = iter(loader)
+        self.stream = torch.cuda.Stream()
+        self.input_cuda = self._input_cuda_for_image
+        self.record_stream = DataPrefetcher._record_stream_for_image
+        self.preload()
+
+    def preload(self):
+        try:
+            self.next_input, self.next_target, _, _, _ = next(self.loader_iter)
+        except StopIteration:
+            self.next_input = None
+            self.next_target = None
+            return
+
+        with torch.cuda.stream(self.stream):
+            self.input_cuda()
+            self.next_target = self.next_target.cuda(non_blocking=True)
+
+    def next(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        input = self.next_input
+        target = self.next_target
+        if input is not None:
+            self.record_stream(input)
+        if target is not None:
+            target.record_stream(torch.cuda.current_stream())
+        self.preload()
+        return input, target
+
+    def _input_cuda_for_image(self):
+        self.next_input = self.next_input.cuda(non_blocking=True)
+
+    @staticmethod
+    def _record_stream_for_image(input):
+        input.record_stream(torch.cuda.current_stream())
diff --git a/det/yolox/data/dataloading.py b/det/yolox/data/dataloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..e186356f9967679c3a3978a1f4dc24d664f98649
--- /dev/null
+++ b/det/yolox/data/dataloading.py
@@ -0,0 +1,275 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import os
+import random
+import uuid
+
+import numpy as np
+
+import torch
+from torch.utils.data.dataloader import DataLoader as torchDataLoader
+from torch.utils.data.dataloader import default_collate
+import operator
+
+from detectron2.data.build import (
+    AspectRatioGroupedDataset,
+    worker_init_reset_seed,
+    trivial_batch_collator,
+    InferenceSampler,
+)
+
+from core.utils.my_comm import get_world_size
+
+from .samplers import YoloBatchSampler, InfiniteSampler
+
+# from .datasets import Base_DatasetFromList
+
+
+class DataLoader(torchDataLoader):
+    """Lightnet dataloader that enables on the fly resizing of the images.
+
+    See :class:`torch.utils.data.DataLoader` for more information on the arguments.
+    Check more on the following website:
+    https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.__initialized = False
+        shuffle = False
+        batch_sampler = None
+        if len(args) > 5:
+            shuffle = args[2]
+            sampler = args[3]
+            batch_sampler = args[4]
+        elif len(args) > 4:
+            shuffle = args[2]
+            sampler = args[3]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+        elif len(args) > 3:
+            shuffle = args[2]
+            if "sampler" in kwargs:
+                sampler = kwargs["sampler"]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+        else:
+            if "shuffle" in kwargs:
+                shuffle = kwargs["shuffle"]
+            if "sampler" in kwargs:
+                sampler = kwargs["sampler"]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+
+        # Use custom BatchSampler
+        if batch_sampler is None:
+            if sampler is None:
+                if shuffle:
+                    sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
+                    # sampler = torch.utils.data.DistributedSampler(self.dataset)
+                else:
+                    sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
+            batch_sampler = YoloBatchSampler(
+                sampler,
+                self.batch_size,
+                self.drop_last,
+                input_dimension=self.dataset.input_dim,
+            )
+            # batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
+
+        self.batch_sampler = batch_sampler
+
+        self.__initialized = True
+
+    def close_mosaic(self):
+        self.batch_sampler.mosaic = False
+
+
+# def list_collate(batch):
+#     """
+#     Function that collates lists or tuples together into one list (of lists/tuples).
+#     Use this as the collate function in a Dataloader, if you want to have a list of
+#     items as an output, as opposed to tensors (eg. Brambox.boxes).
+#     """
+#     items = list(zip(*batch))
+
+#     for i in range(len(items)):
+#         if isinstance(items[i][0], (list, tuple)):
+#             items[i] = list(items[i])
+#         else:
+#             items[i] = default_collate(items[i])
+
+#     return items
+
+
+def build_yolox_batch_data_loader(
+    dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0, pin_memory=False
+):
+    """
+    Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
+    1. support aspect ratio grouping options
+    2. use no "batch collation", because this is common for detection training
+
+    Args:
+        dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
+        sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
+        total_batch_size, aspect_ratio_grouping, num_workers): see
+            :func:`build_detection_train_loader`.
+
+    Returns:
+        iterable[list]. Length of each list is the batch size of the current
+            GPU. Each element in the list comes from the dataset.
+    """
+    world_size = get_world_size()
+    assert (
+        total_batch_size > 0 and total_batch_size % world_size == 0
+    ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(total_batch_size, world_size)
+
+    batch_size = total_batch_size // world_size
+    if aspect_ratio_grouping:
+        data_loader = torch.utils.data.DataLoader(
+            dataset,
+            sampler=sampler,
+            num_workers=num_workers,
+            batch_sampler=None,
+            collate_fn=operator.itemgetter(0),  # don't batch, but yield individual elements
+            worker_init_fn=worker_init_reset_seed,
+        )  # yield individual mapped dict
+        return AspectRatioGroupedDataset(data_loader, batch_size)
+    else:
+        # batch_sampler = torch.utils.data.sampler.BatchSampler(
+        #     sampler, batch_size, drop_last=True
+        # )  # drop_last so the batch always have the same size
+        if hasattr(dataset, "enable_mosaic"):
+            mosaic = dataset.enable_mosaic
+        else:
+            mosaic = False
+        batch_sampler = YoloBatchSampler(
+            mosaic=mosaic,
+            sampler=sampler,
+            batch_size=batch_size,
+            drop_last=False,  # NOTE: different to d2
+            # input_dimension=dataset.input_dim,
+        )
+        return DataLoader(
+            dataset,
+            num_workers=num_workers,
+            batch_sampler=batch_sampler,
+            # collate_fn=trivial_batch_collator,  # TODO: use this when item is changed to dict
+            worker_init_fn=worker_init_reset_seed,
+            pin_memory=pin_memory,
+        )
+
+
+def build_yolox_train_loader(
+    dataset,
+    *,
+    aug_wrapper,
+    total_batch_size,
+    sampler=None,
+    aspect_ratio_grouping=False,
+    num_workers=0,
+    pin_memory=False,
+    seed=None
+):
+    """Build a dataloader for object detection with some default features. This
+    interface is experimental.
+
+    Args:
+        dataset (torch.utils.data.Dataset): Base_DatasetFromList
+        aug_wrapper (callable): MosaciDetection
+        total_batch_size (int): total batch size across all workers. Batching
+            simply puts data into a list.
+        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
+            indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
+            which coordinates an infinite random shuffle sequence across all workers.
+        aspect_ratio_grouping (bool): whether to group images with similar
+            aspect ratio for efficiency. When enabled, it requires each
+            element in dataset be a dict with keys "width" and "height".
+        num_workers (int): number of parallel data loading workers
+
+    Returns:
+        torch.utils.data.DataLoader:
+            a dataloader. Each output from it is a ``list[mapped_element]`` of length
+            ``total_batch_size / num_workers``, where ``mapped_element`` is produced
+            by the ``mapper``.
+    """
+
+    if aug_wrapper is not None:
+        # MosaicDetection (mosaic, mixup, other augs)
+        dataset = aug_wrapper.init_dataset(dataset)
+
+    if sampler is None:
+        # sampler = TrainingSampler(len(dataset))
+        sampler = InfiniteSampler(len(dataset), seed=0 if seed is None else seed)
+    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
+    return build_yolox_batch_data_loader(
+        dataset,
+        sampler,
+        total_batch_size,
+        aspect_ratio_grouping=aspect_ratio_grouping,
+        num_workers=num_workers,
+        pin_memory=pin_memory,
+    )
+
+
+def build_yolox_test_loader(
+    dataset, *, aug_wrapper=None, total_batch_size=1, sampler=None, num_workers=0, pin_memory=False
+):
+    """Similar to `build_detection_train_loader`, but uses a batch size of 1,
+    and :class:`InferenceSampler`. This sampler coordinates all workers to
+    produce the exact set of all samples. This interface is experimental.
+
+    Args:
+        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
+            or a map-style pytorch dataset. They can be obtained by using
+            :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+        aug_wrapper (callable): MosaciDetection
+        total_batch_size (int): total batch size across all workers. Batching
+            simply puts data into a list. Default test batch size is 1.
+        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
+            indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
+            which splits the dataset across all workers.
+        num_workers (int): number of parallel data loading workers
+
+    Returns:
+        DataLoader: a torch DataLoader, that loads the given detection
+        dataset, with test-time transformation and batching.
+
+    Examples:
+    ::
+        data_loader = build_detection_test_loader(
+            DatasetRegistry.get("my_test"),
+            mapper=DatasetMapper(...))
+
+        # or, instantiate with a CfgNode:
+        data_loader = build_detection_test_loader(cfg, "my_test")
+    """
+    if aug_wrapper is not None:
+        # MosaicDetection (mosaic, mixup, other augs)
+        dataset = aug_wrapper.init_dataset(dataset)
+
+    world_size = get_world_size()
+    batch_size = total_batch_size // world_size
+    if sampler is None:
+        sampler = InferenceSampler(len(dataset))
+    # Always use 1 image per worker during inference since this is the
+    # standard when reporting inference time in papers.
+    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last=False)
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        # batch_size=batch_size,
+        num_workers=num_workers,
+        batch_sampler=batch_sampler,
+        # collate_fn=trivial_batch_collator,
+        pin_memory=pin_memory,
+    )
+    return data_loader
+
+
+def yolox_worker_init_reset_seed(worker_id):
+    seed = uuid.uuid4().int % 2**32
+    random.seed(seed)
+    torch.set_rng_state(torch.manual_seed(seed).get_state())
+    np.random.seed(seed)
diff --git a/det/yolox/data/datasets/__init__.py b/det/yolox/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..47e5014298340941338970bb60d5e63d9c4db6a8
--- /dev/null
+++ b/det/yolox/data/datasets/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .coco import COCODataset
+from .coco_classes import COCO_CLASSES
+from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
+from .mosaicdetection import MosaicDetection
+from .voc import VOCDetection
+from .base_data_from_list import Base_DatasetFromList
diff --git a/det/yolox/data/datasets/base_data_from_list.py b/det/yolox/data/datasets/base_data_from_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9ba9c51008fe6010c006eae9469104271683294
--- /dev/null
+++ b/det/yolox/data/datasets/base_data_from_list.py
@@ -0,0 +1,612 @@
+# -*- coding: utf-8 -*-
+import copy
+import hashlib
+import logging
+import os
+import os.path as osp
+import random
+import cv2
+import mmcv
+import numpy as np
+import pickle
+from omegaconf import OmegaConf
+
+from detectron2.data import detection_utils as utils
+from detectron2.structures import BoxMode
+
+from core.utils.data_utils import resize_short_edge, read_image_mmcv
+from core.utils.augment import AugmentRGB
+from core.utils.dataset_utils import flat_dataset_dicts
+from lib.utils.utils import lazy_property
+
+from .datasets_wrapper import Dataset
+
+
+logger = logging.getLogger(__name__)
+
+
+default_input_cfg = OmegaConf.create(
+    dict(
+        img_format="BGR",
+        # depth
+        with_depth=False,
+        aug_depth=False,
+        # bg ----------------
+        bg_type="VOC_table",
+        bg_imgs_root="datasets/VOCdevkit/VOC2012/",
+        num_bg_imgs=10000,
+        change_bg_prob=0.0,  # prob to change bg of real image
+        bg_keep_aspect_ratio=True,
+        # truncation fg (randomly replace some side of fg with bg during replace_bg)
+        truncate_fg=False,
+        # color aug ---------------
+        color_aug_prob=0.0,
+        color_aug_type="AAE",
+        color_aug_code="",
+        # color normalization
+        pixel_mean=[0.0, 0.0, 0.0],  # to [0, 1]
+        pixel_std=[255.0, 255.0, 255.0],
+        # box aug
+        bbox_aug_type="",
+        bbox_aug_scale_ratio=1.0,
+        bbox_aug_shift_ratio=0.0,
+        # box aug dzi
+        dzi_type="none",  # uniform, truncnorm, none, roi10d
+        dzi_pad_scale=1.0,
+        dzi_scale_ratio=0.25,  # wh scale
+        dzi_shift_ratio=0.25,  # center shift
+    )
+)
+
+
+class Base_DatasetFromList(Dataset):
+    """# https://github.com/facebookresearch/detectron2/blob/master/detectron2/
+    data/common.py Wrap a list to a torch Dataset.
+
+    It produces elements of the list as data.
+    """
+
+    def __init__(
+        self,
+        split,
+        lst: list,
+        *,
+        cfg=default_input_cfg,
+        img_size=(416, 416),
+        preproc=None,
+        copy: bool = True,
+        serialize: bool = True,
+        flatten=False,
+    ):
+        """
+        Args:
+            lst (list): a list which contains elements to produce.
+            img_size (tuple): (h, w)
+            copy (bool): whether to deepcopy the element when producing it,
+                so that the result can be modified in place without affecting the
+                source in the list.
+            serialize (bool): whether to hold memory using serialized objects, when
+                enabled, data loader workers can use shared RAM from master
+                process instead of making a copy.
+        """
+        super().__init__(img_size)
+        self.cfg = cfg
+        self.img_size = img_size
+        self.preproc = preproc
+
+        self.split = split  # train | val | test
+        if split == "train" and cfg.color_aug_prob > 0:
+            self.color_augmentor = self._get_color_augmentor(aug_type=cfg.color_aug_type, aug_code=cfg.color_aug_code)
+        else:
+            self.color_augmentor = None
+        # --------------------------------------------------------
+        self._lst = flat_dataset_dicts(lst) if flatten else lst
+        self._copy = copy
+        self._serialize = serialize
+
+        def _serialize(data):
+            buffer = pickle.dumps(data, protocol=-1)
+            return np.frombuffer(buffer, dtype=np.uint8)
+
+        if self._serialize:
+            logger.info("Serializing {} elements to byte tensors and concatenating them all ...".format(len(self._lst)))
+            self._lst = [_serialize(x) for x in self._lst]
+            self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
+            self._addr = np.cumsum(self._addr)
+            self._lst = np.concatenate(self._lst)
+            logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
+
+    def __len__(self):
+        if self._serialize:
+            return len(self._addr)
+        else:
+            return len(self._lst)
+
+    def read_data(self, dataset_dict):
+        raise NotImplementedError("Not implemented")
+
+    def _rand_another(self, idx):
+        pool = [i for i in range(self.__len__()) if i != idx]
+        return np.random.choice(pool)
+
+    def load_anno(self, index):
+        # cfg = self.cfg
+        dataset_dict = self._get_sample_dict(index)
+        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
+        # im annos
+        width = dataset_dict["width"]
+        height = dataset_dict["height"]
+
+        # get target--------------------
+        if dataset_dict.get("annotations", None) != None:
+            annotations = dataset_dict["annotations"]
+            objs = []
+            for obj in annotations:  # filter instances by area ------------------
+                xyxy = BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS)
+                x1 = np.max((xyxy[0], 0))
+                y1 = np.max((xyxy[1], 0))
+                x2 = np.min((xyxy[2], width))
+                y2 = np.min((xyxy[3], height))
+                if "area" in obj:
+                    area = obj["area"]
+                else:
+                    area = (x2 - x1) * (y2 - y1)
+                if area > 0 and x2 >= x1 and y2 >= y1:
+                    obj["clean_bbox"] = [x1, y1, x2, y2]
+                    objs.append(obj)
+
+            num_objs = len(objs)
+
+            res = np.zeros((num_objs, 5))
+
+            for ix, obj in enumerate(objs):
+                _cls = obj["category_id"]  # 0-based
+                res[ix, 0:4] = obj["clean_bbox"]
+                res[ix, 4] = _cls
+
+            r = min(self.img_size[0] / height, self.img_size[1] / width)
+            res[:, :4] *= r
+        elif self.split == "train":
+            raise SystemExit("Failed to load labels.")
+        else:
+            r = min(self.img_size[0] / height, self.img_size[1] / width)
+            res = np.zeros((1, 5))
+
+        resized_info = (int(height * r), int(width * r))
+        return res, resized_info
+
+    def load_resized_img(self, file_name):
+        img = self.load_image(file_name)
+        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * r), int(img.shape[0] * r)),
+            interpolation=cv2.INTER_LINEAR,
+        ).astype(np.uint8)
+        return resized_img
+
+    def load_image(self, file_name):
+        img = read_image_mmcv(file_name, format=self.cfg.img_format)  # BGR
+        assert img is not None
+        return img
+
+    def pull_item(self, index):
+        """Returns the original image and target at an index for mixup.
+
+        Note: not using self.__getitem__(), as any transformations passed in
+        could mess up this functionality.
+
+        Argument:
+            index (int): index of img to show
+        Return:
+            img, target
+        """
+        cfg = self.cfg
+        dataset_dict = self._get_sample_dict(index)
+        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
+        file_name = dataset_dict["file_name"]
+        img = self.load_resized_img(file_name)
+
+        target, resized_info = self.load_anno(index)
+
+        width = dataset_dict["width"]
+        height = dataset_dict["height"]
+        img_info = (height, width)
+
+        scene_im_id = dataset_dict.get("scene_im_id", 0)
+
+        img_id = dataset_dict["image_id"]
+        return img, target.copy(), scene_im_id, img_info, np.array([img_id])
+
+    @Dataset.mosaic_getitem
+    def __getitem__(self, index):
+        img, target, scene_im_id, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+        # import ipdb; ipdb.set_trace()
+        return img, target, scene_im_id, img_info, img_id
+
+    def _get_sample_dict(self, idx):
+        if self._serialize:
+            start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+            end_addr = self._addr[idx].item()
+            bytes = memoryview(self._lst[start_addr:end_addr])
+            dataset_dict = pickle.loads(bytes)
+        elif self._copy:
+            dataset_dict = copy.deepcopy(self._lst[idx])
+        else:
+            dataset_dict = self._lst[idx]
+        return dataset_dict
+
+    def normalize_image(self, image):
+        # image: CHW format
+        cfg = self.cfg
+        pixel_mean = np.array(cfg.pixel_mean).reshape(-1, 1, 1)
+        pixel_std = np.array(cfg.pixel_std).reshape(-1, 1, 1)
+        return (image - pixel_mean) / pixel_std
+
+    def aug_bbox_non_square(self, bbox_xyxy, im_H, im_W):
+        """Similar to DZI, but the resulted bbox is not square, and not enlarged
+        Args:
+            cfg (ConfigDict):
+            bbox_xyxy (np.ndarray): (4,)
+            im_H (int):
+            im_W (int):
+        Returns:
+             augmented bbox (ndarray)
+        """
+        cfg = self.cfg
+        x1, y1, x2, y2 = bbox_xyxy.copy()
+        cx = 0.5 * (x1 + x2)
+        cy = 0.5 * (y1 + y2)
+        bh = y2 - y1
+        bw = x2 - x1
+        if cfg.bbox_aug_type.lower() == "uniform":
+            # different to DZI: scale both w and h
+            scale_ratio = 1 + cfg.bbox_aug_scale_ratio * (2 * np.random.random_sample(2) - 1)  # [1-0.25, 1+0.25]
+            shift_ratio = cfg.bbox_aug_shift_ratio * (2 * np.random.random_sample(2) - 1)  # [-0.25, 0.25]
+            bbox_center = np.array([cx + bw * shift_ratio[0], cy + bh * shift_ratio[1]])  # (h/2, w/2)
+            new_bw = bw * scale_ratio[0]
+            new_bh = bh * scale_ratio[1]
+            x1 = min(max(bbox_center[0] - new_bw / 2, 0), im_W)
+            y1 = min(max(bbox_center[1] - new_bh / 2, 0), im_W)
+            x2 = min(max(bbox_center[0] + new_bw / 2, 0), im_W)
+            y2 = min(max(bbox_center[1] + new_bh / 2, 0), im_W)
+            bbox_auged = np.array([x1, y1, x2, y2])
+        elif cfg.bbox_aug_type.lower() == "roi10d":
+            # shift (x1,y1), (x2,y2) by 15% in each direction
+            _a = -0.15
+            _b = 0.15
+            x1 += bw * (np.random.rand() * (_b - _a) + _a)
+            x2 += bw * (np.random.rand() * (_b - _a) + _a)
+            y1 += bh * (np.random.rand() * (_b - _a) + _a)
+            y2 += bh * (np.random.rand() * (_b - _a) + _a)
+            x1 = min(max(x1, 0), im_W)
+            x2 = min(max(x1, 0), im_W)
+            y1 = min(max(y1, 0), im_H)
+            y2 = min(max(y2, 0), im_H)
+            bbox_auged = np.array([x1, y1, x2, y2])
+        elif cfg.bbox_aug_type.lower() == "truncnorm":
+            raise NotImplementedError("BBOX_AUG_TYPE truncnorm is not implemented yet.")
+        else:
+            bbox_auged = bbox_xyxy.copy()
+        return bbox_auged
+
+    def aug_bbox_DZI(self, cfg, bbox_xyxy, im_H, im_W):
+        """Used for DZI, the augmented box is a square (maybe enlarged)
+        Args:
+            bbox_xyxy (np.ndarray):
+        Returns:
+             center, scale
+        """
+        x1, y1, x2, y2 = bbox_xyxy.copy()
+        cx = 0.5 * (x1 + x2)
+        cy = 0.5 * (y1 + y2)
+        bh = y2 - y1
+        bw = x2 - x1
+        if cfg.dzi_type.lower() == "uniform":
+            scale_ratio = 1 + cfg.dzi_scale_ratio * (2 * np.random.random_sample() - 1)  # [1-0.25, 1+0.25]
+            shift_ratio = cfg.dzi_shift_ratio * (2 * np.random.random_sample(2) - 1)  # [-0.25, 0.25]
+            bbox_center = np.array([cx + bw * shift_ratio[0], cy + bh * shift_ratio[1]])  # (h/2, w/2)
+            scale = max(y2 - y1, x2 - x1) * scale_ratio * cfg.dzi_pad_scale
+        elif cfg.dzi_type.lower() == "roi10d":
+            # shift (x1,y1), (x2,y2) by 15% in each direction
+            _a = -0.15
+            _b = 0.15
+            x1 += bw * (np.random.rand() * (_b - _a) + _a)
+            x2 += bw * (np.random.rand() * (_b - _a) + _a)
+            y1 += bh * (np.random.rand() * (_b - _a) + _a)
+            y2 += bh * (np.random.rand() * (_b - _a) + _a)
+            x1 = min(max(x1, 0), im_W)
+            x2 = min(max(x1, 0), im_W)
+            y1 = min(max(y1, 0), im_H)
+            y2 = min(max(y2, 0), im_H)
+            bbox_center = np.array([0.5 * (x1 + x2), 0.5 * (y1 + y2)])
+            scale = max(y2 - y1, x2 - x1) * cfg.dzi_pad_scale
+        elif cfg.dzi_type.lower() == "truncnorm":
+            raise NotImplementedError("DZI truncnorm not implemented yet.")
+        else:
+            bbox_center = np.array([cx, cy])  # (w/2, h/2)
+            scale = max(y2 - y1, x2 - x1)
+        scale = min(scale, max(im_H, im_W)) * 1.0
+        return bbox_center, scale
+
+    def _get_color_augmentor(self, aug_type="ROI10D", aug_code=None):
+        # fmt: off
+        cfg = self.cfg
+        if aug_type.lower() == "roi10d":
+            color_augmentor = AugmentRGB(
+                brightness_delta=2.5 / 255.,  # 0,
+                lighting_std=0.3,
+                saturation_var=(0.95, 1.05),  # (1, 1),
+                contrast_var=(0.95, 1.05))  # (1, 1))  #
+        elif aug_type.lower() == "aae":
+            import imgaug.augmenters as iaa  # noqa
+            from imgaug.augmenters import (Sequential, SomeOf, OneOf, Sometimes, WithColorspace, WithChannels, Noop,
+                                           Lambda, AssertLambda, AssertShape, Scale, CropAndPad, Pad, Crop, Fliplr,
+                                           Flipud, Superpixels, ChangeColorspace, PerspectiveTransform, Grayscale,
+                                           GaussianBlur, AverageBlur, MedianBlur, Convolve, Sharpen, Emboss, EdgeDetect,
+                                           DirectedEdgeDetect, Add, AddElementwise, AdditiveGaussianNoise, Multiply,
+                                           MultiplyElementwise, Dropout, CoarseDropout, Invert, ContrastNormalization,
+                                           Affine, PiecewiseAffine, ElasticTransformation, pillike, LinearContrast)  # noqa
+            aug_code = """Sequential([
+                # Sometimes(0.5, PerspectiveTransform(0.05)),
+                # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+                # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+                Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),
+                Sometimes(0.5, GaussianBlur(1.2*np.random.rand())),
+                Sometimes(0.5, Add((-25, 25), per_channel=0.3)),
+                Sometimes(0.3, Invert(0.2, per_channel=True)),
+                Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),
+                Sometimes(0.5, Multiply((0.6, 1.4))),
+                Sometimes(0.5, LinearContrast((0.5, 2.2), per_channel=0.3))
+                ], random_order = False)"""
+            # for darker objects, e.g. LM driller: use BOOTSTRAP_RATIO: 16 and weaker augmentation
+            aug_code_weaker = """Sequential([
+                Sometimes(0.4, CoarseDropout( p=0.1, size_percent=0.05) ),
+                # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+                Sometimes(0.5, GaussianBlur(np.random.rand())),
+                Sometimes(0.5, Add((-20, 20), per_channel=0.3)),
+                Sometimes(0.4, Invert(0.20, per_channel=True)),
+                Sometimes(0.5, Multiply((0.7, 1.4), per_channel=0.8)),
+                Sometimes(0.5, Multiply((0.7, 1.4))),
+                Sometimes(0.5, LinearContrast((0.5, 2.0), per_channel=0.3))
+                ], random_order=False)"""
+            color_augmentor = eval(aug_code)
+        elif aug_type.lower() == "code":  # assume imgaug
+            import imgaug.augmenters as iaa
+            from imgaug.augmenters import (Sequential, SomeOf, OneOf, Sometimes, WithColorspace, WithChannels, Noop,
+                                           Lambda, AssertLambda, AssertShape, Scale, CropAndPad, Pad, Crop, Fliplr,
+                                           Flipud, Superpixels, ChangeColorspace, PerspectiveTransform, Grayscale,
+                                           GaussianBlur, AverageBlur, MedianBlur, Convolve, Sharpen, Emboss, EdgeDetect,
+                                           DirectedEdgeDetect, Add, AddElementwise, AdditiveGaussianNoise, Multiply,
+                                           MultiplyElementwise, Dropout, CoarseDropout, Invert, ContrastNormalization,
+                                           Affine, PiecewiseAffine, ElasticTransformation, pillike, LinearContrast, Canny)  # noqa
+            aug_code = cfg.color_aug_code
+            color_augmentor = eval(aug_code)
+        elif aug_type.lower() == 'code_albu':
+            from albumentations import (HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
+                                        Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion,
+                                        HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur,
+                                        MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast,
+                                        RandomBrightness, Flip, OneOf, Compose, CoarseDropout, RGBShift, RandomGamma,
+                                        RandomBrightnessContrast, JpegCompression, InvertImg)  # noqa
+            aug_code = """Compose([
+                CoarseDropout(max_height=0.05*480, max_holes=0.05*640, p=0.4),
+                OneOf([
+                    IAAAdditiveGaussianNoise(p=0.5),
+                    GaussNoise(p=0.5),
+                ], p=0.2),
+                OneOf([
+                    MotionBlur(p=0.2),
+                    MedianBlur(blur_limit=3, p=0.1),
+                    Blur(blur_limit=3, p=0.1),
+                ], p=0.2),
+                OneOf([
+                    CLAHE(clip_limit=2),
+                    IAASharpen(),
+                    IAAEmboss(),
+                    RandomBrightnessContrast(),
+                ], p=0.3),
+                InvertImg(p=0.2),
+                RGBShift(r_shift_limit=105, g_shift_limit=45, b_shift_limit=40, p=0.5),
+                RandomContrast(limit=0.9, p=0.5),
+                RandomGamma(gamma_limit=(80,120), p=0.5),
+                RandomBrightness(limit=1.2, p=0.5),
+                HueSaturationValue(hue_shift_limit=172, sat_shift_limit=20, val_shift_limit=27, p=0.3),
+                JpegCompression(quality_lower=4, quality_upper=100, p=0.4),
+            ], p=0.8)"""
+            color_augmentor = eval(cfg.color_aug_code)
+        else:
+            color_augmentor = None
+        # fmt: on
+        return color_augmentor
+
+    def _color_aug(self, image, aug_type="ROI10D"):
+        # assume image in [0, 255] uint8
+        if aug_type.lower() == "roi10d":  # need normalized image in [0,1]
+            image = np.asarray(image / 255.0, dtype=np.float32).copy()
+            image = self.color_augmentor.augment(image)
+            image = (image * 255.0 + 0.5).astype(np.uint8)
+            return image
+        elif aug_type.lower() in ["aae", "code"]:
+            # imgaug need uint8
+            return self.color_augmentor.augment_image(image)
+        elif aug_type.lower() in ["code_albu"]:
+            augmented = self.color_augmentor(image=image)
+            return augmented["image"]
+        else:
+            raise ValueError("aug_type: {} is not supported.".format(aug_type))
+
+    @lazy_property
+    def _bg_img_paths(self):
+        logger.info("get bg image paths")
+        cfg = self.cfg
+        # random.choice(bg_img_paths)
+        bg_type = cfg.bg_type
+        bg_root = cfg.bg_imgs_root
+        hashed_file_name = hashlib.md5(
+            ("{}_{}_{}_get_bg_imgs".format(bg_root, cfg.num_bg_imgs, bg_type)).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(".cache/bg_paths_{}_{}.pkl".format(bg_type, hashed_file_name))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        if osp.exists(cache_path):
+            logger.info("get bg_paths from cache file: {}".format(cache_path))
+            bg_img_paths = mmcv.load(cache_path)
+            logger.info("num bg imgs: {}".format(len(bg_img_paths)))
+            assert len(bg_img_paths) > 0
+            return bg_img_paths
+
+        logger.info("building bg imgs cache {}...".format(bg_type))
+        assert osp.exists(bg_root), f"BG ROOT: {bg_root} does not exist"
+        if bg_type == "coco":
+            img_paths = [
+                osp.join(bg_root, fn.name) for fn in os.scandir(bg_root) if ".png" in fn.name or "jpg" in fn.name
+            ]
+        elif bg_type == "VOC_table":  # used in original deepim
+            VOC_root = bg_root  # path to "VOCdevkit/VOC2012"
+            VOC_image_set_dir = osp.join(VOC_root, "ImageSets/Main")
+            VOC_bg_list_path = osp.join(VOC_image_set_dir, "diningtable_trainval.txt")
+            with open(VOC_bg_list_path, "r") as f:
+                VOC_bg_list = [
+                    line.strip("\r\n").split()[0] for line in f.readlines() if line.strip("\r\n").split()[1] == "1"
+                ]
+            img_paths = [osp.join(VOC_root, "JPEGImages/{}.jpg".format(bg_idx)) for bg_idx in VOC_bg_list]
+        elif bg_type == "VOC":
+            VOC_root = bg_root  # path to "VOCdevkit/VOC2012"
+            img_paths = [
+                osp.join(VOC_root, "JPEGImages", fn.name)
+                for fn in os.scandir(osp.join(bg_root, "JPEGImages"))
+                if ".jpg" in fn.name
+            ]
+        elif bg_type == "SUN2012":
+            img_paths = [
+                osp.join(bg_root, "JPEGImages", fn.name)
+                for fn in os.scandir(osp.join(bg_root, "JPEGImages"))
+                if ".jpg" in fn.name
+            ]
+        else:
+            raise ValueError(f"BG_TYPE: {bg_type} is not supported")
+        assert len(img_paths) > 0, len(img_paths)
+
+        num_bg_imgs = min(len(img_paths), cfg.num_bg_imgs)
+        bg_img_paths = np.random.choice(img_paths, num_bg_imgs)
+
+        mmcv.dump(bg_img_paths, cache_path)
+        logger.info("num bg imgs: {}".format(len(bg_img_paths)))
+        assert len(bg_img_paths) > 0
+        return bg_img_paths
+
+    def replace_bg(self, im, im_mask, return_mask=False, truncate_fg=False):
+        cfg = self.cfg
+        # add background to the image
+        H, W = im.shape[:2]
+        ind = random.randint(0, len(self._bg_img_paths) - 1)
+        filename = self._bg_img_paths[ind]
+        if cfg.get("bg_keep_aspect_ratio", True):
+            bg_img = self.get_bg_image(filename, H, W)
+        else:
+            bg_img = self.get_bg_image_v2(filename, H, W)
+
+        if len(bg_img.shape) != 3:
+            bg_img = np.zeros((H, W, 3), dtype=np.uint8)
+            logger.warning("bad background image: {}".format(filename))
+
+        mask = im_mask.copy().astype(np.bool)
+        if truncate_fg:
+            mask = self.trunc_mask(im_mask)
+        mask_bg = ~mask
+        im[mask_bg] = bg_img[mask_bg]
+        im = im.astype(np.uint8)
+        if return_mask:
+            return im, mask  # bool fg mask
+        else:
+            return im
+
+    def trunc_mask(self, mask):
+        # return the bool truncated mask
+        mask = mask.copy().astype(np.bool)
+        nonzeros = np.nonzero(mask.astype(np.uint8))
+        x1, y1 = np.min(nonzeros, axis=1)
+        x2, y2 = np.max(nonzeros, axis=1)
+        c_h = 0.5 * (x1 + x2)
+        c_w = 0.5 * (y1 + y2)
+        rnd = random.random()
+        # print(x1, x2, y1, y2, c_h, c_w, rnd, mask.shape)
+        if rnd < 0.2:  # block upper
+            c_h_ = int(random.uniform(x1, c_h))
+            mask[:c_h_, :] = False
+        elif rnd < 0.4:  # block bottom
+            c_h_ = int(random.uniform(c_h, x2))
+            mask[c_h_:, :] = False
+        elif rnd < 0.6:  # block left
+            c_w_ = int(random.uniform(y1, c_w))
+            mask[:, :c_w_] = False
+        elif rnd < 0.8:  # block right
+            c_w_ = int(random.uniform(c_w, y2))
+            mask[:, c_w_:] = False
+        else:
+            pass
+        return mask
+
+    def get_bg_image(self, filename, imH, imW, channel=3):
+        """keep aspect ratio of bg during resize target image size:
+
+        imHximWxchannel.
+        """
+        cfg = self.cfg
+        target_size = min(imH, imW)
+        max_size = max(imH, imW)
+        real_hw_ratio = float(imH) / float(imW)
+        bg_image = utils.read_image(filename, format=cfg.img_format)
+        bg_h, bg_w, bg_c = bg_image.shape
+        bg_image_resize = np.zeros((imH, imW, channel), dtype="uint8")
+        if (float(imH) / float(imW) < 1 and float(bg_h) / float(bg_w) < 1) or (
+            float(imH) / float(imW) >= 1 and float(bg_h) / float(bg_w) >= 1
+        ):
+            if bg_h >= bg_w:
+                bg_h_new = int(np.ceil(bg_w * real_hw_ratio))
+                if bg_h_new < bg_h:
+                    bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]
+                else:
+                    bg_image_crop = bg_image
+            else:
+                bg_w_new = int(np.ceil(bg_h / real_hw_ratio))
+                if bg_w_new < bg_w:
+                    bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]
+                else:
+                    bg_image_crop = bg_image
+        else:
+            if bg_h >= bg_w:
+                bg_h_new = int(np.ceil(bg_w * real_hw_ratio))
+                bg_image_crop = bg_image[0:bg_h_new, 0:bg_w, :]
+            else:  # bg_h < bg_w
+                bg_w_new = int(np.ceil(bg_h / real_hw_ratio))
+                # logger.info(bg_w_new)
+                bg_image_crop = bg_image[0:bg_h, 0:bg_w_new, :]
+        bg_image_resize_0 = resize_short_edge(bg_image_crop, target_size, max_size)
+        h, w, c = bg_image_resize_0.shape
+        bg_image_resize[0:h, 0:w, :] = bg_image_resize_0
+        return bg_image_resize
+
+    def get_bg_image_v2(self, filename, imH, imW, channel=3):
+        cfg = self.cfg
+        _bg_img = utils.read_image(filename, format=cfg.img_format)
+        try:
+            # randomly crop a region as background
+            bw = _bg_img.shape[1]
+            bh = _bg_img.shape[0]
+            x1 = np.random.randint(0, int(bw / 3))
+            y1 = np.random.randint(0, int(bh / 3))
+            x2 = np.random.randint(int(2 * bw / 3), bw)
+            y2 = np.random.randint(int(2 * bh / 3), bh)
+            bg_img = cv2.resize(
+                _bg_img[y1:y2, x1:x2],
+                (imW, imH),
+                interpolation=cv2.INTER_LINEAR,
+            )
+        except:
+            bg_img = np.zeros((imH, imW, 3), dtype=np.uint8)
+            logger.warning("bad background image: {}".format(filename))
+        return bg_img
diff --git a/det/yolox/data/datasets/coco.py b/det/yolox/data/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf422533b50fb37ffe389d1d47e2ace5147827ec
--- /dev/null
+++ b/det/yolox/data/datasets/coco.py
@@ -0,0 +1,206 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import os
+from loguru import logger
+
+import cv2
+import numpy as np
+from pycocotools.coco import COCO
+
+from det.yolox.utils.setup_env import get_yolox_datadir
+from .datasets_wrapper import Dataset
+
+
+class COCODataset(Dataset):
+    """COCO dataset class."""
+
+    def __init__(
+        self,
+        data_dir=None,
+        json_file="instances_train2017.json",
+        name="train2017",
+        img_size=(416, 416),
+        preproc=None,
+        cache=False,
+    ):
+        """COCO dataset initialization.
+
+        Annotation data are read into memory by COCO API.
+        Args:
+            data_dir (str): dataset root directory
+            json_file (str): COCO json file name
+            name (str): COCO data name (e.g. 'train2017' or 'val2017')
+            img_size (int): target image size after pre-processing
+            preproc: data augmentation strategy
+        """
+        super().__init__(img_size)
+        if data_dir is None:
+            data_dir = os.path.join(get_yolox_datadir(), "coco")
+        self.data_dir = data_dir
+        self.json_file = json_file
+
+        self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
+        self.ids = self.coco.getImgIds()
+        self.class_ids = sorted(self.coco.getCatIds())
+        cats = self.coco.loadCats(self.coco.getCatIds())
+        self._classes = tuple([c["name"] for c in cats])
+        self.imgs = None
+        self.name = name
+        self.img_size = img_size
+        self.preproc = preproc
+        self.annotations = self._load_coco_annotations()
+        if cache:
+            self._cache_images()
+
+    def __len__(self):
+        return len(self.ids)
+
+    def __del__(self):
+        del self.imgs
+
+    def _load_coco_annotations(self):
+        return [self.load_anno_from_ids(_ids) for _ids in self.ids]
+
+    def _cache_images(self):
+        logger.warning(
+            "\n********************************************************************************\n"
+            "You are using cached images in RAM to accelerate training.\n"
+            "This requires large system RAM.\n"
+            "Make sure you have 200G+ RAM and 136G available disk space for training COCO.\n"
+            "********************************************************************************\n"
+        )
+        max_h = self.img_size[0]
+        max_w = self.img_size[1]
+        cache_file = self.data_dir + "/img_resized_cache_" + self.name + ".array"
+        if not os.path.exists(cache_file):
+            logger.info("Caching images for the first time. This might take about 20 minutes for COCO")
+            self.imgs = np.memmap(
+                cache_file,
+                shape=(len(self.ids), max_h, max_w, 3),
+                dtype=np.uint8,
+                mode="w+",
+            )
+            from tqdm import tqdm
+            from multiprocessing.pool import ThreadPool
+
+            NUM_THREADs = min(8, os.cpu_count())
+            loaded_images = ThreadPool(NUM_THREADs).imap(
+                lambda x: self.load_resized_img(x),
+                range(len(self.annotations)),
+            )
+            pbar = tqdm(enumerate(loaded_images), total=len(self.annotations))
+            for k, out in pbar:
+                self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy()
+            self.imgs.flush()
+            pbar.close()
+        else:
+            logger.warning(
+                "You are using cached imgs! Make sure your dataset is not changed!!\n"
+                "Everytime the self.input_size is changed in your exp file, you need to delete\n"
+                "the cached data and re-generate them.\n"
+            )
+
+        logger.info("Loading cached imgs...")
+        self.imgs = np.memmap(
+            cache_file,
+            shape=(len(self.ids), max_h, max_w, 3),
+            dtype=np.uint8,
+            mode="r+",
+        )
+
+    def load_anno_from_ids(self, id_):
+        im_ann = self.coco.loadImgs(id_)[0]
+        width = im_ann["width"]
+        height = im_ann["height"]
+        anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
+        annotations = self.coco.loadAnns(anno_ids)
+        objs = []
+        for obj in annotations:
+            x1 = np.max((0, obj["bbox"][0]))
+            y1 = np.max((0, obj["bbox"][1]))
+            x2 = np.min((width, x1 + np.max((0, obj["bbox"][2]))))
+            y2 = np.min((height, y1 + np.max((0, obj["bbox"][3]))))
+            if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
+                obj["clean_bbox"] = [x1, y1, x2, y2]
+                objs.append(obj)
+
+        num_objs = len(objs)
+
+        res = np.zeros((num_objs, 5))
+
+        for ix, obj in enumerate(objs):
+            cls = self.class_ids.index(obj["category_id"])
+            res[ix, 0:4] = obj["clean_bbox"]
+            res[ix, 4] = cls
+
+        r = min(self.img_size[0] / height, self.img_size[1] / width)
+        res[:, :4] *= r
+
+        img_info = (height, width)
+        resized_info = (int(height * r), int(width * r))
+
+        file_name = im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg"
+
+        return (res, img_info, resized_info, file_name)
+
+    def load_anno(self, index):
+        return self.annotations[index][0]
+
+    def load_resized_img(self, index):
+        img = self.load_image(index)
+        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * r), int(img.shape[0] * r)),
+            interpolation=cv2.INTER_LINEAR,
+        ).astype(np.uint8)
+        return resized_img
+
+    def load_image(self, index):
+        file_name = self.annotations[index][3]
+
+        img_file = os.path.join(self.data_dir, self.name, file_name)
+
+        img = cv2.imread(img_file)
+        assert img is not None
+
+        return img
+
+    def pull_item(self, index):
+        id_ = self.ids[index]
+
+        res, img_info, resized_info, _ = self.annotations[index]
+        if self.imgs is not None:
+            pad_img = self.imgs[index]
+            img = pad_img[: resized_info[0], : resized_info[1], :].copy()
+        else:
+            img = self.load_resized_img(index)
+
+        return img, res.copy(), img_info, np.array([id_])
+
+    @Dataset.mosaic_getitem
+    def __getitem__(self, index):
+        """One image / label pair for the given index is picked up and pre-
+        processed.
+
+        Args:
+            index (int): data index
+
+        Returns:
+            img (numpy.ndarray): pre-processed image
+            padded_labels (torch.Tensor): pre-processed label data.
+                The shape is :math:`[max_labels, 5]`.
+                each label consists of [class, xc, yc, w, h]:
+                    class (float): class index.
+                    xc, yc (float) : center of bbox whose values range from 0 to 1.
+                    w, h (float) : size of bbox whose values range from 0 to 1.
+            info_img : tuple of h, w.
+                h, w (int): original shape of the image
+            img_id (int): same as the input index. Used for evaluation.
+        """
+        img, target, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+        return img, target, img_info, img_id
diff --git a/det/yolox/data/datasets/coco_classes.py b/det/yolox/data/datasets/coco_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f5cbe6e86ed4fc8378760da71f8349a6406ff1
--- /dev/null
+++ b/det/yolox/data/datasets/coco_classes.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+COCO_CLASSES = (
+    "person",
+    "bicycle",
+    "car",
+    "motorcycle",
+    "airplane",
+    "bus",
+    "train",
+    "truck",
+    "boat",
+    "traffic light",
+    "fire hydrant",
+    "stop sign",
+    "parking meter",
+    "bench",
+    "bird",
+    "cat",
+    "dog",
+    "horse",
+    "sheep",
+    "cow",
+    "elephant",
+    "bear",
+    "zebra",
+    "giraffe",
+    "backpack",
+    "umbrella",
+    "handbag",
+    "tie",
+    "suitcase",
+    "frisbee",
+    "skis",
+    "snowboard",
+    "sports ball",
+    "kite",
+    "baseball bat",
+    "baseball glove",
+    "skateboard",
+    "surfboard",
+    "tennis racket",
+    "bottle",
+    "wine glass",
+    "cup",
+    "fork",
+    "knife",
+    "spoon",
+    "bowl",
+    "banana",
+    "apple",
+    "sandwich",
+    "orange",
+    "broccoli",
+    "carrot",
+    "hot dog",
+    "pizza",
+    "donut",
+    "cake",
+    "chair",
+    "couch",
+    "potted plant",
+    "bed",
+    "dining table",
+    "toilet",
+    "tv",
+    "laptop",
+    "mouse",
+    "remote",
+    "keyboard",
+    "cell phone",
+    "microwave",
+    "oven",
+    "toaster",
+    "sink",
+    "refrigerator",
+    "book",
+    "clock",
+    "vase",
+    "scissors",
+    "teddy bear",
+    "hair drier",
+    "toothbrush",
+)
diff --git a/det/yolox/data/datasets/dataset_factory.py b/det/yolox/data/datasets/dataset_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..b59dc641742564ef3e9ba9b6732eaaf11dceb78f
--- /dev/null
+++ b/det/yolox/data/datasets/dataset_factory.py
@@ -0,0 +1,139 @@
+import logging
+import os.path as osp
+import mmcv
+from detectron2.data import DatasetCatalog
+from . import (
+    lm_syn_imgn,
+    lm_dataset_d2,
+    # lm_syn_egl,
+    lm_pbr,
+    lm_blender,
+    # lm_dataset_crop_d2,
+    ycbv_pbr,
+    ycbv_d2,
+    ycbv_bop_test,
+    hb_pbr,
+    hb_bop_val,
+    hb_bop_test,
+    hb_bench_driller_phone_d2,
+    # duck_frames,
+    # lm_new_duck_pbr,
+    tudl_train_real,
+    tudl_pbr,
+    tudl_bop_test,
+    tless_primesense_train,
+    tless_pbr,
+    tless_bop_test,
+    icbin_pbr,  # TODO:test
+    icbin_bop_test,
+    itodd_pbr,  # TODO:test
+    itodd_d2,
+    itodd_bop_test,
+)  # noqa
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+# from lib.utils.utils import iprint
+
+__all__ = [
+    "register_dataset",
+    "register_datasets",
+    "register_datasets_in_cfg",
+    "get_available_datasets",
+]
+_DSET_MOD_NAMES = [
+    "lm_syn_imgn",
+    "lm_dataset_d2",
+    # "lm_syn_egl",
+    "lm_pbr",
+    "lm_blender",
+    # "lm_dataset_crop_d2",
+    "ycbv_pbr",
+    "ycbv_d2",
+    "ycbv_bop_test",
+    "hb_pbr",
+    "hb_bop_val",
+    "hb_bop_test",
+    "hb_bench_driller_phone_d2",
+    # "duck_frames",
+    # "lm_new_duck_pbr",
+    "tudl_train_real",
+    "tudl_pbr",
+    "tudl_bop_test",
+    "tless_primesense_train",
+    "tless_pbr",
+    "tless_bop_test",
+    "icbin_pbr",
+    "icbin_bop_test",
+    "itodd_pbr",
+    "itodd_d2",
+    "itodd_bop_test",
+]
+
+logger = logging.getLogger(__name__)
+
+
+def register_dataset(mod_name, dset_name, data_cfg=None):
+    """
+    mod_name: a module under core.datasets or other dataset source file imported here
+    dset_name: dataset name
+    data_cfg: dataset config
+    """
+    register_func = eval(mod_name)
+    register_func.register_with_name_cfg(dset_name, data_cfg)
+
+
+def get_available_datasets(mod_name):
+    return eval(mod_name).get_available_datasets()
+
+
+def register_datasets_in_cfg(cfg):
+    for split in [
+        "TRAIN",
+        "TEST",
+        "SS_TRAIN",
+        "TEST_DEBUG",
+        "TRAIN_REAL",
+        "TRAIN2",
+        "TRAIN_SYN_SUP",
+    ]:
+        for name in cfg.DATASETS.get(split, []):
+            if name in DatasetCatalog.list():
+                continue
+            registered = False
+            # try to find in pre-defined datasets
+            # NOTE: it is better to let all datasets pre-refined
+            for _mod_name in _DSET_MOD_NAMES:
+                if name in get_available_datasets(_mod_name):
+                    register_dataset(_mod_name, name, data_cfg=None)
+                    registered = True
+                    break
+            # not in pre-defined; not recommend
+            if not registered:
+                # try to get mod_name and data_cfg from cfg
+                """load data_cfg and mod_name from file
+                cfg.DATA_CFG[name] = 'path_to_cfg'
+                """
+                assert "DATA_CFG" in cfg and name in cfg.DATA_CFG, "no cfg.DATA_CFG.{}".format(name)
+                assert osp.exists(cfg.DATA_CFG[name])
+                data_cfg = mmcv.load(cfg.DATA_CFG[name])
+                mod_name = data_cfg.pop("mod_name", None)
+                assert mod_name in _DSET_MOD_NAMES, mod_name
+                register_dataset(mod_name, name, data_cfg)
+
+
+def register_datasets(dataset_names):
+    for name in dataset_names:
+        if name in DatasetCatalog.list():
+            continue
+        registered = False
+        # try to find in pre-defined datasets
+        # NOTE: it is better to let all datasets pre-refined
+        for _mod_name in _DSET_MOD_NAMES:
+            if name in get_available_datasets(_mod_name):
+                register_dataset(_mod_name, name, data_cfg=None)
+                registered = True
+                break
+
+        # not in pre-defined; not recommend
+        if not registered:
+            raise ValueError(f"dataset {name} is not defined")
diff --git a/det/yolox/data/datasets/datasets_wrapper.py b/det/yolox/data/datasets/datasets_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99a8a13825c6e147e422bbafa7bbd2e653eff5e
--- /dev/null
+++ b/det/yolox/data/datasets/datasets_wrapper.py
@@ -0,0 +1,108 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import bisect
+from functools import wraps
+
+from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
+from torch.utils.data.dataset import Dataset as torchDataset
+
+
+class ConcatDataset(torchConcatDataset):
+    def __init__(self, datasets):
+        super(ConcatDataset, self).__init__(datasets)
+        if hasattr(self.datasets[0], "input_dim"):
+            self._input_dim = self.datasets[0].input_dim
+            self.input_dim = self.datasets[0].input_dim
+
+    def pull_item(self, idx):
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError("absolute value of index should not exceed dataset length")
+            idx = len(self) + idx
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        return self.datasets[dataset_idx].pull_item(sample_idx)
+
+
+class MixConcatDataset(torchConcatDataset):
+    def __init__(self, datasets):
+        super(MixConcatDataset, self).__init__(datasets)
+        if hasattr(self.datasets[0], "input_dim"):
+            self._input_dim = self.datasets[0].input_dim
+            self.input_dim = self.datasets[0].input_dim
+
+    def __getitem__(self, index):
+
+        if not isinstance(index, int):
+            idx = index[1]
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError("absolute value of index should not exceed dataset length")
+            idx = len(self) + idx
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        if not isinstance(index, int):
+            index = (index[0], sample_idx, index[2])
+
+        return self.datasets[dataset_idx][index]
+
+
+class Dataset(torchDataset):
+    """This class is a subclass of the base :class:`torch.utils.data.Dataset`,
+    that enables on the fly resizing of the ``input_dim``.
+
+    Args:
+        input_dimension (tuple): (height, width) tuple with default dimensions of the network
+    """
+
+    def __init__(self, input_dimension, mosaic=True):
+        super().__init__()
+        self.__input_dim = input_dimension[:2]
+        self.enable_mosaic = mosaic
+
+    @property
+    def input_dim(self):
+        """Dimension that can be used by transforms to set the correct image
+        size, etc. This allows transforms to have a single source of truth for
+        the input dimension of the network.
+
+        Return:
+            list: Tuple containing the current width,height
+        """
+        if hasattr(self, "_input_dim"):
+            return self._input_dim
+        return self.__input_dim
+
+    @staticmethod
+    def mosaic_getitem(getitem_fn):
+        """Decorator method that needs to be used around the ``__getitem__``
+        method. |br| This decorator enables the on the fly resizing of the
+        ``input_dim`` with our :class:`~lightnet.data.DataLoader` class.
+
+        Example:
+            >>> class CustomSet(ln.data.Dataset):
+            ...     def __len__(self):
+            ...         return 10
+            ...     @ln.data.Dataset.mosaic_getitem
+            ...     def __getitem__(self, index):
+            ...         return self.enable_mosaic
+        """
+
+        @wraps(getitem_fn)
+        def wrapper(self, index):
+            if not isinstance(index, int):
+                self.enable_mosaic = index[0]
+                index = index[1]
+
+            ret_val = getitem_fn(self, index)
+
+            return ret_val
+
+        return wrapper
diff --git a/det/yolox/data/datasets/hb_bench_driller_phone_d2.py b/det/yolox/data/datasets/hb_bench_driller_phone_d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..56811928151885e8253ed5953e6701bdd730f40b
--- /dev/null
+++ b/det/yolox/data/datasets/hb_bench_driller_phone_d2.py
@@ -0,0 +1,615 @@
+# NOTE: different from Self6D-v1 which uses hb-v1, this uses hb_bop conventions
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class HB_BenchDrillerPhone:
+    """a test sequence (test sequence 2) of HomebrewedDB contains 3 objects in
+    linemod."""
+
+    def __init__(self, data_cfg):
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = dataset_root = data_cfg["dataset_root"]
+        self.ann_files = data_cfg["ann_files"]
+        self.models_root = data_cfg["models_root"]  # models_lm
+        self.scale_to_meter = data_cfg["scale_to_meter"]
+
+        # use the images with converted K
+        assert cam_type in ["linemod", "hb"]
+        self.cam_type = cam_type
+        if cam_type == "linemod":  # linemod K
+            self.cam = np.array(
+                [
+                    [572.4114, 0, 325.2611],
+                    [0, 573.57043, 242.04899],
+                    [0, 0, 1],
+                ],
+                dtype="float32",
+            )
+            self.rgb_root = osp.join(dataset_root, "sequence/rgb_lmK")
+            self.depth_root = osp.join(dataset_root, "sequence/depth_lmK")
+            self.mask_visib_root = osp.join(dataset_root, "sequence/mask_visib_lmK")
+        else:  # hb
+            self.cam = np.array(
+                [[537.4799, 0, 318.8965], [0, 536.1447, 238.3781], [0, 0, 1]],
+                dtype="float32",
+            )
+            self.rgb_root = osp.join(dataset_root, "sequence/rgb")
+            self.depth_root = osp.join(dataset_root, "sequence/depth")
+            self.mask_visib_root = osp.join(dataset_root, "sequence/mask_visib")
+        assert osp.exists(self.rgb_root), self.rgb_root
+
+        self.with_masks = data_cfg.get("with_masks", True)
+        self.with_depth = data_cfg.get("with_depth", True)
+
+        self.height = data_cfg["height"]
+        self.width = data_cfg["width"]
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.hb_bdp.id2obj.items() if obj_name in ref.hb_bdp.objects]
+        self.lm_cat_ids = [cat_id for cat_id, obj_name in lm13_id2obj.items() if obj_name in ref.hb_bdp.objects]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.lm_map_id = {k: v for k, v in zip(self.cat_ids, self.lm_cat_ids)}  # from hb label to lm label
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+        dataset_dicts = []
+        im_id_global = 0
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        # NOTE: converted from gt_v1, obj_id --> obj_id+1
+        gt_path = osp.join(self.dataset_root, "sequence/gt_v2.json")
+        gt_dict = mmcv.load(gt_path)
+
+        # determine which images to load by self.ann_files
+        sel_im_ids = []
+        for ann_file in self.ann_files:
+            with open(ann_file, "r") as f:
+                for line in f:
+                    line = line.strip("\r\n")
+                    cur_im_id = int(line)
+                    if cur_im_id not in sel_im_ids:
+                        sel_im_ids.append(cur_im_id)
+
+        for str_im_id, annos in tqdm(gt_dict.items()):  # str im ids
+            int_im_id = int(str_im_id)
+            if int_im_id not in sel_im_ids:
+                continue
+            rgb_path = osp.join(self.rgb_root, "color_{:06d}.png".format(int_im_id))
+            depth_path = osp.join(self.depth_root, "{:06d}.png".format(int_im_id))
+
+            scene_id = 2  # dummy (because in the whole test set, its scene id is 2)
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                "depth_factor": 1 / self.scale_to_meter,
+                "height": self.height,
+                "width": self.width,
+                "image_id": im_id_global,
+                "scene_im_id": "{}/{}".format(scene_id, int_im_id),  # for evaluation
+                "cam": self.cam,
+                "img_type": "real",
+            }
+            im_id_global += 1
+
+            inst_annos = []
+            for anno_i, anno in enumerate(annos):
+                obj_id = anno["obj_id"]
+                cls_name = ref.hb_bdp.id2obj[obj_id]
+                if cls_name not in self.objs:
+                    continue
+                if cls_name not in ref.hb_bdp.objects:  # only support 3 objects
+                    continue
+
+                cur_label = self.cat2label[obj_id]
+                lm_cur_label = self.lm_map_id[obj_id] - 1  # 0-based label
+
+                R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                pose = np.hstack([R, t.reshape(3, 1)])
+                if self.cam_type == "hb":
+                    bbox = anno["obj_bb"]
+                    bbox_mode = BoxMode.XYWH_ABS
+                elif self.cam_type == "linemod":
+                    # get bbox from projected points
+                    bbox = misc.compute_2d_bbox_xyxy_from_pose_v2(
+                        self.models[cur_label]["pts"].astype("float32"),
+                        pose.astype("float32"),
+                        self.cam,
+                        width=self.width,
+                        height=self.height,
+                        clip=True,
+                    )
+                    bbox_mode = BoxMode.XYXY_ABS
+                    x1, y1, x2, y2 = bbox
+                    w = x2 - x1
+                    h = y2 - y1
+                else:
+                    raise ValueError("Wrong cam type: {}".format(self.cam_type))
+
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+
+                mask_visib_file = osp.join(
+                    self.mask_visib_root,
+                    "{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                assert osp.exists(mask_visib_file), mask_visib_file
+                # load mask visib  TODO: load both mask_visib and mask_full
+                mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                area = mask_single.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                    continue
+                mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                quat = mat2quat(R).astype("float32")
+
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                inst = {
+                    "category_id": lm_cur_label,  # 0-based label
+                    "bbox": bbox,
+                    "bbox_mode": bbox_mode,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                }
+
+                # NOTE: currently no xyz
+                # if "test" not in self.name:
+                #     xyz_path = osp.join(xyz_root, f"{int_im_id:06d}_{anno_i:06d}.pkl")
+                #     assert osp.exists(xyz_path), xyz_path
+                #     inst["xyz_path"] = xyz_path
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[lm_cur_label][key]
+
+                inst_annos.append(inst)
+            if len(inst_annos) == 0 and self.filter_invalid:  # filter im without anno
+                continue
+            record["annotations"] = inst_annos
+            dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info(
+            "loaded dataset dicts, num_images: {}, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start)
+        )
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(PROJ_ROOT, ".cache", "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # logger.info("load cached object models from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            if obj_name not in ref.hb_bdp.objects:
+                models.append(None)
+                continue
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    "obj_{:06d}.ply".format(ref.hb_bdp.obj2id[obj_name]),
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_hb_bdp_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        if obj_name not in data_ref.objects:
+            sym_info = None
+            continue
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+lm13_id2obj = {
+    1: "ape",
+    2: "benchvise",
+    3: "camera",
+    4: "can",
+    5: "cat",
+    6: "driller",
+    7: "duck",
+    8: "eggbox",
+    9: "glue",
+    10: "holepuncher",
+    11: "iron",
+    12: "lamp",
+    13: "phone",
+}  # no bowl, cup
+
+SPLITS_HB_BenchviseDrillerPhone = dict(
+    # TODO: maybe add scene name
+    hb_benchvise_driller_phone_all_lmK=dict(
+        name="hb_benchvise_driller_phone_all_lmK",
+        dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+        models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+        ann_files=[osp.join(DATASETS_ROOT, "hb_bench_driller_phone/image_set/all.txt")],
+        objs=["benchvise", "driller", "phone"],
+        use_cache=True,
+        num_to_load=-1,
+        cam_type="linemod",
+        scale_to_meter=0.001,
+        filter_invalid=False,
+        height=480,
+        width=640,
+        ref_key="hb_bdp",
+    ),
+    hb_benchvise_driller_phone_all=dict(
+        name="hb_benchvise_driller_phone_all",
+        dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+        models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+        ann_files=[osp.join(DATASETS_ROOT, "hb_bench_driller_phone/image_set/all.txt")],
+        objs=["benchvise", "driller", "phone"],
+        use_cache=True,
+        num_to_load=-1,
+        cam_type="hb",  # NOTE: hb K
+        scale_to_meter=0.001,
+        filter_invalid=False,
+        height=480,
+        width=640,
+        ref_key="hb_bdp",
+    ),
+    hb_benchvise_driller_phone_test_lmK=dict(
+        name="hb_benchvise_driller_phone_test_lmK",
+        dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+        models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+        ann_files=[osp.join(DATASETS_ROOT, "hb_bench_driller_phone/image_set/test.txt")],
+        objs=["benchvise", "driller", "phone"],
+        use_cache=True,
+        num_to_load=-1,
+        cam_type="linemod",
+        scale_to_meter=0.001,
+        filter_invalid=False,
+        height=480,
+        width=640,
+        ref_key="hb_bdp",
+    ),
+    hb2lm_benchvise_driller_phone_test_lmK=dict(
+        name="hb2lm_benchvise_driller_phone_test_lmK",
+        dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+        models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+        ann_files=[osp.join(DATASETS_ROOT, "hb_bench_driller_phone/image_set/test.txt")],
+        objs=[v for v in lm13_id2obj.values()],  # pretend to have 13 classed of objs
+        use_cache=True,
+        num_to_load=-1,
+        cam_type="linemod",
+        scale_to_meter=0.001,
+        filter_invalid=False,
+        height=480,
+        width=640,
+        ref_key="hb_bdp",
+    ),
+    hb_benchvise_driller_phone_test=dict(
+        name="hb_benchvise_driller_phone_test",
+        dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+        models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+        ann_files=[osp.join(DATASETS_ROOT, "hb_bench_driller_phone/image_set/test.txt")],
+        objs=["benchvise", "driller", "phone"],
+        use_cache=True,
+        num_to_load=-1,
+        cam_type="hb",
+        scale_to_meter=0.001,
+        filter_invalid=False,
+        height=480,
+        width=640,
+        ref_key="hb_bdp",
+    ),
+)
+
+
+# add varying percent splits
+VARY_PERCENT_SPLITS = [
+    "test100",
+    "train090",
+    "train180",
+    "train270",
+    "train360",
+    "train450",
+    "train540",
+    "train630",
+    "train720",
+    "train810",
+    "train900",
+]
+
+# all objects
+for _split in VARY_PERCENT_SPLITS:
+    for cam_type in ["linemod", "hb"]:
+        K_str = "_lmK" if cam_type == "linemod" else ""
+        name = "hb_benchvise_driller_phone_{}{}".format(_split, K_str)
+        if name not in SPLITS_HB_BenchviseDrillerPhone:
+            SPLITS_HB_BenchviseDrillerPhone[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+                models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        f"hb_bench_driller_phone/image_set/{_split}.txt",
+                    )
+                ],
+                objs=["benchvise", "driller", "phone"],
+                use_cache=True,
+                num_to_load=-1,
+                cam_type=cam_type,
+                scale_to_meter=0.001,
+                filter_invalid=False,
+                height=480,
+                width=640,
+                ref_key="hb_bdp",
+            )
+
+# single obj splits
+for obj in ref.hb_bdp.objects:
+    for split in ["test", "train", "all"] + VARY_PERCENT_SPLITS:
+        for cam_type in ["linemod", "hb"]:
+            K_str = "_lmK" if cam_type == "linemod" else ""
+            name = "hb_bdp_{}_{}{}".format(obj, split, K_str)
+            if name not in SPLITS_HB_BenchviseDrillerPhone:
+                SPLITS_HB_BenchviseDrillerPhone[name] = dict(
+                    name=name,
+                    dataset_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone"),
+                    models_root=osp.join(DATASETS_ROOT, "hb_bench_driller_phone/models_lm/"),
+                    ann_files=[
+                        osp.join(
+                            DATASETS_ROOT,
+                            f"hb_bench_driller_phone/image_set/{split}.txt",
+                        )
+                    ],
+                    objs=[obj],
+                    use_cache=True,
+                    num_to_load=-1,
+                    cam_type=cam_type,
+                    scale_to_meter=0.001,
+                    filter_invalid=False,
+                    height=480,
+                    width=640,
+                    ref_key="hb_bdp",
+                )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_HB_BenchviseDrillerPhone:
+        used_cfg = SPLITS_HB_BenchviseDrillerPhone[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, HB_BenchDrillerPhone(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="coco_bop",  # NOTE: should not be bop
+        **get_hb_bdp_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_HB_BenchviseDrillerPhone.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    # python -m core.datasets.lm_dataset_d2 lmo_syn_vispy_train
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                ],
+                ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                row=2,
+                col=2,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import (
+        vis_image_mask_bbox_cv2,
+        vis_image_bboxes_cv2,
+    )
+    from lib.utils.mask_utils import cocosegm2mask
+    from lib.utils.bbox_utils import xywh_to_xyxy
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/hb_bop_test.py b/det/yolox/data/datasets/hb_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa3a6aab9d527fe7a582a6e83d2c1ccdf6ad666c
--- /dev/null
+++ b/det/yolox/data/datasets/hb_bop_test.py
@@ -0,0 +1,467 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class HB_BOP_TEST_Dataset:
+    """hb bop test."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+        # all classes are self.objs, but this enables us to evaluate on selected objs
+        self.select_objs = data_cfg.get("select_objs", self.objs)
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/hb/test
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/hb/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg["cache_dir"]  # .cache
+        self.use_cache = data_cfg["use_cache"]  # True
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.hb.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs)
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        im_id_global = 0
+
+        if True:
+            targets = mmcv.load(self.ann_file)
+            scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+            scene_im_ids = sorted(list(set(scene_im_ids)))
+
+            # NOTE: currently no gt info available
+            # load infos for each scene
+            # gt_dicts = {}
+            # gt_info_dicts = {}
+            # cam_dicts = {}
+            for scene_id, im_id in scene_im_ids:
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                # if scene_id not in gt_dicts:
+                #     gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, 'scene_gt.json'))
+                # if scene_id not in gt_info_dicts:
+                #     gt_info_dicts[scene_id] = mmcv.load(osp.join(scene_root, 'scene_gt_info.json'))  # bbox_obj, bbox_visib
+                # if scene_id not in cam_dicts:
+                #     cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for scene_id, im_id in tqdm(scene_im_ids):
+                str_im_id = str(im_id)
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(im_id))
+
+                scene_id = int(rgb_path.split("/")[-3])
+
+                # cam = np.array(cam_dicts[scene_id][str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                # depth_factor = 1000.0 / cam_dicts[scene_id][str_im_id]["depth_scale"]
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    # "depth_factor": depth_factor,
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": im_id_global,  # unique image_id in the dataset, for coco evaluation
+                    "scene_im_id": "{}/{}".format(scene_id, im_id),  # for evaluation
+                    # "cam": cam,
+                    "img_type": "real",
+                }
+                im_id_global += 1
+                # insts = []
+                # for anno_i, anno in enumerate(gt_dicts[scene_id][str_im_id]):
+                #     obj_id = anno['obj_id']
+                #     if ref.hb.id2obj[obj_id] not in self.select_objs:
+                #         continue
+                #     cur_label = self.cat2label[obj_id]  # 0-based label
+                #     R = np.array(anno['cam_R_m2c'], dtype='float32').reshape(3, 3)
+                #     t = np.array(anno['cam_t_m2c'], dtype='float32') / 1000.0
+                #     pose = np.hstack([R, t.reshape(3, 1)])
+                #     quat = mat2quat(R).astype('float32')
+                #     allo_q = mat2quat(egocentric_to_allocentric(pose)[:3, :3]).astype('float32')
+
+                #     proj = (record["cam"] @ t.T).T
+                #     proj = proj[:2] / proj[2]
+
+                #     bbox_visib = gt_info_dicts[scene_id][str_im_id][anno_i]['bbox_visib']
+                #     bbox_obj = gt_info_dicts[scene_id][str_im_id][anno_i]['bbox_obj']
+                #     x1, y1, w, h = bbox_visib
+                #     if self.filter_invalid:
+                #         if h <= 1 or w <= 1:
+                #             self.num_instances_without_valid_box += 1
+                #             continue
+
+                #     mask_file = osp.join(scene_root, "mask/{:06d}_{:06d}.png".format(im_id, anno_i))
+                #     mask_visib_file = osp.join(scene_root, "mask_visib/{:06d}_{:06d}.png".format(im_id, anno_i))
+                #     assert osp.exists(mask_file), mask_file
+                #     assert osp.exists(mask_visib_file), mask_visib_file
+                #     # load mask visib
+                #     mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                #     area = mask_single.sum()
+                #     if area < 3:  # filter out too small or nearly invisible instances
+                #         self.num_instances_without_valid_segmentation += 1
+                #         continue
+                #     mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                #    # load mask full
+                #     mask_full = mmcv.imread(mask_file, "unchanged")
+                #     mask_full = mask_full.astype("bool")
+                #     mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                #     inst = {
+                #         'category_id': cur_label,  # 0-based label
+                #         'bbox': bbox_visib,  # TODO: load both bbox_obj and bbox_visib
+                #         'bbox_mode': BoxMode.XYWH_ABS,
+                #         'pose': pose,
+                #         "quat": quat,
+                #         "trans": t,
+                #         "allo_quat": allo_q,
+                #         "centroid_2d": proj,  # absolute (cx, cy)
+                #         "segmentation": mask_rle,
+                #         "mask_full": mask_full_rle,
+                #     }
+                #     for key in [
+                #             "bbox3d_and_center", "fps4_and_center", "fps8_and_center", "fps12_and_center",
+                #             "fps16_and_center", "fps20_and_center"
+                #     ]:
+                #         inst[key] = self.models[cur_label][key]
+                #     insts.append(inst)
+                # if len(insts) == 0:  # filter im without anno
+                #     continue
+                # record['annotations'] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, f"models_{self.name}.pkl")
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.hb.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_hb_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+################################################################################
+# 16 objects for bop19/20
+HB_BOP19_20_OBJS = [ref.hb.id2obj[_i] for _i in [1, 3, 4, 8, 9, 10, 12, 15, 17, 18, 19, 22, 23, 29, 32, 33]]
+
+SPLITS_HB = dict(
+    hb_test_primesense_bop19=dict(
+        name="hb_test_primesense_bop19",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+        # objs=HB_BOP19_20_OBJS,  # selected 16 objects
+        objs=ref.hb.objects,  # all 33 objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="hb",
+    ),
+    # NOTE: the file for targets is not released yet
+    # hb_test_primesense_all=dict(
+    #     name="hb_test_primesense_all",
+    #     dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test_primesense"),
+    #     models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+    #     objs=ref.hb.objects,  # all 33 objects
+    #     ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test_targets_all.json"),
+    #     scale_to_meter=0.001,
+    #     with_masks=True,  # (load masks but may not use it)
+    #     with_depth=True,  # (load depth path here, but may not use it)
+    #     height=480,
+    #     width=640,
+    #     cache_dir=osp.join(PROJ_ROOT, ".cache"),
+    #     use_cache=True,
+    #     num_to_load=-1,
+    #     filter_invalid=False,
+    #     ref_key="hb",
+    # ),
+)
+
+
+# single objs (num_class is from all objs)
+for obj in ref.hb.objects:
+    name = "hb_bop_{}_test_primesense".format(obj)
+    select_objs = [obj]
+    if name not in SPLITS_HB:
+        SPLITS_HB[name] = dict(
+            name=name,
+            dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test_primesense"),
+            models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+            objs=[obj],  # only this obj
+            select_objs=select_objs,  # selected objects
+            ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/test_targets_bop19.json"),
+            scale_to_meter=0.001,
+            with_masks=True,  # (load masks but may not use it)
+            with_depth=True,  # (load depth path here, but may not use it)
+            height=480,
+            width=640,
+            cache_dir=osp.join(PROJ_ROOT, ".cache"),
+            use_cache=True,
+            num_to_load=-1,
+            filter_invalid=False,
+            ref_key="hb",
+        )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_HB:
+        used_cfg = SPLITS_HB[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, HB_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="hb",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_hb_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_HB.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        detection_utils.check_image_size(d, img)
+        depth = mmcv.imread(d["depth_file"], "unchanged") / d["depth_factor"]
+
+        imH, imW = img.shape[:2]
+        # annos = d["annotations"]
+        # masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        # bboxes = [anno["bbox"] for anno in annos]
+        # bbox_modes = [anno["bbox_mode"] for anno in annos]
+        # bboxes_xyxy = np.array(
+        #     [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)])
+        # kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        # quats = [anno["quat"] for anno in annos]
+        # transes = [anno["trans"] for anno in annos]
+        # Rs = [quat2mat(quat) for quat in quats]
+        # # 0-based label
+        # cat_ids = [anno["category_id"] for anno in annos]
+        # K = d["cam"]
+        # kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        # # TODO: visualize pose and keypoints
+        # labels = [objs[cat_id] for cat_id in cat_ids]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        # img_vis = vis_image_mask_bbox_cv2(img, masks, bboxes=bboxes_xyxy, labels=labels)
+        # img_vis_kpts2d = img.copy()
+        # for anno_i in range(len(annos)):
+        #     img_vis_kpts2d = misc.draw_projected_box3d(img_vis_kpts2d, kpts_2d[anno_i])
+        # grid_show([img[:, :, [2, 1, 0]], img_vis[:, :, [2, 1, 0]], img_vis_kpts2d[:, :, [2, 1, 0]], depth],
+        #           [f"img:{d['file_name']}", "vis_img", "img_vis_kpts2d", 'depth'],
+        #           row=2,
+        #           col=2)
+        grid_show(
+            [img[:, :, [2, 1, 0]], depth],
+            [f"img:{d['file_name']}", "depth"],
+            row=1,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m core.gdrn_modeling.datasets.hb_bop_test dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from detectron2.data import detection_utils
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/hb_bop_val.py b/det/yolox/data/datasets/hb_bop_val.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f0839ae016b44efa348cb0c4d24023374f2f33
--- /dev/null
+++ b/det/yolox/data/datasets/hb_bop_val.py
@@ -0,0 +1,472 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class HB_BOP_VAL_Dataset:
+    """hb bop val."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+        # all classes are self.objs, but this enables us to evaluate on selected objs
+        self.select_objs = data_cfg.get("select_objs", self.objs)
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/hb/val_primesense
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/hb/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.hb.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        im_id_global = 0
+
+        if True:
+            targets = mmcv.load(self.ann_file)
+            scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+            scene_im_ids = sorted(list(set(scene_im_ids)))
+
+            # load infos for each scene
+            gt_dicts = {}
+            gt_info_dicts = {}
+            cam_dicts = {}
+            for scene_id, im_id in scene_im_ids:
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                if scene_id not in gt_dicts:
+                    gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+                if scene_id not in gt_info_dicts:
+                    gt_info_dicts[scene_id] = mmcv.load(
+                        osp.join(scene_root, "scene_gt_info.json")
+                    )  # bbox_obj, bbox_visib
+                if scene_id not in cam_dicts:
+                    cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for scene_id, im_id in tqdm(scene_im_ids):
+                str_im_id = str(im_id)
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(im_id))
+
+                scene_id = int(rgb_path.split("/")[-3])
+
+                cam = np.array(cam_dicts[scene_id][str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dicts[scene_id][str_im_id]["depth_scale"]
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "depth_factor": depth_factor,
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": im_id_global,  # unique image_id in the dataset, for coco evaluation
+                    "scene_im_id": "{}/{}".format(scene_id, im_id),  # for evaluation
+                    "cam": cam,
+                    "img_type": "real",
+                }
+                im_id_global += 1
+                insts = []
+                for anno_i, anno in enumerate(gt_dicts[scene_id][str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if ref.hb.id2obj[obj_id] not in self.select_objs:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dicts[scene_id][str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dicts[scene_id][str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 3:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, f"models_{self.name}.pkl")
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.hb.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_hb_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+################################################################################
+# 16 objects
+HB_BOP19_20_OBJS = [ref.hb.id2obj[_i] for _i in [1, 3, 4, 8, 9, 10, 12, 15, 17, 18, 19, 22, 23, 29, 32, 33]]
+
+SPLITS_HB = dict(
+    hb_val_primesense_bop19=dict(
+        name="hb_val_primesense_bop19",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_primesense"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+        objs=HB_BOP19_20_OBJS,  # selected 16 objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="hb",
+    ),
+    hb_val_primesense_all=dict(
+        name="hb_val_primesense_all",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_primesense"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+        objs=ref.hb.objects,  # all 33 objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="hb",
+    ),
+)
+
+
+# single objs (num_class is from all objs)
+for obj in ref.hb.objects:
+    name = "hb_bop_{}_val_primesense".format(obj)
+    select_objs = [obj]
+    if name not in SPLITS_HB:
+        SPLITS_HB[name] = dict(
+            name=name,
+            dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_primesense"),
+            models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+            objs=[obj],  # only this obj
+            select_objs=select_objs,  # selected objects
+            ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/val_targets_bop19.json"),
+            scale_to_meter=0.001,
+            with_masks=True,  # (load masks but may not use it)
+            with_depth=True,  # (load depth path here, but may not use it)
+            height=480,
+            width=640,
+            cache_dir=osp.join(PROJ_ROOT, ".cache"),
+            use_cache=True,
+            num_to_load=-1,
+            filter_invalid=False,
+            ref_key="hb",
+        )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_HB:
+        used_cfg = SPLITS_HB[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, HB_BOP_VAL_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="hb",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_hb_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_HB.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        detection_utils.check_image_size(d, img)
+        depth = mmcv.imread(d["depth_file"], "unchanged") / d["depth_factor"]
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        img_vis = vis_image_mask_bbox_cv2(img, masks, bboxes=bboxes_xyxy, labels=labels)
+        img_vis_kpts2d = img.copy()
+        for anno_i in range(len(annos)):
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis_kpts2d, kpts_2d[anno_i])
+        grid_show(
+            [
+                img[:, :, [2, 1, 0]],
+                img_vis[:, :, [2, 1, 0]],
+                img_vis_kpts2d[:, :, [2, 1, 0]],
+                depth,
+            ],
+            [f"img:{d['file_name']}", "vis_img", "img_vis_kpts2d", "depth"],
+            row=2,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m core.gdrn_modeling.datasets.hb_bop_val dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from detectron2.data import detection_utils
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/hb_pbr.py b/det/yolox/data/datasets/hb_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..3324cca4fd2ca3b1c828758da4383a9e2b589f54
--- /dev/null
+++ b/det/yolox/data/datasets/hb_pbr.py
@@ -0,0 +1,520 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class HB_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get(
+            "dataset_root",
+            osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/train_pbr"),
+        )
+        # self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/hb/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]
+        self.width = data_cfg["width"]
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.hb.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+        # for debug
+        # self.scenes = [f"{i:06d}" for i in range(1)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area <= 64:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    # NOTE: if online is False: use back proj xyz coord rather rendered xyz here
+                    xyz_path = osp.join(
+                        self.dataset_root,
+                        f"{scene_id:06d}/coor_backprj/{int_im_id:06d}_{anno_i:06d}.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format(self.name))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.hb.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_hb_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+hb_model_root = "BOP_DATASETS/hb/models/"
+################################################################################
+
+HB_BOP19_20_OBJS = [ref.hb.id2obj[_i] for _i in [1, 3, 4, 8, 9, 10, 12, 15, 17, 18, 19, 22, 23, 29, 32, 33]]
+
+SPLITS_HB_PBR = dict(
+    hb_pbr_train=dict(
+        name="hb_pbr_train",
+        objs=ref.hb.objects,  # all 33 objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="hb",
+    ),
+    hbs_pbr_train=dict(
+        name="hbs_pbr_train",
+        objs=HB_BOP19_20_OBJS,  # selected 16 objects for BOP19/20
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="hb",
+    ),
+)
+
+# single obj splits
+for obj in ref.hb.objects:
+    for split in ["train_pbr"]:
+        name = "hb_{}_{}".format(obj, split)
+        if split in ["train_pbr"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_HB_PBR:
+            SPLITS_HB_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/hb/models"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="hb",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_HB_PBR:
+        used_cfg = SPLITS_HB_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, HB_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="hb",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_hb_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_HB_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / d["depth_factor"]
+        print("depth before, min: {} max: {}, mean: {}".format(depth.min(), depth.max(), depth.mean()))
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        fg_mask = (sum(masks) > 0).astype("uint8")
+        fg_mask_eroded = mask_erode_cv2(fg_mask)
+        depth_masked = depth * fg_mask_eroded
+        depth_masked[depth_masked > 3] = 0
+        print(
+            "depth masked, min: {} max: {}, mean: {}".format(
+                depth_masked.min(), depth_masked.max(), depth_masked.mean()
+            )
+        )
+
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            # xyz_path = annos[_i]["xyz_path"]
+            # xyz_info = mmcv.load(xyz_path)
+            # x1, y1, x2, y2 = xyz_info["xyxy"]
+            # xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            # xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            # xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            # xyz_show = get_emb_show(xyz)
+            # xyz_crop_show = get_emb_show(xyz_crop)
+            # img_xyz = img.copy() / 255.0
+            # mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            # fg_idx = np.where(mask_xyz != 0)
+            # img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            # img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            # img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # # diff mask
+            # diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    heatmap(depth_masked, to_rgb=True),
+                    # xyz_show,
+                    # diff_mask_xyz,
+                    # xyz_crop_show,
+                    # img_xyz[:, :, [2, 1, 0]],
+                    # img_xyz_crop[:, :, [2, 1, 0]],
+                    # img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth_masked",
+                    # "diff_mask_xyz",
+                    # "xyz_crop_show",
+                    # "img_xyz",
+                    # "img_xyz_crop",
+                    # "img_vis_crop",
+                ],
+                row=2,
+                col=2,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module hb_pbr_train
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+    from lib.utils.mask_utils import mask_erode_cv2
+    from lib.vis_utils.image import heatmap
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/icbin_bop_test.py b/det/yolox/data/datasets/icbin_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e49f29e31928a5f2a32e32da169c5d6069ed5351
--- /dev/null
+++ b/det/yolox/data/datasets/icbin_bop_test.py
@@ -0,0 +1,525 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class ICBIN_BOP_TEST_Dataset(object):
+    """icbin bop test splits."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/test"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/icbin/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.icbin.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj, obj_id in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        targets = mmcv.load(self.ann_file)
+
+        scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+        scene_im_ids = sorted(list(set(scene_im_ids)))
+
+        # load infos for each scene
+        gt_dicts = {}
+        gt_info_dicts = {}
+        cam_dicts = {}
+        for scene_id, im_id in scene_im_ids:
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+            if scene_id not in gt_dicts:
+                gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            if scene_id not in gt_info_dicts:
+                gt_info_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))  # bbox_obj, bbox_visib
+            if scene_id not in cam_dicts:
+                cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+        for scene_id, int_im_id in tqdm(scene_im_ids):
+            str_im_id = str(int_im_id)
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+
+            gt_dict = gt_dicts[scene_id]
+            gt_info_dict = gt_info_dicts[scene_id]
+            cam_dict = cam_dicts[scene_id]
+
+            rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+            assert osp.exists(rgb_path), rgb_path
+
+            depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+            scene_im_id = f"{scene_id}/{int_im_id}"
+
+            K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+            depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                "height": self.height,
+                "width": self.width,
+                "image_id": int_im_id,
+                "scene_im_id": scene_im_id,  # for evaluation
+                "cam": K,
+                "depth_factor": depth_factor,
+                "img_type": "real",  # NOTE: has background
+            }
+            insts = []
+            for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                obj_id = anno["obj_id"]
+                if obj_id not in self.cat_ids:
+                    continue
+                cur_label = self.cat2label[obj_id]  # 0-based label
+                R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                pose = np.hstack([R, t.reshape(3, 1)])
+                quat = mat2quat(R).astype("float32")
+
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                x1, y1, w, h = bbox_visib
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+
+                mask_file = osp.join(
+                    scene_root,
+                    "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                mask_visib_file = osp.join(
+                    scene_root,
+                    "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                assert osp.exists(mask_file), mask_file
+                assert osp.exists(mask_visib_file), mask_visib_file
+                # load mask visib
+                mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                mask_single = mask_single.astype("bool")
+                area = mask_single.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                # load mask full
+                mask_full = mmcv.imread(mask_file, "unchanged")
+                mask_full = mask_full.astype("bool")
+                mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYWH_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                    "mask_full": mask_full_rle,
+                    "visib_fract": visib_fract,
+                    "xyz_path": None,  #  no need for test
+                }
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                insts.append(inst)
+            if len(insts) == 0:  # filter im without anno
+                continue
+            record["annotations"] = insts
+            dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "There are {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "There are {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.cache_dir, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.icbin.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_icbin_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+##########################################################################
+
+ICBIN_OBJECTS = ["coffee_cup", "juice_carton"]
+
+SPLITS_ICBIN = dict(
+    icbin_bop_test=dict(
+        name="icbin_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/models"),
+        objs=ICBIN_OBJECTS,
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="icbin",
+    ),
+)
+
+# single obj splits for icbin bop test
+for obj in ref.icbin.objects:
+    for split in [
+        "bop_test",
+    ]:
+        name = "icbin_{}_{}".format(obj, split)
+        ann_files = [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/icbin/image_set/{}_{}.txt".format(obj, split),
+            )
+        ]
+        if name not in SPLITS_ICBIN:
+            SPLITS_ICBIN[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/models"),
+                objs=[obj],  # only this obj
+                scale_to_meter=0.001,
+                ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/test_targets_bop19.json"),
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=False,
+                ref_key="icbin",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_ICBIN:
+        used_cfg = SPLITS_ICBIN[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, ICBIN_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="icbin",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_icbin_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_ICBIN.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/icbin_pbr.py b/det/yolox/data/datasets/icbin_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51fee052064b1148263fd70252d8f845590cc75
--- /dev/null
+++ b/det/yolox/data/datasets/icbin_pbr.py
@@ -0,0 +1,482 @@
+import logging
+import hashlib
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class ICBIN_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objs
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/train_pbr"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/icbin/models_cad
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]
+        self.width = data_cfg["width"]
+
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.icbin.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs)
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+
+        dataset_dicts = []
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        # it is slow because of loading and converting masks to rle
+
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(scene_root, "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i))
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    assert osp.exists(xyz_path), xyz_path
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_file,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.icbin.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_icbin_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+icbin_model_root = "BOP_DATASETS/icbin/models_cad/"
+################################################################################
+
+
+SPLITS_ICBIN_PBR = dict(
+    icbin_pbr_train=dict(
+        name="icbin_pbr_train",
+        objs=ref.icbin.objects,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="icbin",
+    ),
+)
+
+# single obj splits
+for obj in ref.icbin.objects:
+    for split in ["train"]:
+        name = "icbin_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_ICBIN_PBR:
+            SPLITS_ICBIN_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/icbin/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="icbin",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_ICBIN_PBR:
+        used_cfg = SPLITS_ICBIN_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, ICBIN_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="icbin",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_icbin_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_ICBIN_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/itodd_bop_test.py b/det/yolox/data/datasets/itodd_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..01012957cccde04b065905d2776f396c534b2529
--- /dev/null
+++ b/det/yolox/data/datasets/itodd_bop_test.py
@@ -0,0 +1,478 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class ITODD_BOP_TEST_Dataset(object):
+    """itodd bop test splits."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/test"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/itodd/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.itodd.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj, obj_id in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        targets = mmcv.load(self.ann_file)
+
+        scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+        scene_im_ids = sorted(list(set(scene_im_ids)))
+
+        # load infos for each scene
+        # NOTE: currently no gt info available
+        # gt_dicts = {}
+        # gt_info_dicts = {}
+        cam_dicts = {}
+        for scene_id, im_id in scene_im_ids:
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+            # if scene_id not in gt_dicts:
+            #     gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            # if scene_id not in gt_info_dicts:
+            #     gt_info_dicts[scene_id] = mmcv.load(
+            #         osp.join(scene_root, "scene_gt_info.json")
+            #     )  # bbox_obj, bbox_visib
+            if scene_id not in cam_dicts:
+                cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+        for scene_id, int_im_id in tqdm(scene_im_ids):
+            str_im_id = str(int_im_id)
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+
+            # gt_dict = gt_dicts[scene_id]
+            # gt_info_dict = gt_info_dicts[scene_id]
+            cam_dict = cam_dicts[scene_id]
+
+            rgb_path = osp.join(scene_root, "gray/{:06d}.tif").format(int_im_id)
+            assert osp.exists(rgb_path), rgb_path
+
+            depth_path = osp.join(scene_root, "depth/{:06d}.tif".format(int_im_id))
+
+            scene_im_id = f"{scene_id}/{int_im_id}"
+
+            K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+            depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                "height": self.height,
+                "width": self.width,
+                "image_id": int_im_id,
+                "scene_im_id": scene_im_id,  # for evaluation
+                "cam": K,
+                "depth_factor": depth_factor,
+                "img_type": "real",  # NOTE: has background
+            }
+            # insts = []
+            # for anno_i, anno in enumerate(gt_dict[str_im_id]):
+            #     obj_id = anno["obj_id"]
+            #     if obj_id not in self.cat_ids:
+            #         continue
+            #     cur_label = self.cat2label[obj_id]  # 0-based label
+            #     R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+            #     t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+            #     pose = np.hstack([R, t.reshape(3, 1)])
+            #     quat = mat2quat(R).astype("float32")
+
+            #     proj = (record["cam"] @ t.T).T
+            #     proj = proj[:2] / proj[2]
+
+            #     bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+            #     bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+            #     x1, y1, w, h = bbox_visib
+            #     if self.filter_invalid:
+            #         if h <= 1 or w <= 1:
+            #             self.num_instances_without_valid_box += 1
+            #             continue
+
+            #     mask_file = osp.join(
+            #         scene_root,
+            #         "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+            #     )
+            #     mask_visib_file = osp.join(
+            #         scene_root,
+            #         "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+            #     )
+            #     assert osp.exists(mask_file), mask_file
+            #     assert osp.exists(mask_visib_file), mask_visib_file
+            #     # load mask visib
+            #     mask_single = mmcv.imread(mask_visib_file, "unchanged")
+            #     mask_single = mask_single.astype("bool")
+            #     area = mask_single.sum()
+            #     if area < 3:  # filter out too small or nearly invisible instances
+            #         self.num_instances_without_valid_segmentation += 1
+            #         continue
+            #     mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+            #     # load mask full
+            #     mask_full = mmcv.imread(mask_file, "unchanged")
+            #     mask_full = mask_full.astype("bool")
+            #     mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+            #     visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+            #     inst = {
+            #         "category_id": cur_label,  # 0-based label
+            #         "bbox": bbox_visib,  # TODO: load both bbox_obj and bbox_visib
+            #         "bbox_mode": BoxMode.XYWH_ABS,
+            #         "pose": pose,
+            #         "quat": quat,
+            #         "trans": t,
+            #         "centroid_2d": proj,  # absolute (cx, cy)
+            #         "segmentation": mask_rle,
+            #         "mask_full": mask_full_rle,
+            #         "visib_fract": visib_fract,
+            #         "xyz_path": None, #  no need for test
+            #     }
+
+            #     model_info = self.models_info[str(obj_id)]
+            #     inst["model_info"] = model_info
+            #     for key in ["bbox3d_and_center"]:
+            #         inst[key] = self.models[cur_label][key]
+            #     insts.append(inst)
+            # if len(insts) == 0:  # filter im without anno
+            #     continue
+            # record["annotations"] = insts
+            dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "There are {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "There are {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.cache_dir, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.itodd.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_itodd_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+##########################################################################
+
+
+SPLITS_ITODD = dict(
+    itodd_bop_test=dict(
+        name="itodd_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+        objs=ref.itodd.objects,
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=960,
+        width=1280,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="itodd",
+    ),
+)
+
+# single obj splits for itodd bop test
+for obj in ref.itodd.objects:
+    for split in ["bop_test"]:
+        name = "itodd_{}_{}".format(obj, split)
+        ann_files = [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/itodd/image_set/{}_{}.txt".format(obj, split),
+            )
+        ]
+        if name not in SPLITS_ITODD:
+            SPLITS_ITODD[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/test"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+                objs=[obj],  # only this obj
+                scale_to_meter=0.001,
+                ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/test_targets_bop19.json"),
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=960,
+                width=1280,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=False,
+                ref_key="itodd",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_ITODD:
+        used_cfg = SPLITS_ITODD[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, ITODD_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="itodd",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_itodd_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_ITODD.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        detection_utils.check_image_size(d, img)
+        depth = mmcv.imread(d["depth_file"], "unchanged") / d["depth_factor"]
+
+        imH, imW = img.shape[:2]
+        # annos = d["annotations"]
+        # masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        # bboxes = [anno["bbox"] for anno in annos]
+        # bbox_modes = [anno["bbox_mode"] for anno in annos]
+        # bboxes_xyxy = np.array(
+        #     [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)])
+        # kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        # quats = [anno["quat"] for anno in annos]
+        # transes = [anno["trans"] for anno in annos]
+        # Rs = [quat2mat(quat) for quat in quats]
+        # # 0-based label
+        # cat_ids = [anno["category_id"] for anno in annos]
+        # K = d["cam"]
+        # kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        # # TODO: visualize pose and keypoints
+        # labels = [objs[cat_id] for cat_id in cat_ids]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        # img_vis = vis_image_mask_bbox_cv2(img, masks, bboxes=bboxes_xyxy, labels=labels)
+        # img_vis_kpts2d = img.copy()
+        # for anno_i in range(len(annos)):
+        #     img_vis_kpts2d = misc.draw_projected_box3d(img_vis_kpts2d, kpts_2d[anno_i])
+        # grid_show([img[:, :, [2, 1, 0]], img_vis[:, :, [2, 1, 0]], img_vis_kpts2d[:, :, [2, 1, 0]], depth],
+        #           [f"img:{d['file_name']}", "vis_img", "img_vis_kpts2d", 'depth'],
+        #           row=2,
+        #           col=2)
+        grid_show(
+            [img[:, :, [2, 1, 0]], depth],
+            [f"img:{d['file_name']}", "depth"],
+            row=1,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+    from detectron2.data import detection_utils
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/itodd_d2.py b/det/yolox/data/datasets/itodd_d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b77f0d5fad90b9656c7b6383096bcf3c328c30c
--- /dev/null
+++ b/det/yolox/data/datasets/itodd_d2.py
@@ -0,0 +1,483 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class ITODD_Dataset:
+    """itodd."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+        # all classes are self.objs, but this enables us to evaluate on selected objs
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/itodd/val
+        self.scene_ids = data_cfg.get("scene_ids", [1])
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/itodd/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]  # 960
+        self.width = data_cfg["width"]  # 1280
+
+        self.cache_dir = data_cfg["cache_dir"]  # .cache
+        self.use_cache = data_cfg["use_cache"]  # True
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.itodd.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs)
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+
+        dataset_dicts = []  #######################################################
+
+        for scene_id in self.scene_ids:
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                im_id = int(str_im_id)
+
+                im_path = osp.join(scene_root, "gray/{:06d}.tif").format(im_id)
+                assert osp.exists(im_path), im_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.tif".format(im_id))
+
+                scene_id = int(im_path.split("/")[-3])
+
+                cam = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(im_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "depth_factor": depth_factor,
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": im_id,  # unique image_id in the dataset, for coco evaluation
+                    "scene_im_id": "{}/{}".format(scene_id, im_id),  # for evaluation
+                    "cam": cam,
+                    "depth_factor": depth_factor,
+                    "img_type": "real",
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(scene_root, "mask/{:06d}_{:06d}.png".format(im_id, anno_i))
+                    mask_visib_file = osp.join(scene_root, "mask_visib/{:06d}_{:06d}.png".format(im_id, anno_i))
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 3:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_visib,
+                        "bbox_obj": bbox_obj,
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_file,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                        "xyz_path": None,  #  no need for test
+                    }
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info(
+            "loaded dataset dicts, num_images: {}, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start)
+        )
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, f"models_{self.name}.pkl")
+        if osp.exists(cache_path) and self.use_cache:
+            # logger.info("load cached object models from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.itodd.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_itodd_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+##########################################################################
+
+SPLITS_ITODD = dict(
+    itodd_val=dict(
+        name="itodd_val",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/val"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+        objs=ref.itodd.objects,  # selected objects
+        scene_ids=[1],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=960,
+        width=1280,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="itodd",
+    ),
+)
+
+
+# single objs (num_class is from all objs)
+for obj in ref.itodd.objects:
+    name = "itodd_{}_val".format(obj)
+    select_objs = [obj]
+    if name not in SPLITS_ITODD:
+        SPLITS_ITODD[name] = dict(
+            name=name,
+            dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/val"),
+            models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+            objs=ref.itodd.objects,
+            scene_ids=[1],
+            scale_to_meter=0.001,
+            with_masks=True,  # (load masks but may not use it)
+            with_depth=True,  # (load depth path here, but may not use it)
+            height=960,
+            width=1280,
+            cache_dir=osp.join(PROJ_ROOT, ".cache"),
+            use_cache=True,
+            num_to_load=-1,
+            filter_invalid=False,
+            ref_key="itodd",
+        )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_ITODD:
+        used_cfg = SPLITS_ITODD[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, ITODD_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="itodd",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",  # TODO: add bop evaluator
+        **get_itodd_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_ITODD.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "val" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/itodd_pbr.py b/det/yolox/data/datasets/itodd_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0db0787d911a5a60304df69d013e782c1fe3d9a
--- /dev/null
+++ b/det/yolox/data/datasets/itodd_pbr.py
@@ -0,0 +1,493 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class ITODD_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/train_pbr"))
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/itodd/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.itodd.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    mask_single = mask_single.astype("bool")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.itodd.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_itodd_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+################################################################################
+
+
+SPLITS_ITODD_PBR = dict(
+    itodd_pbr_train=dict(
+        name="itodd_pbr_train",
+        objs=ref.itodd.objects,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=960,
+        width=1280,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="itodd",
+    )
+)
+
+# single obj splits
+for obj in ref.itodd.objects:
+    for split in ["train"]:
+        name = "itodd_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_ITODD_PBR:
+            SPLITS_ITODD_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/itodd/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=960,
+                width=1280,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="itodd",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_ITODD_PBR:
+        used_cfg = SPLITS_ITODD_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, ITODD_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_itodd_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_ITODD_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/lm_blender.py b/det/yolox/data/datasets/lm_blender.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3ce35fafbc0122fd066c63f7e72db85b55cd6c
--- /dev/null
+++ b/det/yolox/data/datasets/lm_blender.py
@@ -0,0 +1,512 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, lazy_property
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class LM_BLENDER_Dataset(object):
+    """lm blender data, from pvnet-rendering."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.ann_files = data_cfg["ann_files"]  # json files with image ids and pose/bbox
+        self.image_prefixes = data_cfg["image_prefixes"]
+
+        self.dataset_root = data_cfg["dataset_root"]  # lm_renders_blender/
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/lm/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+        self.depth_factor = data_cfg["depth_factor"]  # 1000.0
+
+        self.cam = data_cfg["cam"]  #
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg["cache_dir"]  # .cache
+        self.use_cache = data_cfg["use_cache"]  # True
+        # sample uniformly to get n items
+        self.n_per_obj = data_cfg.get("n_per_obj", 10000)
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+        if self.cam is None:
+            self.cam = np.array([[572.4114, 0, 325.2611], [0, 573.57043, 242.04899], [0, 0, 1]])
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.lm_full.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):  # LM_BLENDER
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    self.n_per_obj,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  #######################################################
+        assert len(self.ann_files) == len(self.image_prefixes), f"{len(self.ann_files)} != {len(self.image_prefixes)}"
+
+        for ann_file, scene_root in zip(tqdm(self.ann_files), self.image_prefixes):
+            # each scene is an object
+            assert osp.exists(ann_file), ann_file
+            scene_gt_dict = mmcv.load(ann_file)
+            # sample uniformly (equal space)
+            indices = list(scene_gt_dict.keys())
+            if self.n_per_obj > 0:
+                sample_num = min(self.n_per_obj, len(scene_gt_dict))
+                sel_indices_idx = np.linspace(0, len(scene_gt_dict) - 1, sample_num, dtype=np.int32)
+                sel_indices = [indices[int(_i)] for _i in sel_indices_idx]
+            else:
+                sel_indices = indices
+
+            for str_im_id in tqdm(sel_indices):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "{}.jpg").format(str_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "{}_depth_opengl.png".format(str_im_id))
+
+                obj_name = osp.basename(ann_file).split("_")[0]  # obj_gt.json
+                obj_id = ref.lm_full.obj2id[obj_name]
+                if obj_name not in self.objs:
+                    continue
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": f"{obj_id}/{int_im_id}",
+                    "cam": self.cam,
+                    "img_type": "syn_blender",  # has bg
+                }
+
+                cur_label = self.obj2label[obj_name]  # 0-based label
+                anno = scene_gt_dict[str_im_id][0]  # only one object
+                R = np.array(anno["cam_R_m2c"]).reshape(3, 3)
+                t = np.array(anno["cam_t_m2c"]).reshape(-1) / 1000
+                pose = np.hstack([R, t.reshape(3, 1)])
+                quat = mat2quat(R).astype("float32")
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                bbox_visib = anno["bbox_visib"]
+                x1, y1, w, h = bbox_visib
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+
+                mask_path = osp.join(scene_root, "{}_mask_opengl.png".format(str_im_id))
+                mask = mmcv.imread(mask_path, "unchanged")
+                mask = (mask > 0).astype(np.uint8)
+
+                area = mask.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                    continue
+                mask_rle = binary_mask_to_rle(mask, compressed=True)
+
+                visib_fract = anno.get("visib_fract", 1.0)
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox_visib,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYWH_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                    "visib_fract": visib_fract,
+                }
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                record["annotations"] = [inst]
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        # if self.num_to_load > 0:
+        #     self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+        #     random.shuffle(dataset_dicts)
+        #     dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info(
+            "loaded dataset dicts, num_images: {}, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start)
+        )
+
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.lm_full.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_lm_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+LM_13_OBJECTS = [
+    "ape",
+    "benchvise",
+    "camera",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+    "iron",
+    "lamp",
+    "phone",
+]  # no bowl, cup
+LM_OCC_OBJECTS = [
+    "ape",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+]
+################################################################################
+
+SPLITS_LM_BLENDER = dict(
+    lm_blender_13_train=dict(
+        name="lm_blender_13_train",  # BB8 training set
+        dataset_root=osp.join(DATASETS_ROOT, "lm_renders_blender/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        objs=LM_13_OBJECTS,  # selected objects
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "lm_renders_blender/renders/{}_gt.json".format(_obj),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        image_prefixes=[
+            osp.join(DATASETS_ROOT, "lm_renders_blender/renders/{}".format(_obj)) for _obj in LM_13_OBJECTS
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        depth_factor=1000.0,
+        cam=ref.lm_full.camera_matrix,
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        n_per_obj=-1,  # num per class, -1 for all 10k
+        filter_invalid=False,
+        ref_key="lm_full",
+    ),
+    lmo_blender_train=dict(
+        name="lmo_blender_train",
+        dataset_root=osp.join(DATASETS_ROOT, "lm_renders_blender/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+        objs=LM_OCC_OBJECTS,  # selected objects
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "lm_renders_blender/renders/{}_gt.json".format(_obj),
+            )
+            for _obj in LM_OCC_OBJECTS
+        ],
+        image_prefixes=[
+            osp.join(DATASETS_ROOT, "lm_renders_blender/renders/{}".format(_obj)) for _obj in LM_OCC_OBJECTS
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        depth_factor=1000.0,
+        cam=ref.lmo_full.camera_matrix,
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        n_per_obj=-1,  # n per class, -1 for all 10k
+        filter_invalid=False,
+        ref_key="lmo_full",
+    ),
+)
+
+# single obj splits
+for obj in ref.lm_full.objects:
+    for split in ["train"]:
+        for name_prefix in ["lm", "lmo"]:
+            name = "{}_blender_{}_{}".format(name_prefix, obj, split)
+            ref_key = f"{name_prefix}_full"
+            ann_files = [
+                osp.join(
+                    DATASETS_ROOT,
+                    "lm_renders_blender/renders/{}_gt.json".format(obj),
+                )
+            ]
+            if split in ["train"]:
+                filter_invalid = True
+            else:
+                raise ValueError("{}".format(split))
+            if name not in SPLITS_LM_BLENDER:
+                SPLITS_LM_BLENDER[name] = dict(
+                    name=name,
+                    dataset_root=osp.join(DATASETS_ROOT, "lm_renders_blender/"),
+                    models_root=osp.join(DATASETS_ROOT, f"BOP_DATASETS/{name_prefix}/models"),
+                    objs=[obj],  # only this obj
+                    ann_files=ann_files,
+                    image_prefixes=[osp.join(DATASETS_ROOT, f"lm_renders_blender/renders/{obj}")],
+                    scale_to_meter=0.001,
+                    with_masks=True,  # (load masks but may not use it)
+                    with_depth=True,  # (load depth path here, but may not use it)
+                    depth_factor=1000.0,
+                    cam=ref.__dict__[ref_key].camera_matrix,
+                    height=480,
+                    width=640,
+                    cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                    use_cache=True,
+                    n_per_obj=-1,
+                    filter_invalid=False,
+                    ref_key=ref_key,
+                )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_LM_BLENDER:
+        used_cfg = SPLITS_LM_BLENDER[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, LM_BLENDER_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="coco_bop",
+        **get_lm_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_LM_BLENDER.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        anno = d["annotations"][0]  # only one instance per image
+        imH, imW = img.shape[:2]
+        mask = cocosegm2mask(anno["segmentation"], imH, imW)
+        bbox = anno["bbox"]
+        bbox_mode = anno["bbox_mode"]
+        bbox_xyxy = np.array(BoxMode.convert(bbox, bbox_mode, BoxMode.XYXY_ABS))
+        kpt3d = anno["bbox3d_and_center"]
+        quat = anno["quat"]
+        trans = anno["trans"]
+        R = quat2mat(quat)
+        # 0-based label
+        cat_id = anno["category_id"]
+        K = d["cam"]
+        kpt_2d = misc.project_pts(kpt3d, K, R, trans)
+        # # TODO: visualize pose and keypoints
+        label = objs[cat_id]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        img_vis = vis_image_mask_bbox_cv2(img, [mask], bboxes=[bbox_xyxy], labels=[label])
+        img_vis_kpt2d = img.copy()
+        img_vis_kpt2d = misc.draw_projected_box3d(
+            img_vis_kpt2d,
+            kpt_2d,
+            middle_color=None,
+            bottom_color=(128, 128, 128),
+        )
+
+        grid_show(
+            [
+                img[:, :, [2, 1, 0]],
+                img_vis[:, :, [2, 1, 0]],
+                img_vis_kpt2d[:, :, [2, 1, 0]],
+                depth,
+            ],
+            ["img", "vis_img", "img_vis_kpts2d", "depth"],
+            row=2,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m core.datasets.lm_blender dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/lm_dataset_d2.py b/det/yolox/data/datasets/lm_dataset_d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..dae93146b3afde3ce0abddcc28150ca76d21634d
--- /dev/null
+++ b/det/yolox/data/datasets/lm_dataset_d2.py
@@ -0,0 +1,886 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class LM_Dataset(object):
+    """lm splits."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.ann_files = data_cfg["ann_files"]  # idx files with image ids
+        self.image_prefixes = data_cfg["image_prefixes"]
+        self.xyz_prefixes = data_cfg["xyz_prefixes"]
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/lm/
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/lm/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        self.filter_scene = data_cfg.get("filter_scene", False)
+        self.debug_im_id = data_cfg.get("debug_im_id", None)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.lm_full.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):  # LM_Dataset
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        assert len(self.ann_files) == len(self.image_prefixes), f"{len(self.ann_files)} != {len(self.image_prefixes)}"
+        assert len(self.ann_files) == len(self.xyz_prefixes), f"{len(self.ann_files)} != {len(self.xyz_prefixes)}"
+        unique_im_id = 0
+        for ann_file, scene_root, xyz_root in zip(tqdm(self.ann_files), self.image_prefixes, self.xyz_prefixes):
+            # linemod each scene is an object
+            with open(ann_file, "r") as f_ann:
+                indices = [line.strip("\r\n") for line in f_ann.readlines()]  # string ids
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))  # bbox_obj, bbox_visib
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+            for im_id in tqdm(indices):
+                int_im_id = int(im_id)
+                str_im_id = str(int_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_id = int(rgb_path.split("/")[-3])
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                if self.debug_im_id is not None:
+                    if self.debug_im_id != scene_im_id:
+                        continue
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]
+                if self.filter_scene:
+                    if scene_id not in self.cat_ids:
+                        continue
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": unique_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "real",
+                }
+                unique_im_id += 1
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    mask_single = mask_single.astype("bool")
+                    area = mask_single.sum()
+                    if area < 3:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                    }
+
+                    if "test" not in self.name.lower():
+                        # if True:
+                        xyz_path = osp.join(xyz_root, f"{int_im_id:06d}_{anno_i:06d}.pkl")
+                        assert osp.exists(xyz_path), xyz_path
+                        inst["xyz_path"] = xyz_path
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.cache_dir, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.lm_full.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_lm_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+LM_13_OBJECTS = [
+    "ape",
+    "benchvise",
+    "camera",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+    "iron",
+    "lamp",
+    "phone",
+]  # no bowl, cup
+LM_OCC_OBJECTS = [
+    "ape",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+]
+################################################################################
+
+SPLITS_LM = dict(
+    lm_13_train=dict(
+        name="lm_13_train",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        objs=LM_13_OBJECTS,  # selected objects
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/image_set/{}_{}.txt".format(_obj, "train"),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        image_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/{:06d}".format(ref.lm_full.obj2id[_obj]),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[_obj]),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=True,
+        filter_invalid=True,
+        ref_key="lm_full",
+    ),
+    lm_13_test=dict(
+        name="lm_13_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        objs=LM_13_OBJECTS,
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/image_set/{}_{}.txt".format(_obj, "test"),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        # NOTE: scene root
+        image_prefixes=[
+            osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[_obj])
+            for _obj in LM_13_OBJECTS
+        ],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[_obj]),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=True,
+        filter_invalid=False,
+        ref_key="lm_full",
+    ),
+    lmo_train=dict(
+        name="lmo_train",
+        # use lm real all (8 objects) to train for lmo
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        objs=LM_OCC_OBJECTS,  # selected objects
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/image_set/{}_{}.txt".format(_obj, "all"),
+            )
+            for _obj in LM_OCC_OBJECTS
+        ],
+        image_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/{:06d}".format(ref.lmo_full.obj2id[_obj]),
+            )
+            for _obj in LM_OCC_OBJECTS
+        ],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lmo_full.obj2id[_obj]),
+            )
+            for _obj in LM_OCC_OBJECTS
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=True,
+        filter_invalid=True,
+        ref_key="lmo_full",
+    ),
+    lmo_NoBopTest_train=dict(
+        name="lmo_NoBopTest_train",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+        objs=LM_OCC_OBJECTS,
+        ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_no_bop_test.txt")],
+        image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lmo/test/xyz_crop/{:06d}".format(2),
+            )
+        ],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=False,
+        filter_invalid=True,
+        ref_key="lmo_full",
+    ),
+    lmo_test=dict(
+        name="lmo_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+        objs=LM_OCC_OBJECTS,
+        ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_test.txt")],
+        # NOTE: scene root
+        image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[None],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=False,
+        filter_invalid=False,
+        ref_key="lmo_full",
+    ),
+    lmo_bop_test=dict(
+        name="lmo_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+        objs=LM_OCC_OBJECTS,
+        ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/image_set/lmo_bop_test.txt")],
+        # NOTE: scene root
+        image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+        xyz_prefixes=[None],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=False,
+        filter_invalid=False,
+        ref_key="lmo_full",
+    ),
+)
+
+# single obj splits for lm real
+for obj in ref.lm_full.objects:
+    for split in ["train", "test", "all"]:
+        name = "lm_real_{}_{}".format(obj, split)
+        ann_files = [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/lm/image_set/{}_{}.txt".format(obj, split),
+            )
+        ]
+        if split in ["train", "all"]:  # all is used to train lmo
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+                objs=[obj],  # only this obj
+                ann_files=ann_files,
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[obj])],
+                xyz_prefixes=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[obj]),
+                    )
+                ],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                filter_scene=True,
+                ref_key="lm_full",
+            )
+
+# single obj splits for lmo_NoBopTest_train
+for obj in ref.lmo_full.objects:
+    for split in ["train"]:
+        name = "lmo_NoBopTest_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/image_set/lmo_no_bop_test.txt",
+                    )
+                ],
+                # NOTE: scene root
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/test/xyz_crop/{:06d}".format(2),
+                    )
+                ],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_scene=False,
+                filter_invalid=filter_invalid,
+                ref_key="lmo_full",
+            )
+
+# single obj splits for lmo_test
+for obj in ref.lmo_full.objects:
+    for split in ["test"]:
+        name = "lmo_{}_{}".format(obj, split)
+        if split in ["train", "all"]:  # all is used to train lmo
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/image_set/lmo_test.txt",
+                    )
+                ],
+                # NOTE: scene root
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[None],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_scene=False,
+                filter_invalid=False,
+                ref_key="lmo_full",
+            )
+
+# single obj splits for lmo_bop_test
+for obj in ref.lmo_full.objects:
+    for split in ["test"]:
+        name = "lmo_{}_bop_{}".format(obj, split)
+        if split in ["train", "all"]:  # all is used to train lmo
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM:
+            SPLITS_LM[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/lmo/image_set/lmo_bop_test.txt",
+                    )
+                ],
+                # NOTE: scene root
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/test/{:06d}").format(2)],
+                xyz_prefixes=[None],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_scene=False,
+                filter_invalid=False,
+                ref_key="lmo_full",
+            )
+
+# ================ add single image dataset for debug =======================================
+debug_im_ids = {
+    "train": {obj: [] for obj in ref.lm_full.objects},
+    "test": {obj: [] for obj in ref.lm_full.objects},
+}
+for obj in ref.lm_full.objects:
+    for split in ["train", "test"]:
+        cur_ann_file = osp.join(DATASETS_ROOT, f"BOP_DATASETS/lm/image_set/{obj}_{split}.txt")
+        ann_files = [cur_ann_file]
+
+        im_ids = []
+        with open(cur_ann_file, "r") as f:
+            for line in f:
+                # scene_id(obj_id)/im_id
+                im_ids.append("{}/{}".format(ref.lm_full.obj2id[obj], int(line.strip("\r\n"))))
+
+        debug_im_ids[split][obj] = im_ids
+        for debug_im_id in debug_im_ids[split][obj]:
+            name = "lm_single_{}{}_{}".format(obj, debug_im_id.split("/")[1], split)
+            if name not in SPLITS_LM:
+                SPLITS_LM[name] = dict(
+                    name=name,
+                    dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/"),
+                    models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+                    objs=[obj],  # only this obj
+                    ann_files=ann_files,
+                    image_prefixes=[
+                        osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/test/{:06d}").format(ref.lm_full.obj2id[obj])
+                    ],
+                    xyz_prefixes=[
+                        osp.join(
+                            DATASETS_ROOT,
+                            "BOP_DATASETS/lm/test/xyz_crop/{:06d}".format(ref.lm_full.obj2id[obj]),
+                        )
+                    ],
+                    scale_to_meter=0.001,
+                    with_masks=True,  # (load masks but may not use it)
+                    with_depth=True,  # (load depth path here, but may not use it)
+                    height=480,
+                    width=640,
+                    cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                    use_cache=True,
+                    num_to_load=-1,
+                    filter_invalid=False,
+                    filter_scene=True,
+                    ref_key="lm_full",
+                    debug_im_id=debug_im_id,  # NOTE: debug im id
+                )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_LM:
+        used_cfg = SPLITS_LM[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, LM_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="linemod",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_lm_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_LM.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/lm_pbr.py b/det/yolox/data/datasets/lm_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f42489c60e6b2b71fa486aaad86f145f4befb46
--- /dev/null
+++ b/det/yolox/data/datasets/lm_pbr.py
@@ -0,0 +1,561 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class LM_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/train_pbr"))
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/lm/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.lm_full.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    mask_single = mask_single.astype("bool")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    assert osp.exists(xyz_path), xyz_path
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.lm_full.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_lm_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+LM_13_OBJECTS = [
+    "ape",
+    "benchvise",
+    "camera",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+    "iron",
+    "lamp",
+    "phone",
+]  # no bowl, cup
+LM_OCC_OBJECTS = [
+    "ape",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+]
+lm_model_root = "BOP_DATASETS/lm/models/"
+lmo_model_root = "BOP_DATASETS/lmo/models/"
+################################################################################
+
+
+SPLITS_LM_PBR = dict(
+    lm_pbr_13_train=dict(
+        name="lm_pbr_13_train",
+        objs=LM_13_OBJECTS,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="lm_full",
+    ),
+    lmo_pbr_train=dict(
+        name="lmo_pbr_train",
+        objs=LM_OCC_OBJECTS,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="lmo_full",
+    ),
+)
+
+# single obj splits
+for obj in ref.lm_full.objects:
+    for split in ["train"]:
+        name = "lm_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM_PBR:
+            SPLITS_LM_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="lm_full",
+            )
+
+# lmo single objs
+for obj in ref.lmo_full.objects:
+    for split in ["train"]:
+        name = "lmo_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM_PBR:
+            SPLITS_LM_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lmo/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="lmo_full",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_LM_PBR:
+        used_cfg = SPLITS_LM_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, LM_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_lm_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_LM_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/lm_syn_imgn.py b/det/yolox/data/datasets/lm_syn_imgn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1732c29d4aa24b2c618b47cfc9beabe45246b8e3
--- /dev/null
+++ b/det/yolox/data/datasets/lm_syn_imgn.py
@@ -0,0 +1,473 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import (
+    binary_mask_to_rle,
+    cocosegm2mask,
+    mask2bbox_xywh,
+)
+from lib.utils.utils import dprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class LM_SYN_IMGN_Dataset(object):
+    """lm synthetic data, imgn(imagine) from DeepIM."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.ann_files = data_cfg["ann_files"]  # idx files with image ids
+        self.image_prefixes = data_cfg["image_prefixes"]
+
+        self.dataset_root = data_cfg["dataset_root"]  # lm_imgn
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/lm/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+        self.depth_factor = data_cfg["depth_factor"]  # 1000.0
+
+        self.cam = data_cfg["cam"]  #
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg["cache_dir"]  # .cache
+        self.use_cache = data_cfg["use_cache"]  # True
+        # sample uniformly to get n items
+        self.n_per_obj = data_cfg.get("n_per_obj", 1000)
+        self.filter_invalid = data_cfg["filter_invalid"]
+        self.filter_scene = data_cfg.get("filter_scene", False)
+        ##################################################
+        if self.cam is None:
+            self.cam = np.array([[572.4114, 0, 325.2611], [0, 573.57043, 242.04899], [0, 0, 1]])
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.lm_full.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):  # LM_SYN_IMGN_Dataset
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    self.n_per_obj,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.dataset_root,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  #######################################################
+        assert len(self.ann_files) == len(self.image_prefixes), f"{len(self.ann_files)} != {len(self.image_prefixes)}"
+        for ann_file, scene_root in zip(self.ann_files, self.image_prefixes):
+            # linemod each scene is an object
+            with open(ann_file, "r") as f_ann:
+                indices = [line.strip("\r\n").split()[-1] for line in f_ann.readlines()]  # string ids
+            # sample uniformly (equal space)
+            if self.n_per_obj > 0:
+                sample_num = min(self.n_per_obj, len(indices))
+                sel_indices_idx = np.linspace(0, len(indices) - 1, sample_num, dtype=np.int32)
+                sel_indices = [indices[int(_i)] for _i in sel_indices_idx]
+            else:
+                sel_indices = indices
+
+            for im_id in tqdm(sel_indices):
+                rgb_path = osp.join(scene_root, "{}-color.png").format(im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "{}-depth.png".format(im_id))
+
+                obj_name = im_id.split("/")[0]
+                if obj_name == "benchviseblue":
+                    obj_name = "benchvise"
+                obj_id = ref.lm_full.obj2id[obj_name]
+                if self.filter_scene:
+                    if obj_name not in self.objs:
+                        continue
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": im_id.split("/")[-1],
+                    "scene_im_id": im_id,
+                    "cam": self.cam,
+                    "img_type": "syn",
+                }
+
+                cur_label = self.obj2label[obj_name]  # 0-based label
+                pose_path = osp.join(scene_root, "{}-pose.txt".format(im_id))
+                pose = np.loadtxt(pose_path, skiprows=1)
+                R = pose[:3, :3]
+                t = pose[:3, 3]
+                quat = mat2quat(R).astype("float32")
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                depth = mmcv.imread(depth_path, "unchanged") / 1000.0
+                mask = (depth > 0).astype(np.uint8)
+
+                bbox_obj = mask2bbox_xywh(mask)
+                x1, y1, w, h = bbox_obj
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+                area = mask.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                    continue
+                mask_rle = binary_mask_to_rle(mask, compressed=True)
+
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYWH_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                }
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                # TODO: using full mask
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                record["annotations"] = [inst]
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        # if self.num_to_load > 0:
+        #     self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+        #     random.shuffle(dataset_dicts)
+        #     dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info(
+            "loaded dataset dicts, num_images: {}, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start)
+        )
+
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # logger.info("load cached object models from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.lm_full.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        # return 1
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_lm_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+LM_13_OBJECTS = [
+    "ape",
+    "benchvise",
+    "camera",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+    "iron",
+    "lamp",
+    "phone",
+]  # no bowl, cup
+################################################################################
+
+SPLITS_LM_IMGN_13 = dict(
+    lm_imgn_13_train_1k_per_obj=dict(
+        name="lm_imgn_13_train_1k_per_obj",  # BB8 training set
+        dataset_root=osp.join(DATASETS_ROOT, "lm_imgn/"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+        objs=LM_13_OBJECTS,  # selected objects
+        ann_files=[
+            osp.join(
+                DATASETS_ROOT,
+                "lm_imgn/image_set/{}_{}.txt".format("train", _obj),
+            )
+            for _obj in LM_13_OBJECTS
+        ],
+        image_prefixes=[osp.join(DATASETS_ROOT, "lm_imgn/imgn") for _obj in LM_13_OBJECTS],
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        depth_factor=1000.0,
+        cam=ref.lm_full.camera_matrix,
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        n_per_obj=1000,  # 1000 per class
+        filter_scene=True,
+        filter_invalid=False,
+        ref_key="lm_full",
+    )
+)
+
+# single obj splits
+for obj in ref.lm_full.objects:
+    for split in ["train"]:
+        name = "lm_imgn_13_{}_{}_1k".format(obj, split)
+        ann_files = [osp.join(DATASETS_ROOT, "lm_imgn/image_set/{}_{}.txt".format(split, obj))]
+        if split in ["train"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_LM_IMGN_13:
+            SPLITS_LM_IMGN_13[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "lm_imgn/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/lm/models"),
+                objs=[obj],  # only this obj
+                ann_files=ann_files,
+                image_prefixes=[osp.join(DATASETS_ROOT, "lm_imgn/imgn/")],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                depth_factor=1000.0,
+                cam=ref.lm_full.camera_matrix,
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                n_per_obj=1000,
+                filter_invalid=False,
+                filter_scene=True,
+                ref_key="lm_full",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_LM_IMGN_13:
+        used_cfg = SPLITS_LM_IMGN_13[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, LM_SYN_IMGN_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="coco_bop",
+        **get_lm_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_LM_IMGN_13.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        anno = d["annotations"][0]  # only one instance per image
+        imH, imW = img.shape[:2]
+        mask = cocosegm2mask(anno["segmentation"], imH, imW)
+        bbox = anno["bbox"]
+        bbox_mode = anno["bbox_mode"]
+        bbox_xyxy = np.array(BoxMode.convert(bbox, bbox_mode, BoxMode.XYXY_ABS))
+        kpt3d = anno["bbox3d_and_center"]
+        quat = anno["quat"]
+        trans = anno["trans"]
+        R = quat2mat(quat)
+        # 0-based label
+        cat_id = anno["category_id"]
+        K = d["cam"]
+        kpt_2d = misc.project_pts(kpt3d, K, R, trans)
+        # # TODO: visualize pose and keypoints
+        label = objs[cat_id]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        img_vis = vis_image_mask_bbox_cv2(img, [mask], bboxes=[bbox_xyxy], labels=[label])
+        img_vis_kpt2d = img.copy()
+        img_vis_kpt2d = misc.draw_projected_box3d(
+            img_vis_kpt2d,
+            kpt_2d,
+            middle_color=None,
+            bottom_color=(128, 128, 128),
+        )
+
+        grid_show(
+            [
+                img[:, :, [2, 1, 0]],
+                img_vis[:, :, [2, 1, 0]],
+                img_vis_kpt2d[:, :, [2, 1, 0]],
+                depth,
+            ],
+            ["img", "vis_img", "img_vis_kpts2d", "depth"],
+            row=2,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m det.yolov4.datasets.lm_syn_imgn dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/mosaicdetection.py b/det/yolox/data/datasets/mosaicdetection.py
new file mode 100644
index 0000000000000000000000000000000000000000..04dac64f1b85c786756ec3afcc148d099326d176
--- /dev/null
+++ b/det/yolox/data/datasets/mosaicdetection.py
@@ -0,0 +1,389 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import random
+
+import cv2
+import numpy as np
+from core.utils.my_comm import get_local_rank
+from core.utils.augment import AugmentRGB
+
+from det.yolox.utils import adjust_box_anns
+
+from ..data_augment import random_affine, augment_hsv
+from .datasets_wrapper import Dataset
+
+
+def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
+    # TODO update doc
+    # index0 to top left part of image
+    if mosaic_index == 0:
+        x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
+        small_coord = w - (x2 - x1), h - (y2 - y1), w, h
+    # index1 to top right part of image
+    elif mosaic_index == 1:
+        x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
+        small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
+    # index2 to bottom left part of image
+    elif mosaic_index == 2:
+        x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
+        small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
+    # index2 to bottom right part of image
+    elif mosaic_index == 3:
+        x1, y1, x2, y2 = (
+            xc,
+            yc,
+            min(xc + w, input_w * 2),
+            min(input_h * 2, yc + h),
+        )  # noqa
+        small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
+    return (x1, y1, x2, y2), small_coord
+
+
+class MosaicDetection(Dataset):
+    """Detection dataset wrapper that performs mixup for normal dataset."""
+
+    def __init__(
+        self,
+        img_size,
+        mosaic=True,
+        preproc=None,
+        degrees=10.0,
+        translate=0.1,
+        mosaic_scale=(0.5, 1.5),
+        mixup_scale=(0.5, 1.5),
+        shear=2.0,
+        enable_mixup=True,
+        mosaic_prob=1.0,
+        mixup_prob=1.0,
+        COLOR_AUG_PROB=0.0,
+        COLOR_AUG_TYPE="",
+        COLOR_AUG_CODE=(),
+        AUG_HSV_PROB=0,
+        HSV_H=0,
+        HSV_S=0,
+        HSV_V=0,
+        FORMAT="RGB",
+        *args
+    ):
+        """
+
+        Args:
+            img_size (tuple): (h, w)
+            mosaic (bool): enable mosaic augmentation or not.
+            preproc (func):
+            degrees (float):
+            translate (float):
+            mosaic_scale (tuple):
+            mixup_scale (tuple):
+            shear (float):
+            enable_mixup (bool):
+            *args(tuple) : Additional arguments for mixup random sampler.
+        """
+        super().__init__(img_size, mosaic=mosaic)
+        self.preproc = preproc
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = mosaic_scale
+        self.shear = shear
+        self.mixup_scale = mixup_scale
+        self.enable_mosaic = mosaic
+        self.enable_mixup = enable_mixup
+        self.mosaic_prob = mosaic_prob
+        self.mixup_prob = mixup_prob
+        self.local_rank = get_local_rank()
+
+        # color aug config
+        self.color_aug_prob = COLOR_AUG_PROB
+        self.color_aug_type = COLOR_AUG_TYPE
+        self.color_aug_code = COLOR_AUG_CODE
+
+        # hsv aug config
+        self.aug_hsv_prob = AUG_HSV_PROB
+        self.hsv_h = HSV_H
+        self.hsv_s = HSV_S
+        self.hsv_v = HSV_V
+        self.img_format = FORMAT
+
+        if self.color_aug_prob > 0:
+            self.color_augmentor = self._get_color_augmentor(aug_type=self.color_aug_type, aug_code=self.color_aug_code)
+        else:
+            self.color_augmentor = None
+
+    def init_dataset(self, dataset):
+        # dataset(Dataset) : Pytorch dataset object.
+        self._dataset = dataset
+        return self
+
+    def __len__(self):
+        return len(self._dataset)
+
+    @Dataset.mosaic_getitem
+    def __getitem__(self, idx):
+        if self.enable_mosaic and random.random() < self.mosaic_prob:
+            mosaic_labels = []
+            input_dim = self._dataset.input_dim
+            input_h, input_w = input_dim[0], input_dim[1]
+
+            # yc, xc = s, s  # mosaic center x, y
+            yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
+            xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
+
+            # 3 additional image indices
+            indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
+
+            for i_mosaic, index in enumerate(indices):
+                img, _labels, scene_im_id, _, img_id = self._dataset.pull_item(index)
+                h0, w0 = img.shape[:2]  # orig hw
+                scale = min(1.0 * input_h / h0, 1.0 * input_w / w0)
+                img = cv2.resize(
+                    img,
+                    (int(w0 * scale), int(h0 * scale)),
+                    interpolation=cv2.INTER_LINEAR,
+                )
+                # generate output mosaic image
+                (h, w, c) = img.shape[:3]
+                if i_mosaic == 0:
+                    mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
+
+                # suffix l means large image, while s means small image in mosaic aug.
+                (l_x1, l_y1, l_x2, l_y2), (
+                    s_x1,
+                    s_y1,
+                    s_x2,
+                    s_y2,
+                ) = get_mosaic_coordinate(mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w)
+
+                mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
+                padw, padh = l_x1 - s_x1, l_y1 - s_y1
+
+                labels = _labels.copy()
+                # Normalized xywh to pixel xyxy format
+                if _labels.size > 0:
+                    labels[:, 0] = scale * _labels[:, 0] + padw
+                    labels[:, 1] = scale * _labels[:, 1] + padh
+                    labels[:, 2] = scale * _labels[:, 2] + padw
+                    labels[:, 3] = scale * _labels[:, 3] + padh
+                mosaic_labels.append(labels)
+
+            if len(mosaic_labels):
+                mosaic_labels = np.concatenate(mosaic_labels, 0)
+                np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
+                np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
+                np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
+                np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
+
+            mosaic_img, mosaic_labels = random_affine(
+                mosaic_img,
+                mosaic_labels,
+                target_size=(input_w, input_h),
+                degrees=self.degrees,
+                translate=self.translate,
+                scales=self.scale,
+                shear=self.shear,
+            )
+
+            # -----------------------------------------------------------------
+            # CopyPaste: https://arxiv.org/abs/2012.07177
+            # -----------------------------------------------------------------
+            if self.enable_mixup and not len(mosaic_labels) == 0 and random.random() < self.mixup_prob:
+                mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
+            mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
+            img_info = (mix_img.shape[1], mix_img.shape[0])
+
+            # Augment colorspace
+            dtype = mix_img.dtype
+            mix_img = mix_img.transpose(2, 1, 0).astype(np.uint8).copy()
+            # cv2.imwrite(f'output/transposed.png', mix_img)
+            if np.random.rand() < self.aug_hsv_prob:
+                augment_hsv(
+                    mix_img,
+                    hgain=self.hsv_h,
+                    sgain=self.hsv_s,
+                    vgain=self.hsv_v,
+                    source_format=self.img_format,
+                )
+
+            # color augment
+            if self.color_aug_prob > 0 and self.color_augmentor is not None:
+                if np.random.rand() < self.color_aug_prob:
+                    mix_img = self._color_aug(mix_img, self.color_aug_type)
+            mix_img = mix_img.transpose(2, 1, 0).astype(dtype).copy()
+            # cv2.imwrite('output/transposed_back.png', mix_img.astype(dtype).copy())
+
+            # -----------------------------------------------------------------
+            # img_info and img_id are not used for training.
+            # They are also hard to be specified on a mosaic image.
+            # -----------------------------------------------------------------
+            return mix_img, padded_labels, scene_im_id, img_info, img_id
+
+        else:
+            self._dataset._input_dim = self.input_dim
+            img, label, scene_im_id, img_info, img_id = self._dataset.pull_item(idx)
+            img, label = self.preproc(img, label, self.input_dim)
+            return img, label, scene_im_id, img_info, img_id
+
+    def mixup(self, origin_img, origin_labels, input_dim):
+        jit_factor = random.uniform(*self.mixup_scale)
+        FLIP = random.uniform(0, 1) > 0.5
+        cp_labels = []
+        while len(cp_labels) == 0:
+            cp_index = random.randint(0, self.__len__() - 1)
+            cp_labels = self._dataset.load_anno(cp_index)
+        img, cp_labels, _, _, _ = self._dataset.pull_item(cp_index)
+
+        if len(img.shape) == 3:
+            cp_img = np.ones((input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114
+        else:
+            cp_img = np.ones(input_dim, dtype=np.uint8) * 114
+        cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
+            interpolation=cv2.INTER_LINEAR,
+        )
+
+        cp_img[: int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)] = resized_img
+
+        cp_img = cv2.resize(
+            cp_img,
+            (int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
+        )
+        cp_scale_ratio *= jit_factor
+
+        if FLIP:
+            cp_img = cp_img[:, ::-1, :]
+
+        origin_h, origin_w = cp_img.shape[:2]
+        target_h, target_w = origin_img.shape[:2]
+        padded_img = np.zeros((max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8)
+        padded_img[:origin_h, :origin_w] = cp_img
+
+        x_offset, y_offset = 0, 0
+        if padded_img.shape[0] > target_h:
+            y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
+        if padded_img.shape[1] > target_w:
+            x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
+        padded_cropped_img = padded_img[y_offset : y_offset + target_h, x_offset : x_offset + target_w]
+
+        cp_bboxes_origin_np = adjust_box_anns(cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h)
+        if FLIP:
+            cp_bboxes_origin_np[:, 0::2] = origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
+        cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
+        cp_bboxes_transformed_np[:, 0::2] = np.clip(cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w)
+        cp_bboxes_transformed_np[:, 1::2] = np.clip(cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h)
+
+        cls_labels = cp_labels[:, 4:5].copy()
+        box_labels = cp_bboxes_transformed_np
+        labels = np.hstack((box_labels, cls_labels))
+        origin_labels = np.vstack((origin_labels, labels))
+        origin_img = origin_img.astype(np.float32)
+        origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
+
+        return origin_img.astype(np.uint8), origin_labels
+
+    def _get_color_augmentor(self, aug_type="ROI10D", aug_code=None):
+        # fmt: off
+        if aug_type.lower() == "roi10d":
+            color_augmentor = AugmentRGB(
+                brightness_delta=2.5 / 255.,  # 0,
+                lighting_std=0.3,
+                saturation_var=(0.95, 1.05),  #(1, 1),
+                contrast_var=(0.95, 1.05))  # (1, 1))  #
+        elif aug_type.lower() == "aae":
+            import imgaug.augmenters as iaa  # noqa
+            from imgaug.augmenters import (Sequential, SomeOf, OneOf, Sometimes, WithColorspace, WithChannels, Noop,
+                                           Lambda, AssertLambda, AssertShape, Scale, CropAndPad, Pad, Crop, Fliplr,
+                                           Flipud, Superpixels, ChangeColorspace, PerspectiveTransform, Grayscale,
+                                           GaussianBlur, AverageBlur, MedianBlur, Convolve, Sharpen, Emboss, EdgeDetect,
+                                           DirectedEdgeDetect, Add, AddElementwise, AdditiveGaussianNoise, Multiply,
+                                           MultiplyElementwise, Dropout, CoarseDropout, Invert, ContrastNormalization,
+                                           Affine, PiecewiseAffine, ElasticTransformation, pillike, LinearContrast)  # noqa
+            aug_code = """Sequential([
+                # Sometimes(0.5, PerspectiveTransform(0.05)),
+                # Sometimes(0.5, CropAndPad(percent=(-0.05, 0.1))),
+                # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+                Sometimes(0.5, CoarseDropout( p=0.2, size_percent=0.05) ),
+                Sometimes(0.5, GaussianBlur(1.2*np.random.rand())),
+                Sometimes(0.5, Add((-25, 25), per_channel=0.3)),
+                Sometimes(0.3, Invert(0.2, per_channel=True)),
+                Sometimes(0.5, Multiply((0.6, 1.4), per_channel=0.5)),
+                Sometimes(0.5, Multiply((0.6, 1.4))),
+                Sometimes(0.5, LinearContrast((0.5, 2.2), per_channel=0.3))
+                ], random_order = False)"""
+            # for darker objects, e.g. LM driller: use BOOTSTRAP_RATIO: 16 and weaker augmentation
+            aug_code_weaker = """Sequential([
+                Sometimes(0.4, CoarseDropout( p=0.1, size_percent=0.05) ),
+                # Sometimes(0.5, Affine(scale=(1.0, 1.2))),
+                Sometimes(0.5, GaussianBlur(np.random.rand())),
+                Sometimes(0.5, Add((-20, 20), per_channel=0.3)),
+                Sometimes(0.4, Invert(0.20, per_channel=True)),
+                Sometimes(0.5, Multiply((0.7, 1.4), per_channel=0.8)),
+                Sometimes(0.5, Multiply((0.7, 1.4))),
+                Sometimes(0.5, LinearContrast((0.5, 2.0), per_channel=0.3))
+                ], random_order=False)"""
+            color_augmentor = eval(aug_code)
+        elif aug_type.lower() == "code":  # assume imgaug
+            import imgaug.augmenters as iaa
+            from imgaug.augmenters import (Sequential, SomeOf, OneOf, Sometimes, WithColorspace, WithChannels, Noop,
+                                           Lambda, AssertLambda, AssertShape, Scale, CropAndPad, Pad, Crop, Fliplr,
+                                           Flipud, Superpixels, ChangeColorspace, PerspectiveTransform, Grayscale,
+                                           GaussianBlur, AverageBlur, MedianBlur, Convolve, Sharpen, Emboss, EdgeDetect,
+                                           DirectedEdgeDetect, Add, AddElementwise, AdditiveGaussianNoise, Multiply,
+                                           MultiplyElementwise, Dropout, CoarseDropout, Invert, ContrastNormalization,
+                                           Affine, PiecewiseAffine, ElasticTransformation, pillike, LinearContrast, Canny)  # noqa
+            aug_code = self.color_aug_code
+            color_augmentor = eval(aug_code)
+        elif aug_type.lower() == 'code_albu':
+            from albumentations import (HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
+                                        Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion,
+                                        HueSaturationValue, IAAAdditiveGaussianNoise, GaussNoise, MotionBlur,
+                                        MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss, RandomContrast,
+                                        RandomBrightness, Flip, OneOf, Compose, CoarseDropout, RGBShift, RandomGamma,
+                                        RandomBrightnessContrast, JpegCompression, InvertImg)  # noqa
+            aug_code = """Compose([
+                CoarseDropout(max_height=0.05*480, max_holes=0.05*640, p=0.4),
+                OneOf([
+                    IAAAdditiveGaussianNoise(p=0.5),
+                    GaussNoise(p=0.5),
+                ], p=0.2),
+                OneOf([
+                    MotionBlur(p=0.2),
+                    MedianBlur(blur_limit=3, p=0.1),
+                    Blur(blur_limit=3, p=0.1),
+                ], p=0.2),
+                OneOf([
+                    CLAHE(clip_limit=2),
+                    IAASharpen(),
+                    IAAEmboss(),
+                    RandomBrightnessContrast(),
+                ], p=0.3),
+                InvertImg(p=0.2),
+                RGBShift(r_shift_limit=105, g_shift_limit=45, b_shift_limit=40, p=0.5),
+                RandomContrast(limit=0.9, p=0.5),
+                RandomGamma(gamma_limit=(80,120), p=0.5),
+                RandomBrightness(limit=1.2, p=0.5),
+                HueSaturationValue(hue_shift_limit=172, sat_shift_limit=20, val_shift_limit=27, p=0.3),
+                JpegCompression(quality_lower=4, quality_upper=100, p=0.4),
+            ], p=0.8)"""
+            color_augmentor = eval(self.color_aug_code)
+        else:
+            color_augmentor = None
+        # fmt: on
+        return color_augmentor
+
+    def _color_aug(self, image, aug_type="ROI10D"):
+        # assume image in [0, 255] uint8
+        if aug_type.lower() == "roi10d":  # need normalized image in [0,1]
+            image = np.asarray(image / 255.0, dtype=np.float32).copy()
+            image = self.color_augmentor.augment(image)
+            image = (image * 255.0 + 0.5).astype(np.uint8)
+            return image
+        elif aug_type.lower() in ["aae", "code"]:
+            # imgaug need uint8
+            return self.color_augmentor.augment_image(image)
+        elif aug_type.lower() in ["code_albu"]:
+            augmented = self.color_augmentor(image=image)
+            return augmented["image"]
+        else:
+            raise ValueError("aug_type: {} is not supported.".format(aug_type))
diff --git a/det/yolox/data/datasets/tless_bop_test.py b/det/yolox/data/datasets/tless_bop_test.py
new file mode 100755
index 0000000000000000000000000000000000000000..c6b66490477dde3e100e6a713149c5ef09299943
--- /dev/null
+++ b/det/yolox/data/datasets/tless_bop_test.py
@@ -0,0 +1,530 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TLESS_BOP_TEST_Dataset(object):
+    """tless bop test."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/tless/test_primesense
+        assert osp.exists(self.dataset_root), self.dataset_root
+
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tless/models_cad
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]
+        self.width = data_cfg["width"]
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tless.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj, obj_id in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        targets = mmcv.load(self.ann_file)
+        scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+        scene_im_ids = sorted(list(set(scene_im_ids)))
+
+        # load infos for each scene
+        gt_dicts = {}
+        gt_info_dicts = {}
+        cam_dicts = {}
+        for scene_id, im_id in scene_im_ids:
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+            if scene_id not in gt_dicts:
+                gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            if scene_id not in gt_info_dicts:
+                gt_info_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))  # bbox_obj, bbox_visib
+            if scene_id not in cam_dicts:
+                cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+        for scene_id, int_im_id in tqdm(scene_im_ids):
+            str_im_id = str(int_im_id)
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+
+            gt_dict = gt_dicts[scene_id]
+            gt_info_dict = gt_info_dicts[scene_id]
+            cam_dict = cam_dicts[scene_id]
+
+            rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+            assert osp.exists(rgb_path), rgb_path
+
+            depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+            scene_im_id = f"{scene_id}/{int_im_id}"
+
+            K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+            depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                "height": self.height,
+                "width": self.width,
+                "image_id": int_im_id,
+                "scene_im_id": scene_im_id,  # for evaluation
+                "cam": K,
+                "depth_factor": depth_factor,
+                "img_type": "real",  # NOTE: has background
+            }
+            insts = []
+            for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                obj_id = anno["obj_id"]
+                if obj_id not in self.cat_ids:
+                    continue
+                cur_label = self.cat2label[obj_id]  # 0-based label
+                R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                pose = np.hstack([R, t.reshape(3, 1)])
+                quat = mat2quat(R).astype("float32")
+
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                x1, y1, w, h = bbox_visib
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+
+                mask_file = osp.join(
+                    scene_root,
+                    "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                mask_visib_file = osp.join(
+                    scene_root,
+                    "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                assert osp.exists(mask_file), mask_file
+                assert osp.exists(mask_visib_file), mask_visib_file
+                # load mask visib
+                mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                mask_single = mask_single.astype("bool")
+                area = mask_single.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                # load mask full
+                mask_full = mmcv.imread(mask_file, "unchanged")
+                mask_full = mask_full.astype("bool")
+                mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYWH_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                    "mask_full": mask_full_rle,
+                    "visib_fract": visib_fract,
+                    "xyz_path": None,  #  no need for test
+                }
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                insts.append(inst)
+            if len(insts) == 0:  # filter im without anno
+                continue
+            record["annotations"] = insts
+            dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "There are {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "There are {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # logger.info("load cached object models from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.tless.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tless_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+################################################################################
+
+SPLITS_TLESS = dict(
+    tless_bop_test_primesense=dict(
+        name="tless_bop_test_primesense",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_primesense"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+        objs=ref.tless.objects,  # selected objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=540,
+        width=720,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="tless",
+    ),
+    tless_bop_test_primesense_alignK=dict(
+        name="tless_bop_test_primesense_alignK",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_primesense_alignK"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+        objs=ref.tless.objects,  # selected objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=540,
+        width=720,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="tless",
+    ),
+)
+
+
+# single objs (num_class is from all objs)
+for obj in ref.tless.objects:
+    name = "tless_{}_bop_test_primesense".format(obj)
+    select_objs = [obj]
+    if name not in SPLITS_TLESS:
+        SPLITS_TLESS[name] = dict(
+            name=name,
+            dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_primesense"),
+            models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+            objs=[obj],
+            ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/test_targets_bop19.json"),
+            scale_to_meter=0.001,
+            with_masks=True,  # (load masks but may not use it)
+            with_depth=True,  # (load depth path here, but may not use it)
+            height=540,
+            width=720,
+            cache_dir=osp.join(PROJ_ROOT, ".cache"),
+            use_cache=True,
+            num_to_load=-1,
+            filter_invalid=False,
+            ref_key="tless",
+        )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TLESS:
+        used_cfg = SPLITS_TLESS[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TLESS_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="tless",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_tless_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TLESS.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tless_pbr.py b/det/yolox/data/datasets/tless_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d221d6df655268974f6971d1e72d46dc48d1f60
--- /dev/null
+++ b/det/yolox/data/datasets/tless_pbr.py
@@ -0,0 +1,482 @@
+import logging
+import hashlib
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TLESS_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objs
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_pbr"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tless/models_cad
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]
+        self.width = data_cfg["width"]
+
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tless.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs)
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+
+        dataset_dicts = []
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        # it is slow because of loading and converting masks to rle
+
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(scene_root, "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i))
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_file,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(self.models_root, f"obj_{ref.tless.obj2id[obj_name]:06d}.ply"),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tless_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+tless_model_root = "BOP_DATASETS/tless/models_cad/"
+################################################################################
+
+
+SPLITS_TLESS_PBR = dict(
+    tless_pbr_train=dict(
+        name="tless_pbr_train",
+        objs=ref.tless.objects,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=540,
+        width=720,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="tless",
+    ),
+)
+
+# single obj splits
+for obj in ref.tless.objects:
+    for split in ["train"]:
+        name = "tless_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_TLESS_PBR:
+            SPLITS_TLESS_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=540,
+                width=720,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="tless",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TLESS_PBR:
+        used_cfg = SPLITS_TLESS_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TLESS_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="tless",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_tless_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TLESS_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tless_primesense_train.py b/det/yolox/data/datasets/tless_primesense_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b16e1b042ba805e09c932930db8bba1aef9b1e3
--- /dev/null
+++ b/det/yolox/data/datasets/tless_primesense_train.py
@@ -0,0 +1,446 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, lazy_property
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TLESS_PRIMESENSE_TRAIN_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get(
+            "dataset_root",
+            osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_primesense"),
+        )
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tless/models_cad
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tless.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(1, 31)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+            # import ipdb;ipdb.set_trace()
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+                # import ipdb;ipdb.set_trace()
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area <= 64:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full_file": mask_file,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format(self.name))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.tless.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tless_primesense_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+tless_primesense_model_root = "BOP_DATASETS/tless/models_cad/"
+################################################################################
+
+SPLITS_TLESS_PRIMESENSE_TRAIN = dict(
+    tless_primesense_train=dict(
+        name="tless_primesense_train",
+        objs=ref.tless.objects,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_primesense"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=400,
+        width=400,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="tless",
+    )
+)
+
+# single obj splits
+for obj in ref.tless.objects:
+    for split in ["train_primesense"]:
+        name = "tless_{}_{}".format(obj, split)
+        if split in ["train_primesense"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_TLESS_PRIMESENSE_TRAIN:
+            SPLITS_TLESS_PRIMESENSE_TRAIN[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/train_primesense"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tless/models_cad"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=400,
+                width=400,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="tless",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TLESS_PRIMESENSE_TRAIN:
+        used_cfg = SPLITS_TLESS_PRIMESENSE_TRAIN[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TLESS_PRIMESENSE_TRAIN_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="tless",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="coco_bop",
+        **get_tless_primesense_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TLESS_PRIMESENSE_TRAIN.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                ],
+                ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                row=2,
+                col=2,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m det.yolov4.datasets.tless_primesense_pbr tless_primesense_pbr_train
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tudl_bop_test.py b/det/yolox/data/datasets/tudl_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e77a6a1930b90e795879d8ade997ed10a02a95c
--- /dev/null
+++ b/det/yolox/data/datasets/tudl_bop_test.py
@@ -0,0 +1,525 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TUDL_BOP_TEST_Dataset(object):
+    """tudl bop test splits."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tudl/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tudl.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj, obj_id in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        targets = mmcv.load(self.ann_file)
+
+        scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+        scene_im_ids = sorted(list(set(scene_im_ids)))
+
+        # load infos for each scene
+        gt_dicts = {}
+        gt_info_dicts = {}
+        cam_dicts = {}
+        for scene_id, im_id in scene_im_ids:
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+            if scene_id not in gt_dicts:
+                gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            if scene_id not in gt_info_dicts:
+                gt_info_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))  # bbox_obj, bbox_visib
+            if scene_id not in cam_dicts:
+                cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+        for scene_id, int_im_id in tqdm(scene_im_ids):
+            str_im_id = str(int_im_id)
+            scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+
+            gt_dict = gt_dicts[scene_id]
+            gt_info_dict = gt_info_dicts[scene_id]
+            cam_dict = cam_dicts[scene_id]
+
+            rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+            assert osp.exists(rgb_path), rgb_path
+
+            depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+            scene_im_id = f"{scene_id}/{int_im_id}"
+
+            K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+            depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                "height": self.height,
+                "width": self.width,
+                "image_id": int_im_id,
+                "scene_im_id": scene_im_id,  # for evaluation
+                "cam": K,
+                "depth_factor": depth_factor,
+                "img_type": "real",  # NOTE: has background
+            }
+            insts = []
+            for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                obj_id = anno["obj_id"]
+                if obj_id not in self.cat_ids:
+                    continue
+                cur_label = self.cat2label[obj_id]  # 0-based label
+                R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                pose = np.hstack([R, t.reshape(3, 1)])
+                quat = mat2quat(R).astype("float32")
+
+                proj = (record["cam"] @ t.T).T
+                proj = proj[:2] / proj[2]
+
+                bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                x1, y1, w, h = bbox_visib
+                if self.filter_invalid:
+                    if h <= 1 or w <= 1:
+                        self.num_instances_without_valid_box += 1
+                        continue
+
+                mask_file = osp.join(
+                    scene_root,
+                    "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                mask_visib_file = osp.join(
+                    scene_root,
+                    "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                )
+                assert osp.exists(mask_file), mask_file
+                assert osp.exists(mask_visib_file), mask_visib_file
+                # load mask visib
+                mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                mask_single = mask_single.astype("bool")
+                area = mask_single.sum()
+                if area < 3:  # filter out too small or nearly invisible instances
+                    self.num_instances_without_valid_segmentation += 1
+                mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                # load mask full
+                mask_full = mmcv.imread(mask_file, "unchanged")
+                mask_full = mask_full.astype("bool")
+                mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYWH_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": t,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                    "mask_full": mask_full_rle,
+                    "visib_fract": visib_fract,
+                    "xyz_path": None,  #  no need for test
+                }
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                insts.append(inst)
+            if len(insts) == 0:  # filter im without anno
+                continue
+            record["annotations"] = insts
+            dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "There are {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "There are {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.cache_dir, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.tudl.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tudl_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+##########################################################################
+
+TUDL_OBJECTS = ["dragon", "frog", "can"]
+
+SPLITS_TUDL = dict(
+    tudl_bop_test=dict(
+        name="tudl_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+        objs=TUDL_OBJECTS,
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="tudl",
+    ),
+)
+
+# single obj splits for tudl bop test
+for obj in ref.tudl.objects:
+    for split in [
+        "bop_test",
+    ]:
+        name = "tudl_{}_{}".format(obj, split)
+        ann_files = [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/tudl/image_set/{}_{}.txt".format(obj, split),
+            )
+        ]
+        if name not in SPLITS_TUDL:
+            SPLITS_TUDL[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+                objs=[obj],  # only this obj
+                scale_to_meter=0.001,
+                ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test_targets_bop19.json"),
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=False,
+                ref_key="tudl",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TUDL:
+        used_cfg = SPLITS_TUDL[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TUDL_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="tudl",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_tudl_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TUDL.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tudl_dataset_d2.py b/det/yolox/data/datasets/tudl_dataset_d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..40c3a2fc1ea368346a6a90301be23f82faa5835f
--- /dev/null
+++ b/det/yolox/data/datasets/tudl_dataset_d2.py
@@ -0,0 +1,508 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TUDL_Dataset(object):
+    """tudl splits."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tudl/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tudl.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(1, 4)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    mask_single = mask_single.astype("bool")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                        "visib_fract": visib_fract,
+                        "xyz_path": None,  #  no need for test
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.cache_dir, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.tudl.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tudl_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+##########################################################################
+
+TUDL_OBJECTS = ["dragon", "frog", "can"]
+
+# TODO: add real train
+SPLITS_TUDL = dict(
+    tudl_bop_test=dict(
+        name="tudl_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+        objs=TUDL_OBJECTS,
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_scene=True,
+        filter_invalid=False,
+        ref_key="tudl",
+    ),
+)
+
+# single obj splits for tudl bop test
+for obj in ref.tudl.objects:
+    for split in [
+        "bop_test",
+    ]:
+        name = "tudl_{}_{}".format(obj, split)
+        ann_files = [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/tudl/image_set/{}_{}.txt".format(obj, split),
+            )
+        ]
+        if name not in SPLITS_TUDL:
+            SPLITS_TUDL[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+                objs=[obj],  # only this obj
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=False,
+                filter_scene=True,
+                ref_key="tudl",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TUDL:
+        used_cfg = SPLITS_TUDL[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TUDL_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="linemod",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_tudl_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TUDL.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if "test" not in dset_name.lower():
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop,
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    python this_file.py dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tudl_pbr.py b/det/yolox/data/datasets/tudl_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..eca140242342ec03c3e5eb83290eb6499ed5b8ba
--- /dev/null
+++ b/det/yolox/data/datasets/tudl_pbr.py
@@ -0,0 +1,494 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+
+import ref
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TUDL_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get("dataset_root", osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_pbr"))
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tudl/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tudl.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    mask_single = mask_single.astype("bool")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,
+                        "visib_fract": visib_fract,
+                        # "xyz_path": xyz_path,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format("_".join(self.objs)))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.tudl.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def __len__(self):
+        return self.num_to_load
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tudl_metadata(obj_names, ref_key):
+    """task specific metadata."""
+
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+TUDL_OBJECTS = ["dragon", "frog", "can"]
+################################################################################
+
+
+SPLITS_TUDL_PBR = dict(
+    tudl_pbr_train=dict(
+        name="tudl_pbr_train",
+        objs=TUDL_OBJECTS,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="tudl",
+    )
+)
+
+# single obj splits
+for obj in ref.tudl.objects:
+    for split in ["train"]:
+        name = "tudl_pbr_{}_{}".format(obj, split)
+        if split in ["train"]:
+            filter_invalid = True
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_TUDL_PBR:
+            SPLITS_TUDL_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="tudl",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TUDL_PBR:
+        used_cfg = SPLITS_TUDL_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TUDL_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_tudl_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TUDL_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/tudl_train_real.py b/det/yolox/data/datasets/tudl_train_real.py
new file mode 100644
index 0000000000000000000000000000000000000000..ead4c5d8d418dc68396ff0f98ee0ae3277f4dad9
--- /dev/null
+++ b/det/yolox/data/datasets/tudl_train_real.py
@@ -0,0 +1,447 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, lazy_property
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class TUDL_TRAIN_REAL_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get(
+            "dataset_root",
+            osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_real"),
+        )
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/tudl/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.tudl.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(1, 4)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "real",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 30:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full_file": mask_file,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format(self.name))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.tudl.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_tudl_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+tudl_model_root = "BOP_DATASETS/tudl/models/"
+################################################################################
+
+TUDL_OBJECTS = ["dragon", "frog", "can"]
+
+SPLITS_TUDL_TRAIN_REAL = dict(
+    tudl_train_real=dict(
+        name="tudl_train_real",
+        objs=TUDL_OBJECTS,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_real"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="tudl",
+    )
+)
+
+# single obj splits
+for obj in ref.tudl.objects:
+    for split in ["train_real"]:
+        name = "tudl_{}_{}".format(obj, split)
+        if split in ["train_real"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_TUDL_TRAIN_REAL:
+            SPLITS_TUDL_TRAIN_REAL[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/train_real"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/tudl/models"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="tudl",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_TUDL_TRAIN_REAL:
+        used_cfg = SPLITS_TUDL_TRAIN_REAL[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, TUDL_TRAIN_REAL_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="tudl",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="coco_bop",
+        **get_tudl_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_TUDL_TRAIN_REAL.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                ],
+                ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                row=2,
+                col=2,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m det.yolov4.datasets.tudl_train_real tudl_train_real
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/datasets/voc.py b/det/yolox/data/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b35d6a3d0b1231fd0c339ce78064400450b3c0
--- /dev/null
+++ b/det/yolox/data/datasets/voc.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Code are based on
+# https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
+# Copyright (c) Francisco Massa.
+# Copyright (c) Ellis Brown, Max deGroot.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import os.path
+import pickle
+import xml.etree.ElementTree as ET
+from loguru import logger
+
+import cv2
+import numpy as np
+
+from det.yolox.evaluators.voc_eval import voc_eval
+
+from .datasets_wrapper import Dataset
+from .voc_classes import VOC_CLASSES
+
+
+class AnnotationTransform(object):
+
+    """Transforms a VOC annotation into a Tensor of bbox coords and label index
+    Initilized with a dictionary lookup of classnames to indexes
+    Arguments:
+        class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
+            (default: alphabetic indexing of VOC's 20 classes)
+        keep_difficult (bool, optional): keep difficult instances or not
+            (default: False)
+        height (int): height
+        width (int): width
+    """
+
+    def __init__(self, class_to_ind=None, keep_difficult=True):
+        self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))
+        self.keep_difficult = keep_difficult
+
+    def __call__(self, target):
+        """
+        Arguments:
+            target (annotation) : the target annotation to be made usable
+                will be an ET.Element
+        Returns:
+            a list containing lists of bounding boxes  [bbox coords, class name]
+        """
+        res = np.empty((0, 5))
+        for obj in target.iter("object"):
+            difficult = obj.find("difficult")
+            if difficult is not None:
+                difficult = int(difficult.text) == 1
+            else:
+                difficult = False
+            if not self.keep_difficult and difficult:
+                continue
+            name = obj.find("name").text.strip()
+            bbox = obj.find("bndbox")
+
+            pts = ["xmin", "ymin", "xmax", "ymax"]
+            bndbox = []
+            for i, pt in enumerate(pts):
+                cur_pt = int(bbox.find(pt).text) - 1
+                # scale height or width
+                # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
+                bndbox.append(cur_pt)
+            label_idx = self.class_to_ind[name]
+            bndbox.append(label_idx)
+            res = np.vstack((res, bndbox))  # [xmin, ymin, xmax, ymax, label_ind]
+            # img_id = target.find('filename').text[:-4]
+
+        width = int(target.find("size").find("width").text)
+        height = int(target.find("size").find("height").text)
+        img_info = (height, width)
+
+        return res, img_info
+
+
+class VOCDetection(Dataset):
+
+    """
+    VOC Detection Dataset Object
+    input is image, target is annotation
+    Args:
+        root (string): filepath to VOCdevkit folder.
+        image_set (string): imageset to use (eg. 'train', 'val', 'test')
+        transform (callable, optional): transformation to perform on the
+            input image
+        target_transform (callable, optional): transformation to perform on the
+            target `annotation`
+            (eg: take in caption string, return tensor of word indices)
+        dataset_name (string, optional): which dataset to load
+            (default: 'VOC2007')
+    """
+
+    def __init__(
+        self,
+        data_dir,
+        image_sets=[("2007", "trainval"), ("2012", "trainval")],
+        img_size=(416, 416),
+        preproc=None,
+        target_transform=AnnotationTransform(),
+        dataset_name="VOC0712",
+        cache=False,
+    ):
+        super().__init__(img_size)
+        self.root = data_dir
+        self.image_set = image_sets
+        self.img_size = img_size
+        self.preproc = preproc
+        self.target_transform = target_transform
+        self.name = dataset_name
+        self._annopath = os.path.join("%s", "Annotations", "%s.xml")
+        self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
+        self._classes = VOC_CLASSES
+        self.ids = list()
+        for (year, name) in image_sets:
+            self._year = year
+            rootpath = os.path.join(self.root, "VOC" + year)
+            for line in open(os.path.join(rootpath, "ImageSets", "Main", name + ".txt")):
+                self.ids.append((rootpath, line.strip()))
+
+        self.annotations = self._load_coco_annotations()
+        self.imgs = None
+        if cache:
+            self._cache_images()
+
+    def __len__(self):
+        return len(self.ids)
+
+    def _load_coco_annotations(self):
+        return [self.load_anno_from_ids(_ids) for _ids in range(len(self.ids))]
+
+    def _cache_images(self):
+        logger.warning(
+            "\n********************************************************************************\n"
+            "You are using cached images in RAM to accelerate training.\n"
+            "This requires large system RAM.\n"
+            "Make sure you have 60G+ RAM and 19G available disk space for training VOC.\n"
+            "********************************************************************************\n"
+        )
+        max_h = self.img_size[0]
+        max_w = self.img_size[1]
+        cache_file = self.root + "/img_resized_cache_" + self.name + ".array"
+        if not os.path.exists(cache_file):
+            logger.info("Caching images for the first time. This might take about 3 minutes for VOC")
+            self.imgs = np.memmap(
+                cache_file,
+                shape=(len(self.ids), max_h, max_w, 3),
+                dtype=np.uint8,
+                mode="w+",
+            )
+            from tqdm import tqdm
+            from multiprocessing.pool import ThreadPool
+
+            NUM_THREADs = min(8, os.cpu_count())
+            loaded_images = ThreadPool(NUM_THREADs).imap(
+                lambda x: self.load_resized_img(x),
+                range(len(self.annotations)),
+            )
+            pbar = tqdm(enumerate(loaded_images), total=len(self.annotations))
+            for k, out in pbar:
+                self.imgs[k][: out.shape[0], : out.shape[1], :] = out.copy()
+            self.imgs.flush()
+            pbar.close()
+        else:
+            logger.warning(
+                "You are using cached imgs! Make sure your dataset is not changed!!\n"
+                "Everytime the self.input_size is changed in your exp file, you need to delete\n"
+                "the cached data and re-generate them.\n"
+            )
+
+        logger.info("Loading cached imgs...")
+        self.imgs = np.memmap(
+            cache_file,
+            shape=(len(self.ids), max_h, max_w, 3),
+            dtype=np.uint8,
+            mode="r+",
+        )
+
+    def load_anno_from_ids(self, index):
+        img_id = self.ids[index]
+        target = ET.parse(self._annopath % img_id).getroot()
+
+        assert self.target_transform is not None
+        res, img_info = self.target_transform(target)
+        height, width = img_info
+
+        r = min(self.img_size[0] / height, self.img_size[1] / width)
+        res[:, :4] *= r
+        resized_info = (int(height * r), int(width * r))
+
+        return (res, img_info, resized_info)
+
+    def load_anno(self, index):
+        return self.annotations[index][0]
+
+    def load_resized_img(self, index):
+        img = self.load_image(index)
+        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * r), int(img.shape[0] * r)),
+            interpolation=cv2.INTER_LINEAR,
+        ).astype(np.uint8)
+
+        return resized_img
+
+    def load_image(self, index):
+        img_id = self.ids[index]
+        img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
+        assert img is not None
+
+        return img
+
+    def pull_item(self, index):
+        """Returns the original image and target at an index for mixup
+        Note: not using self.__getitem__(), as any transformations passed in
+        could mess up this functionality.
+        Argument:
+            index (int): index of img to show
+        Return:
+            img, target
+        """
+        if self.imgs is not None:
+            target, img_info, resized_info = self.annotations[index]
+            pad_img = self.imgs[index]
+            img = pad_img[: resized_info[0], : resized_info[1], :].copy()
+        else:
+            img = self.load_resized_img(index)
+            target, img_info, _ = self.annotations[index]
+
+        return img, target, img_info, index
+
+    @Dataset.mosaic_getitem
+    def __getitem__(self, index):
+        img, target, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+
+        return img, target, img_info, img_id
+
+    def evaluate_detections(self, all_boxes, output_dir=None):
+        """all_boxes is a list of length number-of-classes.
+
+        Each list element is a list of length number-of-images.
+        Each of those list elements is either an empty list []
+        or a numpy array of detection.
+        all_boxes[class][image] = [] or np.array of shape #dets x 5
+        """
+        self._write_voc_results_file(all_boxes)
+        IouTh = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)
+        mAPs = []
+        for iou in IouTh:
+            mAP = self._do_python_eval(output_dir, iou)
+            mAPs.append(mAP)
+
+        print("--------------------------------------------------------------")
+        print("map_5095:", np.mean(mAPs))
+        print("map_50:", mAPs[0])
+        print("--------------------------------------------------------------")
+        return np.mean(mAPs), mAPs[0]
+
+    def _get_voc_results_file_template(self):
+        filename = "comp4_det_test" + "_{:s}.txt"
+        filedir = os.path.join(self.root, "results", "VOC" + self._year, "Main")
+        if not os.path.exists(filedir):
+            os.makedirs(filedir)
+        path = os.path.join(filedir, filename)
+        return path
+
+    def _write_voc_results_file(self, all_boxes):
+        for cls_ind, cls in enumerate(VOC_CLASSES):
+            cls_ind = cls_ind
+            if cls == "__background__":
+                continue
+            print("Writing {} VOC results file".format(cls))
+            filename = self._get_voc_results_file_template().format(cls)
+            with open(filename, "wt") as f:
+                for im_ind, index in enumerate(self.ids):
+                    index = index[1]
+                    dets = all_boxes[cls_ind][im_ind]
+                    if dets == []:
+                        continue
+                    for k in range(dets.shape[0]):
+                        f.write(
+                            "{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
+                                index,
+                                dets[k, -1],
+                                dets[k, 0] + 1,
+                                dets[k, 1] + 1,
+                                dets[k, 2] + 1,
+                                dets[k, 3] + 1,
+                            )
+                        )
+
+    def _do_python_eval(self, output_dir="output", iou=0.5):
+        rootpath = os.path.join(self.root, "VOC" + self._year)
+        name = self.image_set[0][1]
+        annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")
+        imagesetfile = os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
+        cachedir = os.path.join(self.root, "annotations_cache", "VOC" + self._year, name)
+        if not os.path.exists(cachedir):
+            os.makedirs(cachedir)
+        aps = []
+        # The PASCAL VOC metric changed in 2010
+        use_07_metric = True if int(self._year) < 2010 else False
+        print("Eval IoU : {:.2f}".format(iou))
+        if output_dir is not None and not os.path.isdir(output_dir):
+            os.mkdir(output_dir)
+        for i, cls in enumerate(VOC_CLASSES):
+
+            if cls == "__background__":
+                continue
+
+            filename = self._get_voc_results_file_template().format(cls)
+            rec, prec, ap = voc_eval(
+                filename,
+                annopath,
+                imagesetfile,
+                cls,
+                cachedir,
+                ovthresh=iou,
+                use_07_metric=use_07_metric,
+            )
+            aps += [ap]
+            if iou == 0.5:
+                print("AP for {} = {:.4f}".format(cls, ap))
+            if output_dir is not None:
+                with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
+                    pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
+        if iou == 0.5:
+            print("Mean AP = {:.4f}".format(np.mean(aps)))
+            print("~~~~~~~~")
+            print("Results:")
+            for ap in aps:
+                print("{:.3f}".format(ap))
+            print("{:.3f}".format(np.mean(aps)))
+            print("~~~~~~~~")
+            print("")
+            print("--------------------------------------------------------------")
+            print("Results computed with the **unofficial** Python eval code.")
+            print("Results should be very close to the official MATLAB eval code.")
+            print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
+            print("-- Thanks, The Management")
+            print("--------------------------------------------------------------")
+
+        return np.mean(aps)
diff --git a/det/yolox/data/datasets/voc_classes.py b/det/yolox/data/datasets/voc_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..89354b3fdb19195f63f76ed56c86565323de5434
--- /dev/null
+++ b/det/yolox/data/datasets/voc_classes.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+# VOC_CLASSES = ( '__background__', # always index 0
+VOC_CLASSES = (
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "diningtable",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+)
diff --git a/det/yolox/data/datasets/ycbv_bop_test.py b/det/yolox/data/datasets/ycbv_bop_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9a9324ef3bfaeb485a7e798548fc6a0695615ad
--- /dev/null
+++ b/det/yolox/data/datasets/ycbv_bop_test.py
@@ -0,0 +1,454 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class YCBV_BOP_TEST_Dataset:
+    """ycbv bop test."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+        # all classes are self.objs, but this enables us to evaluate on selected objs
+        self.select_objs = data_cfg.get("select_objs", self.objs)
+
+        self.ann_file = data_cfg["ann_file"]  # json file with scene_id and im_id items
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/ycbv/test
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/ycbv/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.ycbv.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        im_id_global = 0
+
+        if True:
+            targets = mmcv.load(self.ann_file)
+            scene_im_ids = [(item["scene_id"], item["im_id"]) for item in targets]
+            scene_im_ids = sorted(list(set(scene_im_ids)))
+
+            # load infos for each scene
+            gt_dicts = {}
+            gt_info_dicts = {}
+            cam_dicts = {}
+            for scene_id, im_id in scene_im_ids:
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                if scene_id not in gt_dicts:
+                    gt_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+                if scene_id not in gt_info_dicts:
+                    gt_info_dicts[scene_id] = mmcv.load(
+                        osp.join(scene_root, "scene_gt_info.json")
+                    )  # bbox_obj, bbox_visib
+                if scene_id not in cam_dicts:
+                    cam_dicts[scene_id] = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for scene_id, im_id in tqdm(scene_im_ids):
+                str_im_id = str(im_id)
+                scene_root = osp.join(self.dataset_root, f"{scene_id:06d}")
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.png").format(im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(im_id))
+
+                scene_id = int(rgb_path.split("/")[-3])
+
+                cam = np.array(cam_dicts[scene_id][str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dicts[scene_id][str_im_id]["depth_scale"]
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "depth_factor": depth_factor,
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": im_id_global,  # unique image_id in the dataset, for coco evaluation
+                    "scene_im_id": "{}/{}".format(scene_id, im_id),  # for evaluation
+                    "cam": cam,
+                    "img_type": "real",
+                }
+                im_id_global += 1
+                insts = []
+                for anno_i, anno in enumerate(gt_dicts[scene_id][str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if ref.ycbv.id2obj[obj_id] not in self.select_objs:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dicts[scene_id][str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dicts[scene_id][str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area < 3:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,  # TODO: load as mask_full, rle
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, f"models_{self.name}.pkl")
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.ycbv.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_ycbv_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+################################################################################
+
+SPLITS_YCBV = dict(
+    ycbv_bop_test=dict(
+        name="ycbv_bop_test",
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/test"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),
+        objs=ref.ycbv.objects,  # selected objects
+        ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/test_targets_bop19.json"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        cache_dir=osp.join(PROJ_ROOT, ".cache"),
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=False,
+        ref_key="ycbv",
+    )
+)
+
+
+# single objs (num_class is from all objs)
+for obj in ref.ycbv.objects:
+    name = "ycbv_bop_{}_test".format(obj)
+    select_objs = [obj]
+    if name not in SPLITS_YCBV:
+        SPLITS_YCBV[name] = dict(
+            name=name,
+            dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/test"),
+            models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),
+            objs=[obj],  # only this obj
+            select_objs=select_objs,  # selected objects
+            ann_file=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/test_targets_bop19.json"),
+            scale_to_meter=0.001,
+            with_masks=True,  # (load masks but may not use it)
+            with_depth=True,  # (load depth path here, but may not use it)
+            height=480,
+            width=640,
+            cache_dir=osp.join(PROJ_ROOT, ".cache"),
+            use_cache=True,
+            num_to_load=-1,
+            filter_invalid=False,
+            ref_key="ycbv",
+        )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_YCBV:
+        used_cfg = SPLITS_YCBV[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, YCBV_BOP_TEST_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="ycbv",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_ycbv_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_YCBV.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / d["depth_factor"]
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        # img_vis = vis_image_bboxes_cv2(img, bboxes=bboxes_xyxy, labels=labels)
+        img_vis = vis_image_mask_bbox_cv2(img, masks, bboxes=bboxes_xyxy, labels=labels)
+        img_vis_kpts2d = img.copy()
+        for anno_i in range(len(annos)):
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis_kpts2d, kpts_2d[anno_i])
+        grid_show(
+            [
+                img[:, :, [2, 1, 0]],
+                img_vis[:, :, [2, 1, 0]],
+                img_vis_kpts2d[:, :, [2, 1, 0]],
+                depth,
+            ],
+            [f"img:{d['file_name']}", "vis_img", "img_vis_kpts2d", "depth"],
+            row=2,
+            col=2,
+        )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m core.datasets.ycbv_bop_test dataset_name
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from core.utils.data_utils import read_image_mmcv
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/ycbv_d2.py b/det/yolox/data/datasets/ycbv_d2.py
new file mode 100755
index 0000000000000000000000000000000000000000..6c763fc0ef48b28e85ca3cddc070e5e0a43bbda2
--- /dev/null
+++ b/det/yolox/data/datasets/ycbv_d2.py
@@ -0,0 +1,740 @@
+import hashlib
+import copy
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class YCBV_Dataset:
+    """use image_sets(scene/image_id) and image root to get data; Here we use
+    bop models, which are center aligned and have some offsets compared to
+    original models."""
+
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.ann_files = data_cfg["ann_files"]  # provide scene/im_id list
+        self.image_prefixes = data_cfg["image_prefixes"]  # image root
+
+        self.dataset_root = data_cfg["dataset_root"]  # BOP_DATASETS/ycbv/
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/ycbv/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]  # True (load masks but may not use it)
+        self.with_depth = data_cfg["with_depth"]  # True (load depth path here, but may not use it)
+        self.with_xyz = data_cfg["with_xyz"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg["filter_invalid"]
+
+        self.align_K_by_change_pose = data_cfg.get("align_K_by_change_pose", False)
+        # default: 0000~0059 and synt
+        self.cam = np.array(
+            [
+                [1066.778, 0.0, 312.9869],
+                [0.0, 1067.487, 241.3109],
+                [0.0, 0.0, 1.0],
+            ],
+            dtype="float32",
+        )
+        # 0060~0091
+        # cmu_cam = np.array([[1077.836, 0.0, 323.7872], [0.0, 1078.189, 279.6921], [0.0, 0.0, 1.0]], dtype='float32')
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.ycbv.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+    def _load_from_idx_file(self, idx_file, image_root):
+        """
+        idx_file: the scene/image ids
+        image_root/scene contains:
+            scene_gt.json
+            scene_gt_info.json
+            scene_camera.json
+        """
+        xyz_root = osp.join(image_root, "xyz_crop")
+        scene_gt_dicts = {}
+        scene_gt_info_dicts = {}
+        scene_cam_dicts = {}
+        scene_im_ids = []  # store tuples of (scene_id, im_id)
+        with open(idx_file, "r") as f:
+            for line in f:
+                line_split = line.strip("\r\n").split("/")
+                scene_id = int(line_split[0])
+                im_id = int(line_split[1])
+                scene_im_ids.append((scene_id, im_id))
+                if scene_id not in scene_gt_dicts:
+                    scene_gt_file = osp.join(image_root, f"{scene_id:06d}/scene_gt.json")
+                    assert osp.exists(scene_gt_file), scene_gt_file
+                    scene_gt_dicts[scene_id] = mmcv.load(scene_gt_file)
+
+                if scene_id not in scene_gt_info_dicts:
+                    scene_gt_info_file = osp.join(image_root, f"{scene_id:06d}/scene_gt_info.json")
+                    assert osp.exists(scene_gt_info_file), scene_gt_info_file
+                    scene_gt_info_dicts[scene_id] = mmcv.load(scene_gt_info_file)
+
+                if scene_id not in scene_cam_dicts:
+                    scene_cam_file = osp.join(image_root, f"{scene_id:06d}/scene_camera.json")
+                    assert osp.exists(scene_cam_file), scene_cam_file
+                    scene_cam_dicts[scene_id] = mmcv.load(scene_cam_file)
+        ######################################################
+        scene_im_ids = sorted(scene_im_ids)  # sort to make it reproducible
+        dataset_dicts = []
+
+        num_instances_without_valid_segmentation = 0
+        num_instances_without_valid_box = 0
+
+        for (scene_id, im_id) in tqdm(scene_im_ids):
+            rgb_path = osp.join(image_root, f"{scene_id:06d}/rgb/{im_id:06d}.png")
+            assert osp.exists(rgb_path), rgb_path
+            str_im_id = str(im_id)
+
+            scene_im_id = f"{scene_id}/{im_id}"
+
+            # for ycbv/tless, load cam K from image infos
+            cam_anno = np.array(scene_cam_dicts[scene_id][str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+            adapth_this_K = False
+            if self.align_K_by_change_pose:
+                if (cam_anno != self.cam).any():
+                    adapth_this_K = True
+                    cam_anno_ori = cam_anno.copy()
+                    cam_anno = self.cam
+
+            depth_factor = 1000.0 / scene_cam_dicts[scene_id][str_im_id]["depth_scale"]
+            # dprint(record['cam'])
+            if "/train_synt/" in rgb_path:
+                img_type = "syn"
+            else:
+                img_type = "real"
+            record = {
+                "dataset_name": self.name,
+                "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                "height": self.height,
+                "width": self.width,
+                "image_id": self._unique_im_id,
+                "scene_im_id": scene_im_id,  # for evaluation
+                "cam": cam_anno,  # self.cam,
+                "depth_factor": depth_factor,
+                "img_type": img_type,
+            }
+
+            if self.with_depth:
+                depth_file = osp.join(image_root, f"{scene_id:06d}/depth/{im_id:06d}.png")
+                assert osp.exists(depth_file), depth_file
+                record["depth_file"] = osp.relpath(depth_file, PROJ_ROOT)
+
+            insts = []
+            anno_dict_list = scene_gt_dicts[scene_id][str(im_id)]
+            info_dict_list = scene_gt_info_dicts[scene_id][str(im_id)]
+            for anno_i, anno in enumerate(anno_dict_list):
+                info = info_dict_list[anno_i]
+                obj_id = anno["obj_id"]
+                if obj_id not in self.cat_ids:
+                    continue
+                # 0-based label now
+                cur_label = self.cat2label[obj_id]
+                ################ pose ###########################
+                R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                trans = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0  # mm->m
+                pose = np.hstack([R, trans.reshape(3, 1)])
+                if adapth_this_K:
+                    # pose_uw = inv(K_uw) @ K_cmu @ pose_cmu
+                    pose = np.linalg.inv(cam_anno) @ cam_anno_ori @ pose
+                    # R = pose[:3, :3]
+                    trans = pose[:3, 3]
+
+                quat = mat2quat(pose[:3, :3])
+
+                ############# bbox ############################
+                bbox = info["bbox_obj"]
+                x1, y1, w, h = bbox
+                x2 = x1 + w
+                y2 = y1 + h
+                x1 = max(min(x1, self.width), 0)
+                y1 = max(min(y1, self.height), 0)
+                x2 = max(min(x2, self.width), 0)
+                y2 = max(min(y2, self.height), 0)
+                bbox = [x1, y1, x2, y2]
+                if self.filter_invalid:
+                    bw = bbox[2] - bbox[0]
+                    bh = bbox[3] - bbox[1]
+                    if bh <= 1 or bw <= 1:
+                        num_instances_without_valid_box += 1
+                        continue
+
+                ############## mask #######################
+                if self.with_masks:  # either list[list[float]] or dict(RLE)
+                    mask_visib_file = osp.join(
+                        image_root,
+                        f"{scene_id:06d}/mask_visib/{im_id:06d}_{anno_i:06d}.png",
+                    )
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    mask = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask.sum()
+                    if area < 30 and self.filter_invalid:
+                        num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask)
+
+                    mask_full_file = osp.join(
+                        image_root,
+                        f"{scene_id:06d}/mask/{im_id:06d}_{anno_i:06d}.png",
+                    )
+                    assert osp.exists(mask_full_file), mask_full_file
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_full_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                proj = (self.cam @ trans.T).T  # NOTE: use self.cam here
+                proj = proj[:2] / proj[2]
+
+                inst = {
+                    "category_id": cur_label,  # 0-based label
+                    "bbox": bbox,  # TODO: load both bbox_obj and bbox_visib
+                    "bbox_mode": BoxMode.XYXY_ABS,
+                    "pose": pose,
+                    "quat": quat,
+                    "trans": trans,
+                    "centroid_2d": proj,  # absolute (cx, cy)
+                    "segmentation": mask_rle,
+                    "mask_full": mask_full_rle,
+                }
+
+                if self.with_xyz:
+                    xyz_path = osp.join(
+                        xyz_root,
+                        f"{scene_id:06d}/{im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+                    inst["xyz_path"] = xyz_path
+
+                model_info = self.models_info[str(obj_id)]
+                inst["model_info"] = model_info
+                # TODO: using full mask and full xyz
+                for key in ["bbox3d_and_center"]:
+                    inst[key] = self.models[cur_label][key]
+                insts.append(inst)
+            if len(insts) == 0:  # and self.filter_invalid:
+                continue
+            record["annotations"] = insts
+            dataset_dicts.append(record)
+            self._unique_im_id += 1
+
+        if num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    num_instances_without_valid_segmentation
+                )
+            )
+        if num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(num_instances_without_valid_box)
+            )
+        return dataset_dicts
+
+    def __call__(self):  # YCBV_Dataset
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    self.with_xyz,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        t_start = time.perf_counter()
+        dataset_dicts = []
+        self._unique_im_id = 0
+        for ann_file, image_root in zip(self.ann_files, self.image_prefixes):
+            # logger.info("loading coco json: {}".format(ann_file))
+            dataset_dicts.extend(self._load_from_idx_file(ann_file, image_root))
+
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format(self.name))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.ycbv.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_ycbv_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+ycbv_model_root = "BOP_DATASETS/ycbv/models/"
+################################################################################
+default_cfg = dict(
+    # name="ycbv_train_real",
+    dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/"),
+    models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),  # models_simple
+    objs=ref.ycbv.objects,  # all objects
+    # NOTE: this contains all classes
+    # ann_files=[osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/train.txt")],
+    # image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+    scale_to_meter=0.001,
+    with_masks=True,  # (load masks but may not use it)
+    with_depth=True,  # (load depth path here, but may not use it)
+    with_xyz=True,
+    height=480,
+    width=640,
+    align_K_by_change_pose=False,
+    cache_dir=osp.join(PROJ_ROOT, ".cache"),
+    use_cache=True,
+    num_to_load=-1,
+    filter_invalid=True,
+    ref_key="ycbv",
+)
+SPLITS_YCBV = {}
+update_cfgs = {
+    "ycbv_train_real": {
+        "ann_files": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/train.txt")],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+    },
+    "ycbv_train_real_aligned_Kuw": {
+        "ann_files": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/train.txt")],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+        "align_K_by_change_pose": True,
+    },
+    "ycbv_train_real_uw": {
+        "ann_files": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/train_real_uw.txt")],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+    },
+    "ycbv_train_real_uw_every10": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_real_uw_every10.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+    },
+    "ycbv_train_real_cmu": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_real_cmu.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+    },
+    "ycbv_train_real_cmu_aligned_Kuw": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_real_cmu.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_real")],
+        "align_K_by_change_pose": True,
+    },
+    "ycbv_train_synt": {
+        "ann_files": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/train_synt.txt")],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_synt")],
+    },
+    "ycbv_train_synt_50k": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_synt_50k.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_synt")],
+    },
+    "ycbv_train_synt_30k": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_synt_30k.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_synt")],
+    },
+    "ycbv_train_synt_100": {
+        "ann_files": [
+            osp.join(
+                DATASETS_ROOT,
+                "BOP_DATASETS/ycbv/image_sets/train_synt_100.txt",
+            )
+        ],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_synt")],
+    },
+    "ycbv_test": {
+        "ann_files": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/image_sets/keyframe.txt")],
+        "image_prefixes": [osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/test")],
+        "with_xyz": False,
+        "filter_invalid": False,
+    },
+}
+for name, update_cfg in update_cfgs.items():
+    used_cfg = copy.deepcopy(default_cfg)
+    used_cfg["name"] = name
+    used_cfg.update(update_cfg)
+    num_to_load = -1
+    if "_100" in name:
+        num_to_load = 100
+    used_cfg["num_to_load"] = num_to_load
+    SPLITS_YCBV[name] = used_cfg
+
+# single object splits ######################################################
+for obj in ref.ycbv.objects:
+    for split in [
+        "train_real",
+        "train_real_aligned_Kuw",
+        "train_real_uw",
+        "train_real_uw_every10",
+        "train_real_cmu",
+        "train_real_cmu_aligned_Kuw",
+        "train_synt",
+        "train_synt_30k",
+        "test",
+    ]:
+        name = "ycbv_{}_{}".format(obj, split)
+        if split in [
+            "train_real",
+            "train_real_aligned_Kuw",
+            "train_real_uw",
+            "train_real_uw_every10",
+            "train_real_cmu",
+            "train_real_cmu_aligned_Kuw",
+            "train_synt",
+            "train_synt_30k",
+        ]:
+            filter_invalid = True
+            with_xyz = True
+        elif split in ["test"]:
+            filter_invalid = False
+            with_xyz = False
+        else:
+            raise ValueError("{}".format(split))
+
+        if split in ["train_real_aligned_Kuw", "train_real_cmu_aligned_Kuw"]:
+            align_K_by_change_pose = True
+        else:
+            align_K_by_change_pose = False
+
+        split_idx_file_dict = {
+            "train_real": ("train_real", "train.txt"),
+            "train_real_aligned_Kuw": ("train_real", "train.txt"),
+            "train_real_uw": ("train_real", "train_real_uw.txt"),
+            "train_real_uw_every10": (
+                "train_real",
+                "train_real_uw_every10.txt",
+            ),
+            "train_real_cmu": ("train_real", "train_real_cmu.txt"),
+            "train_real_cmu_aligned_Kuw": ("train_real", "train_real_cmu.txt"),
+            "train_synt": ("train_synt", "train_synt.txt"),
+            "train_synt_30k": ("train_synt", "train_synt_30k.txt"),
+            "test": ("test", "keyframe.txt"),
+        }
+        root_name, idx_file = split_idx_file_dict[split]
+
+        if name not in SPLITS_YCBV:
+            SPLITS_YCBV[name] = dict(
+                name=name,
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),
+                objs=[obj],
+                ann_files=[
+                    osp.join(
+                        DATASETS_ROOT,
+                        "BOP_DATASETS/ycbv/image_sets/{}".format(idx_file),
+                    )
+                ],
+                image_prefixes=[osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/{}".format(root_name))],
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                with_xyz=with_xyz,
+                height=480,
+                width=640,
+                align_K_by_change_pose=align_K_by_change_pose,
+                cache_dir=osp.join(PROJ_ROOT, ".cache"),
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="ycbv",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_YCBV:
+        used_cfg = SPLITS_YCBV[name]
+    else:
+        assert (
+            data_cfg is not None
+        ), f"dataset name {name} is not registered. available datasets: {list(SPLITS_YCBV.keys())}"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, YCBV_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="ycbv",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_ycbv_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_YCBV.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    # python -m core.datasets.ycbv_d2 ycbv_test
+    dataset_name = sys.argv[1]
+    meta = MetadataCatalog.get(dataset_name)
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dataset_name)
+    with_xyz = False if "test" in dataset_name else True
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/ycbv_test-data-vis"
+    os.makedirs(dirname, exist_ok=True)
+    objs = meta.objs
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 1000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+        # # TODO: visualize pose and keypoints
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            if with_xyz:
+                xyz_path = annos[_i]["xyz_path"]
+                xyz_info = mmcv.load(xyz_path)
+                x1, y1, x2, y2 = xyz_info["xyxy"]
+                xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+                xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+                xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+                xyz_show = get_emb_show(xyz)
+                xyz_crop_show = get_emb_show(xyz_crop)
+                img_xyz = img.copy() / 255.0
+                mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+                fg_idx = np.where(mask_xyz != 0)
+                img_xyz[fg_idx[0], fg_idx[1], :] = (
+                    0.5 * xyz_show[fg_idx[0], fg_idx[1], :3] + 0.5 * img_xyz[fg_idx[0], fg_idx[1], :]
+                )
+                img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+                img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+                # diff mask
+                diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                        # xyz_show,
+                        diff_mask_xyz,
+                        xyz_crop_show,
+                        img_xyz[:, :, [2, 1, 0]],
+                        img_xyz_crop[:, :, [2, 1, 0]],
+                        img_vis_crop[:, :, ::-1],
+                    ],
+                    [
+                        "img",
+                        "vis_img",
+                        "img_vis_kpts2d",
+                        "depth",
+                        "diff_mask_xyz",
+                        "xyz_crop_show",
+                        "img_xyz",
+                        "img_xyz_crop",
+                        "img_vis_crop",
+                    ],
+                    row=3,
+                    col=3,
+                )
+            else:
+                grid_show(
+                    [
+                        img[:, :, [2, 1, 0]],
+                        img_vis[:, :, [2, 1, 0]],
+                        img_vis_kpts2d[:, :, [2, 1, 0]],
+                        depth,
+                    ],
+                    ["img", "vis_img", "img_vis_kpts2d", "depth"],
+                    row=2,
+                    col=2,
+                )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module dataset_name
+        "dataset_name" can be any pre-registered ones
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+    test_vis()
diff --git a/det/yolox/data/datasets/ycbv_pbr.py b/det/yolox/data/datasets/ycbv_pbr.py
new file mode 100644
index 0000000000000000000000000000000000000000..734e99e9b669cf906de057a9d0d803b69efe360d
--- /dev/null
+++ b/det/yolox/data/datasets/ycbv_pbr.py
@@ -0,0 +1,493 @@
+import hashlib
+import logging
+import os
+import os.path as osp
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+PROJ_ROOT = osp.normpath(osp.join(cur_dir, "../../../.."))
+sys.path.insert(0, PROJ_ROOT)
+import time
+from collections import OrderedDict
+import mmcv
+import numpy as np
+from tqdm import tqdm
+from transforms3d.quaternions import mat2quat, quat2mat
+import ref
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.structures import BoxMode
+from lib.pysixd import inout, misc
+from lib.utils.mask_utils import binary_mask_to_rle, cocosegm2mask
+from lib.utils.utils import dprint, iprint, lazy_property
+
+
+logger = logging.getLogger(__name__)
+DATASETS_ROOT = osp.normpath(osp.join(PROJ_ROOT, "datasets"))
+
+
+class YCBV_PBR_Dataset:
+    def __init__(self, data_cfg):
+        """
+        Set with_depth and with_masks default to True,
+        and decide whether to load them into dataloader/network later
+        with_masks:
+        """
+        self.name = data_cfg["name"]
+        self.data_cfg = data_cfg
+
+        self.objs = data_cfg["objs"]  # selected objects
+
+        self.dataset_root = data_cfg.get(
+            "dataset_root",
+            osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_pbr"),
+        )
+        self.xyz_root = data_cfg.get("xyz_root", osp.join(self.dataset_root, "xyz_crop"))
+        assert osp.exists(self.dataset_root), self.dataset_root
+        self.models_root = data_cfg["models_root"]  # BOP_DATASETS/ycbv/models
+        self.scale_to_meter = data_cfg["scale_to_meter"]  # 0.001
+
+        self.with_masks = data_cfg["with_masks"]
+        self.with_depth = data_cfg["with_depth"]
+
+        self.height = data_cfg["height"]  # 480
+        self.width = data_cfg["width"]  # 640
+
+        self.cache_dir = data_cfg.get("cache_dir", osp.join(PROJ_ROOT, ".cache"))  # .cache
+        self.use_cache = data_cfg.get("use_cache", True)
+        self.num_to_load = data_cfg["num_to_load"]  # -1
+        self.filter_invalid = data_cfg.get("filter_invalid", True)
+        ##################################################
+
+        # NOTE: careful! Only the selected objects
+        self.cat_ids = [cat_id for cat_id, obj_name in ref.ycbv.id2obj.items() if obj_name in self.objs]
+        # map selected objs to [0, num_objs-1]
+        self.cat2label = {v: i for i, v in enumerate(self.cat_ids)}  # id_map
+        self.label2cat = {label: cat for cat, label in self.cat2label.items()}
+        self.obj2label = OrderedDict((obj, obj_id) for obj_id, obj in enumerate(self.objs))
+        ##########################################################
+
+        self.scenes = [f"{i:06d}" for i in range(50)]
+
+    def __call__(self):
+        """Load light-weight instance annotations of all images into a list of
+        dicts in Detectron2 format.
+
+        Do not load heavy data into memory in this file, since we will
+        load the annotations of all images into memory.
+        """
+        # cache the dataset_dicts to avoid loading masks from files
+        hashed_file_name = hashlib.md5(
+            (
+                "".join([str(fn) for fn in self.objs])
+                + "dataset_dicts_{}_{}_{}_{}_{}".format(
+                    self.name,
+                    self.dataset_root,
+                    self.with_masks,
+                    self.with_depth,
+                    __name__,
+                )
+            ).encode("utf-8")
+        ).hexdigest()
+        cache_path = osp.join(
+            self.cache_dir,
+            "dataset_dicts_{}_{}.pkl".format(self.name, hashed_file_name),
+        )
+
+        if osp.exists(cache_path) and self.use_cache:
+            logger.info("load cached dataset dicts from {}".format(cache_path))
+            return mmcv.load(cache_path)
+
+        t_start = time.perf_counter()
+
+        logger.info("loading dataset dicts: {}".format(self.name))
+        self.num_instances_without_valid_segmentation = 0
+        self.num_instances_without_valid_box = 0
+        dataset_dicts = []  # ######################################################
+        # it is slow because of loading and converting masks to rle
+        for scene in tqdm(self.scenes):
+            scene_id = int(scene)
+            scene_root = osp.join(self.dataset_root, scene)
+
+            gt_dict = mmcv.load(osp.join(scene_root, "scene_gt.json"))
+            gt_info_dict = mmcv.load(osp.join(scene_root, "scene_gt_info.json"))
+            cam_dict = mmcv.load(osp.join(scene_root, "scene_camera.json"))
+
+            for str_im_id in tqdm(gt_dict, postfix=f"{scene_id}"):
+                int_im_id = int(str_im_id)
+                rgb_path = osp.join(scene_root, "rgb/{:06d}.jpg").format(int_im_id)
+                assert osp.exists(rgb_path), rgb_path
+
+                depth_path = osp.join(scene_root, "depth/{:06d}.png".format(int_im_id))
+
+                scene_im_id = f"{scene_id}/{int_im_id}"
+
+                K = np.array(cam_dict[str_im_id]["cam_K"], dtype=np.float32).reshape(3, 3)
+                depth_factor = 1000.0 / cam_dict[str_im_id]["depth_scale"]  # 10000
+
+                record = {
+                    "dataset_name": self.name,
+                    "file_name": osp.relpath(rgb_path, PROJ_ROOT),
+                    "depth_file": osp.relpath(depth_path, PROJ_ROOT),
+                    "height": self.height,
+                    "width": self.width,
+                    "image_id": int_im_id,
+                    "scene_im_id": scene_im_id,  # for evaluation
+                    "cam": K,
+                    "depth_factor": depth_factor,
+                    "img_type": "syn_pbr",  # NOTE: has background
+                }
+                insts = []
+                for anno_i, anno in enumerate(gt_dict[str_im_id]):
+                    obj_id = anno["obj_id"]
+                    if obj_id not in self.cat_ids:
+                        continue
+                    cur_label = self.cat2label[obj_id]  # 0-based label
+                    R = np.array(anno["cam_R_m2c"], dtype="float32").reshape(3, 3)
+                    t = np.array(anno["cam_t_m2c"], dtype="float32") / 1000.0
+                    pose = np.hstack([R, t.reshape(3, 1)])
+                    quat = mat2quat(R).astype("float32")
+
+                    proj = (record["cam"] @ t.T).T
+                    proj = proj[:2] / proj[2]
+
+                    bbox_visib = gt_info_dict[str_im_id][anno_i]["bbox_visib"]
+                    bbox_obj = gt_info_dict[str_im_id][anno_i]["bbox_obj"]
+                    x1, y1, w, h = bbox_visib
+                    if self.filter_invalid:
+                        if h <= 1 or w <= 1:
+                            self.num_instances_without_valid_box += 1
+                            continue
+
+                    mask_file = osp.join(
+                        scene_root,
+                        "mask/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    mask_visib_file = osp.join(
+                        scene_root,
+                        "mask_visib/{:06d}_{:06d}.png".format(int_im_id, anno_i),
+                    )
+                    assert osp.exists(mask_file), mask_file
+                    assert osp.exists(mask_visib_file), mask_visib_file
+                    # load mask visib  TODO: load both mask_visib and mask_full
+                    mask_single = mmcv.imread(mask_visib_file, "unchanged")
+                    area = mask_single.sum()
+                    if area <= 64:  # filter out too small or nearly invisible instances
+                        self.num_instances_without_valid_segmentation += 1
+                        continue
+                    mask_rle = binary_mask_to_rle(mask_single, compressed=True)
+
+                    # load mask full
+                    mask_full = mmcv.imread(mask_file, "unchanged")
+                    mask_full = mask_full.astype("bool")
+                    mask_full_rle = binary_mask_to_rle(mask_full, compressed=True)
+
+                    visib_fract = gt_info_dict[str_im_id][anno_i].get("visib_fract", 1.0)
+
+                    xyz_path = osp.join(
+                        self.xyz_root,
+                        f"{scene_id:06d}/{int_im_id:06d}_{anno_i:06d}-xyz.pkl",
+                    )
+                    # assert osp.exists(xyz_path), xyz_path
+                    inst = {
+                        "category_id": cur_label,  # 0-based label
+                        "bbox": bbox_obj,  # TODO: load both bbox_obj and bbox_visib
+                        "bbox_mode": BoxMode.XYWH_ABS,
+                        "pose": pose,
+                        "quat": quat,
+                        "trans": t,
+                        "centroid_2d": proj,  # absolute (cx, cy)
+                        "segmentation": mask_rle,
+                        "mask_full": mask_full_rle,  # TODO: load as mask_full, rle
+                        "visib_fract": visib_fract,
+                        "xyz_path": xyz_path,
+                    }
+
+                    model_info = self.models_info[str(obj_id)]
+                    inst["model_info"] = model_info
+                    # TODO: using full mask and full xyz
+                    for key in ["bbox3d_and_center"]:
+                        inst[key] = self.models[cur_label][key]
+                    insts.append(inst)
+                if len(insts) == 0:  # filter im without anno
+                    continue
+                record["annotations"] = insts
+                dataset_dicts.append(record)
+
+        if self.num_instances_without_valid_segmentation > 0:
+            logger.warning(
+                "Filtered out {} instances without valid segmentation. "
+                "There might be issues in your dataset generation process.".format(
+                    self.num_instances_without_valid_segmentation
+                )
+            )
+        if self.num_instances_without_valid_box > 0:
+            logger.warning(
+                "Filtered out {} instances without valid box. "
+                "There might be issues in your dataset generation process.".format(self.num_instances_without_valid_box)
+            )
+        ##########################################################################
+        if self.num_to_load > 0:
+            self.num_to_load = min(int(self.num_to_load), len(dataset_dicts))
+            dataset_dicts = dataset_dicts[: self.num_to_load]
+        logger.info("loaded {} dataset dicts, using {}s".format(len(dataset_dicts), time.perf_counter() - t_start))
+
+        mmcv.mkdir_or_exist(osp.dirname(cache_path))
+        mmcv.dump(dataset_dicts, cache_path, protocol=4)
+        logger.info("Dumped dataset_dicts to {}".format(cache_path))
+        return dataset_dicts
+
+    @lazy_property
+    def models_info(self):
+        models_info_path = osp.join(self.models_root, "models_info.json")
+        assert osp.exists(models_info_path), models_info_path
+        models_info = mmcv.load(models_info_path)  # key is str(obj_id)
+        return models_info
+
+    @lazy_property
+    def models(self):
+        """Load models into a list."""
+        cache_path = osp.join(self.models_root, "models_{}.pkl".format(self.name))
+        if osp.exists(cache_path) and self.use_cache:
+            # dprint("{}: load cached object models from {}".format(self.name, cache_path))
+            return mmcv.load(cache_path)
+
+        models = []
+        for obj_name in self.objs:
+            model = inout.load_ply(
+                osp.join(
+                    self.models_root,
+                    f"obj_{ref.ycbv.obj2id[obj_name]:06d}.ply",
+                ),
+                vertex_scale=self.scale_to_meter,
+            )
+            # NOTE: the bbox3d_and_center is not obtained from centered vertices
+            # for BOP models, not a big problem since they had been centered
+            model["bbox3d_and_center"] = misc.get_bbox3d_and_center(model["pts"])
+
+            models.append(model)
+        logger.info("cache models to {}".format(cache_path))
+        mmcv.dump(models, cache_path, protocol=4)
+        return models
+
+    def image_aspect_ratio(self):
+        return self.width / self.height  # 4/3
+
+
+########### register datasets ############################################################
+
+
+def get_ycbv_metadata(obj_names, ref_key):
+    """task specific metadata."""
+    data_ref = ref.__dict__[ref_key]
+
+    cur_sym_infos = {}  # label based key
+    loaded_models_info = data_ref.get_models_info()
+
+    for i, obj_name in enumerate(obj_names):
+        obj_id = data_ref.obj2id[obj_name]
+        model_info = loaded_models_info[str(obj_id)]
+        if "symmetries_discrete" in model_info or "symmetries_continuous" in model_info:
+            sym_transforms = misc.get_symmetry_transformations(model_info, max_sym_disc_step=0.01)
+            sym_info = np.array([sym["R"] for sym in sym_transforms], dtype=np.float32)
+        else:
+            sym_info = None
+        cur_sym_infos[i] = sym_info
+
+    meta = {"thing_classes": obj_names, "sym_infos": cur_sym_infos}
+    return meta
+
+
+ycbv_model_root = "BOP_DATASETS/ycbv/models/"
+################################################################################
+
+
+SPLITS_YCBV_PBR = dict(
+    ycbv_train_pbr=dict(
+        name="ycbv_train_pbr",
+        objs=ref.ycbv.objects,  # selected objects
+        dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_pbr"),
+        models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),
+        xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_pbr/xyz_crop"),
+        scale_to_meter=0.001,
+        with_masks=True,  # (load masks but may not use it)
+        with_depth=True,  # (load depth path here, but may not use it)
+        height=480,
+        width=640,
+        use_cache=True,
+        num_to_load=-1,
+        filter_invalid=True,
+        ref_key="ycbv",
+    )
+)
+
+# single obj splits
+for obj in ref.ycbv.objects:
+    for split in ["train_pbr"]:
+        name = "ycbv_{}_{}".format(obj, split)
+        if split in ["train_pbr"]:
+            filter_invalid = True
+        elif split in ["test"]:
+            filter_invalid = False
+        else:
+            raise ValueError("{}".format(split))
+        if name not in SPLITS_YCBV_PBR:
+            SPLITS_YCBV_PBR[name] = dict(
+                name=name,
+                objs=[obj],  # only this obj
+                dataset_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_pbr"),
+                models_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/models"),
+                xyz_root=osp.join(DATASETS_ROOT, "BOP_DATASETS/ycbv/train_pbr/xyz_crop"),
+                scale_to_meter=0.001,
+                with_masks=True,  # (load masks but may not use it)
+                with_depth=True,  # (load depth path here, but may not use it)
+                height=480,
+                width=640,
+                use_cache=True,
+                num_to_load=-1,
+                filter_invalid=filter_invalid,
+                ref_key="ycbv",
+            )
+
+
+def register_with_name_cfg(name, data_cfg=None):
+    """Assume pre-defined datasets live in `./datasets`.
+
+    Args:
+        name: datasnet_name,
+        data_cfg: if name is in existing SPLITS, use pre-defined data_cfg
+            otherwise requires data_cfg
+            data_cfg can be set in cfg.DATA_CFG.name
+    """
+    dprint("register dataset: {}".format(name))
+    if name in SPLITS_YCBV_PBR:
+        used_cfg = SPLITS_YCBV_PBR[name]
+    else:
+        assert data_cfg is not None, f"dataset name {name} is not registered"
+        used_cfg = data_cfg
+    DatasetCatalog.register(name, YCBV_PBR_Dataset(used_cfg))
+    # something like eval_types
+    MetadataCatalog.get(name).set(
+        id="ycbv",  # NOTE: for pvnet to determine module
+        ref_key=used_cfg["ref_key"],
+        objs=used_cfg["objs"],
+        eval_error_types=["ad", "rete", "proj"],
+        evaluator_type="bop",
+        **get_ycbv_metadata(obj_names=used_cfg["objs"], ref_key=used_cfg["ref_key"]),
+    )
+
+
+def get_available_datasets():
+    return list(SPLITS_YCBV_PBR.keys())
+
+
+#### tests ###############################################
+def test_vis():
+    dset_name = sys.argv[1]
+    assert dset_name in DatasetCatalog.list()
+
+    meta = MetadataCatalog.get(dset_name)
+    dprint("MetadataCatalog: ", meta)
+    objs = meta.objs
+
+    t_start = time.perf_counter()
+    dicts = DatasetCatalog.get(dset_name)
+    logger.info("Done loading {} samples with {:.3f}s.".format(len(dicts), time.perf_counter() - t_start))
+
+    dirname = "output/{}-data-vis".format(dset_name)
+    os.makedirs(dirname, exist_ok=True)
+    for d in dicts:
+        img = read_image_mmcv(d["file_name"], format="BGR")
+        depth = mmcv.imread(d["depth_file"], "unchanged") / 10000.0
+
+        imH, imW = img.shape[:2]
+        annos = d["annotations"]
+        masks = [cocosegm2mask(anno["segmentation"], imH, imW) for anno in annos]
+        bboxes = [anno["bbox"] for anno in annos]
+        bbox_modes = [anno["bbox_mode"] for anno in annos]
+        bboxes_xyxy = np.array(
+            [BoxMode.convert(box, box_mode, BoxMode.XYXY_ABS) for box, box_mode in zip(bboxes, bbox_modes)]
+        )
+        kpts_3d_list = [anno["bbox3d_and_center"] for anno in annos]
+        quats = [anno["quat"] for anno in annos]
+        transes = [anno["trans"] for anno in annos]
+        Rs = [quat2mat(quat) for quat in quats]
+        # 0-based label
+        cat_ids = [anno["category_id"] for anno in annos]
+        K = d["cam"]
+        kpts_2d = [misc.project_pts(kpt3d, K, R, t) for kpt3d, R, t in zip(kpts_3d_list, Rs, transes)]
+
+        labels = [objs[cat_id] for cat_id in cat_ids]
+        for _i in range(len(annos)):
+            img_vis = vis_image_mask_bbox_cv2(
+                img,
+                masks[_i : _i + 1],
+                bboxes=bboxes_xyxy[_i : _i + 1],
+                labels=labels[_i : _i + 1],
+            )
+            img_vis_kpts2d = misc.draw_projected_box3d(img_vis.copy(), kpts_2d[_i])
+            xyz_path = annos[_i]["xyz_path"]
+            xyz_info = mmcv.load(xyz_path)
+            x1, y1, x2, y2 = xyz_info["xyxy"]
+            xyz_crop = xyz_info["xyz_crop"].astype(np.float32)
+            xyz = np.zeros((imH, imW, 3), dtype=np.float32)
+            xyz[y1 : y2 + 1, x1 : x2 + 1, :] = xyz_crop
+            xyz_show = get_emb_show(xyz)
+            xyz_crop_show = get_emb_show(xyz_crop)
+            img_xyz = img.copy() / 255.0
+            mask_xyz = ((xyz[:, :, 0] != 0) | (xyz[:, :, 1] != 0) | (xyz[:, :, 2] != 0)).astype("uint8")
+            fg_idx = np.where(mask_xyz != 0)
+            img_xyz[fg_idx[0], fg_idx[1], :] = xyz_show[fg_idx[0], fg_idx[1], :3]
+            img_xyz_crop = img_xyz[y1 : y2 + 1, x1 : x2 + 1, :]
+            img_vis_crop = img_vis[y1 : y2 + 1, x1 : x2 + 1, :]
+            # diff mask
+            diff_mask_xyz = np.abs(masks[_i] - mask_xyz)[y1 : y2 + 1, x1 : x2 + 1]
+
+            grid_show(
+                [
+                    img[:, :, [2, 1, 0]],
+                    img_vis[:, :, [2, 1, 0]],
+                    img_vis_kpts2d[:, :, [2, 1, 0]],
+                    depth,
+                    # xyz_show,
+                    diff_mask_xyz,
+                    xyz_crop_show,
+                    img_xyz[:, :, [2, 1, 0]],
+                    img_xyz_crop[:, :, [2, 1, 0]],
+                    img_vis_crop,
+                ],
+                [
+                    "img",
+                    "vis_img",
+                    "img_vis_kpts2d",
+                    "depth",
+                    "diff_mask_xyz",
+                    "xyz_crop_show",
+                    "img_xyz",
+                    "img_xyz_crop",
+                    "img_vis_crop",
+                ],
+                row=3,
+                col=3,
+            )
+
+
+if __name__ == "__main__":
+    """Test the  dataset loader.
+
+    Usage:
+        python -m this_module ycbv_pbr_train
+    """
+    from lib.vis_utils.image import grid_show
+    from lib.utils.setup_logger import setup_my_logger
+
+    import detectron2.data.datasets  # noqa # add pre-defined metadata
+    from lib.vis_utils.image import vis_image_mask_bbox_cv2
+    from core.utils.utils import get_emb_show
+    from core.utils.data_utils import read_image_mmcv
+
+    print("sys.argv:", sys.argv)
+    logger = setup_my_logger(name="core")
+    register_with_name_cfg(sys.argv[1])
+    print("dataset catalog: ", DatasetCatalog.list())
+
+    test_vis()
diff --git a/det/yolox/data/samplers.py b/det/yolox/data/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..93dea1d3466711c1d58a29ee435fafedc9624438
--- /dev/null
+++ b/det/yolox/data/samplers.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import itertools
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import BatchSampler as torchBatchSampler
+from torch.utils.data.sampler import Sampler
+
+
+class YoloBatchSampler(torchBatchSampler):
+    """This batch sampler will generate mini-batches of (mosaic, index) tuples
+    from another sampler.
+
+    It works just like the
+    :class:`torch.utils.data.sampler.BatchSampler`, but it will turn
+    on/off the mosaic aug.
+    """
+
+    def __init__(self, *args, mosaic=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.mosaic = mosaic
+
+    def __iter__(self):
+        for batch in super().__iter__():
+            yield [(self.mosaic, idx) for idx in batch]
+
+
+class InfiniteSampler(Sampler):
+    """In training, we only care about the "infinite stream" of training data.
+
+    So this sampler produces an infinite stream of indices and all
+    workers cooperate to correctly shuffle the indices and sample
+    different indices. The samplers in each worker effectively produces
+    `indices[worker_id::num_workers]` where `indices` is an infinite
+    stream of indices consisting of `shuffle(range(size)) +
+    shuffle(range(size)) + ...` (if shuffle is True) or `range(size) +
+    range(size) + ...` (if shuffle is False)
+    """
+
+    def __init__(
+        self,
+        size: int,
+        shuffle: bool = True,
+        seed: Optional[int] = 0,
+        rank=0,
+        world_size=1,
+    ):
+        """
+        Args:
+            size (int): the total number of data of the underlying dataset to sample from
+            shuffle (bool): whether to shuffle the indices or not
+            seed (int): the initial seed of the shuffle. Must be the same
+                across all workers. If None, will use a random seed shared
+                among workers (require synchronization among all workers).
+        """
+        self._size = size
+        assert size > 0
+        self._shuffle = shuffle
+        self._seed = int(seed)
+
+        if dist.is_available() and dist.is_initialized():
+            self._rank = dist.get_rank()
+            self._world_size = dist.get_world_size()
+        else:
+            self._rank = rank
+            self._world_size = world_size
+
+    def __iter__(self):
+        start = self._rank
+        yield from itertools.islice(self._infinite_indices(), start, None, self._world_size)
+
+    def _infinite_indices(self):
+        g = torch.Generator()
+        g.manual_seed(self._seed)
+        while True:
+            if self._shuffle:
+                yield from torch.randperm(self._size, generator=g)
+            else:
+                yield from torch.arange(self._size)
+
+    def __len__(self):
+        return self._size // self._world_size
diff --git a/det/yolox/engine/__init__.py b/det/yolox/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/det/yolox/engine/launch.py b/det/yolox/engine/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a299337846ebbdbdedb318e60e6f9552d80615e
--- /dev/null
+++ b/det/yolox/engine/launch.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Code are based on
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import sys
+from datetime import timedelta
+from loguru import logger
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+import det.yolox.utils.dist as comm
+
+__all__ = ["launch"]
+
+
+DEFAULT_TIMEOUT = timedelta(minutes=30)
+
+
+def _find_free_port():
+    """Find an available port of current machine / node."""
+    import socket
+
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    # Binding to port 0 will cause the OS to find an available port for us
+    sock.bind(("", 0))
+    port = sock.getsockname()[1]
+    sock.close()
+    # NOTE: there is still a chance the port could be taken by other processes.
+    return port
+
+
+def launch(
+    main_func,
+    num_gpus_per_machine,
+    num_machines=1,
+    machine_rank=0,
+    backend="nccl",
+    dist_url=None,
+    args=(),
+    timeout=DEFAULT_TIMEOUT,
+):
+    """
+    Args:
+        main_func: a function that will be called by `main_func(*args)`
+        num_machines (int): the total number of machines
+        machine_rank (int): the rank of this machine (one per machine)
+        dist_url (str): url to connect to for distributed training, including protocol
+                       e.g. "tcp://127.0.0.1:8686".
+                       Can be set to auto to automatically select a free port on localhost
+        args (tuple): arguments passed to main_func
+    """
+    world_size = num_machines * num_gpus_per_machine
+    if world_size > 1:
+        # https://github.com/pytorch/pytorch/pull/14391
+        # TODO prctl in spawned processes
+
+        if dist_url == "auto":
+            assert num_machines == 1, "dist_url=auto cannot work with distributed training."
+            port = _find_free_port()
+            dist_url = f"tcp://127.0.0.1:{port}"
+
+        start_method = "spawn"
+        cache = vars(args[1]).get("cache", False)
+
+        # To use numpy memmap for caching image into RAM, we have to use fork method
+        if cache:
+            assert sys.platform != "win32", (
+                "As Windows platform doesn't support fork method, " "do not add --cache in your training command."
+            )
+            start_method = "fork"
+
+        mp.start_processes(
+            _distributed_worker,
+            nprocs=num_gpus_per_machine,
+            args=(
+                main_func,
+                world_size,
+                num_gpus_per_machine,
+                machine_rank,
+                backend,
+                dist_url,
+                args,
+            ),
+            daemon=False,
+            start_method=start_method,
+        )
+    else:
+        main_func(*args)
+
+
+def _distributed_worker(
+    local_rank,
+    main_func,
+    world_size,
+    num_gpus_per_machine,
+    machine_rank,
+    backend,
+    dist_url,
+    args,
+    timeout=DEFAULT_TIMEOUT,
+):
+    assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
+    global_rank = machine_rank * num_gpus_per_machine + local_rank
+    logger.info("Rank {} initialization finished.".format(global_rank))
+    try:
+        dist.init_process_group(
+            backend=backend,
+            init_method=dist_url,
+            world_size=world_size,
+            rank=global_rank,
+            timeout=timeout,
+        )
+    except Exception:
+        logger.error("Process group URL: {}".format(dist_url))
+        raise
+
+    # Setup the local process group (which contains ranks within the same machine)
+    assert comm._LOCAL_PROCESS_GROUP is None
+    num_machines = world_size // num_gpus_per_machine
+    for i in range(num_machines):
+        ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
+        pg = dist.new_group(ranks_on_i)
+        if i == machine_rank:
+            comm._LOCAL_PROCESS_GROUP = pg
+
+    # synchronize is needed here to prevent a possible timeout after calling init_process_group
+    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+    comm.synchronize()
+
+    assert num_gpus_per_machine <= torch.cuda.device_count()
+    torch.cuda.set_device(local_rank)
+
+    main_func(*args)
diff --git a/det/yolox/engine/trainer.py b/det/yolox/engine/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6e0586cd073c0f2d6474c159413f39f2b409401
--- /dev/null
+++ b/det/yolox/engine/trainer.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import datetime
+import os
+import time
+from loguru import logger
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from det.yolox.data import DataPrefetcher
+from det.yolox.utils import (
+    MeterBuffer,
+    ModelEMA,
+    all_reduce_norm,
+    get_local_rank,
+    get_model_info,
+    get_rank,
+    get_world_size,
+    gpu_mem_usage,
+    is_parallel,
+    load_ckpt,
+    occupy_mem,
+    save_checkpoint,
+    setup_logger,
+    synchronize,
+)
+
+
+class Trainer:
+    def __init__(self, exp, args):
+        # init function only defines some basic attr, other attrs like model, optimizer are built in
+        # before_train methods.
+        self.exp = exp
+        self.args = args
+
+        # training related attr
+        self.max_epoch = exp.max_epoch
+        self.amp_training = args.fp16
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+        self.is_distributed = get_world_size() > 1
+        self.rank = get_rank()
+        self.local_rank = get_local_rank()
+        self.device = "cuda:{}".format(self.local_rank)
+        self.use_model_ema = exp.ema
+
+        # data/dataloader related attr
+        self.data_type = torch.float16 if args.fp16 else torch.float32
+        self.input_size = exp.input_size
+        self.best_ap = 0
+
+        # metric record
+        self.meter = MeterBuffer(window_size=exp.print_interval)
+        self.file_name = os.path.join(exp.output_dir, args.experiment_name)
+
+        if self.rank == 0:
+            os.makedirs(self.file_name, exist_ok=True)
+
+        setup_logger(
+            self.file_name,
+            distributed_rank=self.rank,
+            filename="train_log.txt",
+            mode="a",
+        )
+
+    def train(self):
+        self.before_train()
+        try:
+            self.train_in_epoch()
+        except Exception:
+            raise
+        finally:
+            self.after_train()
+
+    def train_in_epoch(self):
+        for self.epoch in range(self.start_epoch, self.max_epoch):
+            self.before_epoch()
+            self.train_in_iter()
+            self.after_epoch()
+
+    def train_in_iter(self):
+        for self.iter in range(self.max_iter):
+            self.before_iter()
+            self.train_one_iter()
+            self.after_iter()
+
+    def train_one_iter(self):
+        iter_start_time = time.time()
+
+        inps, targets = self.prefetcher.next()
+        inps = inps.to(self.data_type)
+        targets = targets.to(self.data_type)
+        targets.requires_grad = False
+        inps, targets = self.exp.preprocess(inps, targets, self.input_size)
+        data_end_time = time.time()
+
+        with torch.cuda.amp.autocast(enabled=self.amp_training):
+            outputs = self.model(inps, targets)
+
+        loss = outputs["total_loss"]
+
+        self.optimizer.zero_grad()
+        self.scaler.scale(loss).backward()
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+
+        if self.use_model_ema:
+            self.ema_model.update(self.model)
+
+        lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
+        for param_group in self.optimizer.param_groups:
+            param_group["lr"] = lr
+
+        iter_end_time = time.time()
+        self.meter.update(
+            iter_time=iter_end_time - iter_start_time,
+            data_time=data_end_time - iter_start_time,
+            lr=lr,
+            **outputs,
+        )
+
+    def before_train(self):
+        logger.info("args: {}".format(self.args))
+        logger.info("exp value:\n{}".format(self.exp))
+
+        # model related init
+        torch.cuda.set_device(self.local_rank)
+        model = self.exp.get_model()
+        logger.info("Model Summary: {}".format(get_model_info(model, self.exp.test_size)))
+        model.to(self.device)
+
+        # solver related init
+        self.optimizer = self.exp.get_optimizer(self.args.batch_size)
+
+        # value of epoch will be set in `resume_train`
+        model = self.resume_train(model)
+
+        # data related init
+        self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
+        self.train_loader = self.exp.get_data_loader(
+            batch_size=self.args.batch_size,
+            is_distributed=self.is_distributed,
+            no_aug=self.no_aug,
+            cache_img=self.args.cache,
+        )
+        logger.info("init prefetcher, this might take one minute or less...")
+        self.prefetcher = DataPrefetcher(self.train_loader)
+        # max_iter means iters per epoch
+        self.max_iter = len(self.train_loader)
+
+        self.lr_scheduler = self.exp.get_lr_scheduler(self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter)
+        if self.args.occupy:
+            occupy_mem(self.local_rank)
+
+        if self.is_distributed:
+            model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
+
+        if self.use_model_ema:
+            self.ema_model = ModelEMA(model, 0.9998)
+            self.ema_model.updates = self.max_iter * self.start_epoch
+
+        self.model = model
+        self.model.train()
+
+        self.evaluator = self.exp.get_evaluator(batch_size=self.args.batch_size, is_distributed=self.is_distributed)
+        # Tensorboard logger
+        if self.rank == 0:
+            self.tblogger = SummaryWriter(self.file_name)
+
+        logger.info("Training start...")
+        logger.info("\n{}".format(model))
+
+    def after_train(self):
+        logger.info("Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100))
+
+    def before_epoch(self):
+        logger.info("---> start train epoch{}".format(self.epoch + 1))
+
+        if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
+            logger.info("--->No mosaic aug now!")
+            self.train_loader.close_mosaic()
+            logger.info("--->Add additional L1 loss now!")
+            if self.is_distributed:
+                self.model.module.head.use_l1 = True
+            else:
+                self.model.head.use_l1 = True
+            self.exp.eval_interval = 1
+            if not self.no_aug:
+                self.save_ckpt(ckpt_name="last_mosaic_epoch")
+
+    def after_epoch(self):
+        self.save_ckpt(ckpt_name="latest")
+
+        if (self.epoch + 1) % self.exp.eval_interval == 0:
+            all_reduce_norm(self.model)
+            self.evaluate_and_save_model()
+
+    def before_iter(self):
+        pass
+
+    def after_iter(self):
+        """`after_iter` contains two parts of logic:
+
+        * log information
+        * reset setting of resize
+        """
+        # log needed information
+        if (self.iter + 1) % self.exp.print_interval == 0:
+            # TODO check ETA logic
+            left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
+            eta_seconds = self.meter["iter_time"].global_avg * left_iters
+            eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
+
+            progress_str = "epoch: {}/{}, iter: {}/{}".format(
+                self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
+            )
+            loss_meter = self.meter.get_filtered_meter("loss")
+            loss_str = ", ".join(["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()])
+
+            time_meter = self.meter.get_filtered_meter("time")
+            time_str = ", ".join(["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()])
+
+            logger.info(
+                "{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
+                    progress_str,
+                    gpu_mem_usage(),
+                    time_str,
+                    loss_str,
+                    self.meter["lr"].latest,
+                )
+                + (", size: {:d}, {}".format(self.input_size[0], eta_str))
+            )
+            self.meter.clear_meters()
+
+        # random resizing
+        if (self.progress_in_iter + 1) % 10 == 0:
+            self.input_size = self.exp.random_resize(self.train_loader, self.epoch, self.rank, self.is_distributed)
+
+    @property
+    def progress_in_iter(self):
+        return self.epoch * self.max_iter + self.iter
+
+    def resume_train(self, model):
+        if self.args.resume:
+            logger.info("resume training")
+            if self.args.ckpt is None:
+                ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth")
+            else:
+                ckpt_file = self.args.ckpt
+
+            ckpt = torch.load(ckpt_file, map_location=self.device)
+            # resume the model/optimizer state dict
+            model.load_state_dict(ckpt["model"])
+            self.optimizer.load_state_dict(ckpt["optimizer"])
+            # resume the training states variables
+            start_epoch = self.args.start_epoch - 1 if self.args.start_epoch is not None else ckpt["start_epoch"]
+            self.start_epoch = start_epoch
+            logger.info("loaded checkpoint '{}' (epoch {})".format(self.args.resume, self.start_epoch))  # noqa
+        else:
+            if self.args.ckpt is not None:
+                logger.info("loading checkpoint for fine tuning")
+                ckpt_file = self.args.ckpt
+                ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
+                model = load_ckpt(model, ckpt)
+            self.start_epoch = 0
+
+        return model
+
+    def evaluate_and_save_model(self):
+        if self.use_model_ema:
+            evalmodel = self.ema_model.ema
+        else:
+            evalmodel = self.model
+            if is_parallel(evalmodel):
+                evalmodel = evalmodel.module
+
+        ap50_95, ap50, summary = self.exp.eval(evalmodel, self.evaluator, self.is_distributed)
+        self.model.train()
+        if self.rank == 0:
+            self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
+            self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
+            logger.info("\n" + summary)
+        synchronize()
+
+        self.save_ckpt("last_epoch", ap50_95 > self.best_ap)
+        self.best_ap = max(self.best_ap, ap50_95)
+
+    def save_ckpt(self, ckpt_name, update_best_ckpt=False):
+        if self.rank == 0:
+            save_model = self.ema_model.ema if self.use_model_ema else self.model
+            logger.info("Save weights to {}".format(self.file_name))
+            ckpt_state = {
+                "start_epoch": self.epoch + 1,
+                "model": save_model.state_dict(),
+                "optimizer": self.optimizer.state_dict(),
+            }
+            save_checkpoint(
+                ckpt_state,
+                update_best_ckpt,
+                self.file_name,
+                ckpt_name,
+            )
diff --git a/det/yolox/engine/yolox_inference.py b/det/yolox/engine/yolox_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d59a636cfe9526c2bf51e7e641156a94bf1a66
--- /dev/null
+++ b/det/yolox/engine/yolox_inference.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import datetime
+import logging
+import time
+from collections import OrderedDict, abc
+from contextlib import ExitStack, contextmanager
+from typing import List, Union
+
+from tqdm import tqdm
+from omegaconf import OmegaConf
+import torch
+from torch import nn
+from torch.cuda.amp import autocast
+
+from detectron2.utils.logger import log_every_n_seconds
+from detectron2.evaluation import DatasetEvaluator, DatasetEvaluators, inference_context
+
+from core.utils.my_comm import get_world_size, is_main_process
+from det.yolox.utils import (
+    gather,
+    postprocess,
+    synchronize,
+    time_synchronized,
+    xyxy2xywh,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def yolox_inference_on_dataset(
+    model,
+    data_loader,
+    evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
+    amp_test=False,
+    half_test=False,
+    trt_file=None,
+    decoder=None,
+    test_cfg=OmegaConf.create(
+        dict(
+            test_size=(640, 640),
+            conf_thr=0.01,
+            nms_thr=0.65,
+            num_classes=80,
+        )
+    ),
+    val_cfg=OmegaConf.create(
+        dict(
+            eval_cached=False,
+        )
+    ),
+):
+    """Run model on the data_loader and evaluate the metrics with evaluator.
+    Also benchmark the inference speed of `model.__call__` accurately. The
+    model will be used in eval mode.
+
+    Args:
+        model (callable): a callable which takes an object from
+            `data_loader` and returns some outputs.
+
+            If it's an nn.Module, it will be temporarily set to `eval` mode.
+            If you wish to evaluate a model in `training` mode instead, you can
+            wrap the given model and override its behavior of `.eval()` and `.train()`.
+        data_loader: an iterable object with a length.
+            The elements it generates will be the inputs to the model.
+        evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
+            but don't want to do any evaluation.
+
+    Returns:
+        The return value of `evaluator.evaluate()`
+    """
+    num_devices = get_world_size()
+    assert int(half_test) + int(amp_test) <= 1, "half_test and amp_test cannot both be set"
+    logger.info(f"half_test: {half_test}, amp_test: {amp_test}")
+
+    logger.info("Start inference on {} batches".format(len(data_loader)))
+
+    cfg = test_cfg
+
+    total = len(data_loader)  # inference data loader must have a fixed length
+    if evaluator is None:
+        # create a no-op evaluator
+        evaluator = DatasetEvaluators([])
+    if isinstance(evaluator, abc.MutableSequence):
+        evaluator = DatasetEvaluators(evaluator)
+
+    if val_cfg.get("eval_cached", False):
+        results = evaluator.evaluate(eval_cached=True)
+        # An evaluator may return None when not in main process.
+        # Replace it by an empty dict instead to make it easier for downstream code to handle
+        if results is None:
+            results = {}
+        return results
+
+    evaluator.reset()
+    num_warmup = min(5, total - 1)
+    start_time = time.perf_counter()
+    total_data_time = 0
+    total_compute_time = 0
+    total_nms_time = 0
+    total_eval_time = 0
+    iters_record = 0
+    augment = cfg.get("augment", False)
+    with ExitStack() as stack:
+        if isinstance(model, nn.Module):
+            stack.enter_context(inference_context(model))
+        stack.enter_context(torch.no_grad())
+
+        tensor_type = torch.cuda.HalfTensor if (half_test or amp_test) else torch.cuda.FloatTensor
+        if half_test:
+            model = model.half()
+
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, cfg.test_size[0], cfg.test_size[1]).cuda()
+            model(x)
+            model = model_trt
+
+        progress_bar = tqdm if is_main_process() else iter
+
+        start_data_time = time_synchronized()
+        for idx, inputs in enumerate(progress_bar(data_loader)):
+            imgs, _, scene_im_ids, info_imgs, ids = inputs
+            imgs = imgs.type(tensor_type)
+
+            compute_time = 0
+
+            # skip the the last iters since batchsize might be not enough for batch inference
+            # is_time_record = idx < len(data_loader) - 1
+            is_time_record = idx < len(data_loader)
+
+            total_data_time += time.perf_counter() - start_data_time
+            if idx == num_warmup and is_time_record:
+                start_time = time.perf_counter()
+                total_data_time = 0
+                total_compute_time = 0
+                total_nms_time = 0
+                total_eval_time = 0
+                iters_record = 0
+
+            if is_time_record:
+                start_compute_time = time.perf_counter()
+
+            if trt_file is not None:
+                det_preds = model(imgs)
+                outputs = {"det_preds": det_preds}
+            else:
+                # outputs = model(imgs)
+                outputs = model(imgs, augment=augment, cfg=cfg)
+
+            if decoder is not None:
+                outputs["det_preds"] = decoder(outputs["det_preds"], dtype=outputs.type())
+            if is_time_record:
+                infer_end_time = time_synchronized()
+                total_compute_time += infer_end_time - start_compute_time
+
+            # import ipdb; ipdb.set_trace()
+            outputs["det_preds"] = postprocess(outputs["det_preds"], cfg.num_classes, cfg.conf_thr, cfg.nms_thr)
+            if is_time_record:
+                nms_end_time = time_synchronized()
+                total_nms_time += nms_end_time - infer_end_time
+                compute_time = nms_end_time - start_compute_time
+                outputs["time"] = compute_time
+
+            evaluator.process(outputs, scene_im_ids, info_imgs, ids, cfg)
+            if is_time_record:
+                total_eval_time += time.perf_counter() - nms_end_time
+                # iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
+                iters_record += 1
+
+            data_ms_per_iter = total_data_time / iters_record * 1000
+            compute_ms_per_iter = total_compute_time / iters_record * 1000
+            nms_ms_per_iter = total_nms_time / iters_record * 1000
+            eval_ms_per_iter = total_eval_time / iters_record * 1000
+            total_ms_per_iter = (time.perf_counter() - start_time) / iters_record * 1000
+            if idx >= num_warmup * 2 or compute_ms_per_iter > 5000:
+                eta = datetime.timedelta(seconds=int(total_ms_per_iter / 1000 * (total - idx - 1)))
+                log_every_n_seconds(
+                    logging.WARN,
+                    (
+                        f"Inference done {idx + 1}/{total}. "
+                        f"Dataloading: {data_ms_per_iter:.4f} ms/iter. "
+                        f"Inference: {compute_ms_per_iter:.4f} ms/iter. "
+                        f"NMS: {nms_ms_per_iter:.4f} ms/iter. "
+                        f"Eval: {eval_ms_per_iter:.4f} ms/iter. "
+                        f"Total: {total_ms_per_iter:.4f} ms/iter. "
+                        f"ETA={eta}"
+                    ),
+                    n=5,
+                    name=__name__,
+                )
+            start_data_time = time.perf_counter()
+
+    # Measure the time only for this worker (before the synchronization barrier)
+    total_time = time.perf_counter() - start_time
+    total_time_str = str(datetime.timedelta(seconds=total_time))
+
+    # NOTE this format is parsed by grep
+    logger.info(
+        "Total inference time: {} ({:.3f} ms / iter per device, on {} devices)".format(
+            total_time_str, total_time * 1000 / (total - num_warmup), num_devices
+        )
+    )
+    total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
+    logger.info(
+        "Total inference pure compute time: {} ({:.3f} ms / iter per device, on {} devices)".format(
+            total_compute_time_str,
+            total_compute_time * 1000 / (total - num_warmup),
+            num_devices,
+        )
+    )
+    total_nms_time_str = str(datetime.timedelta(seconds=int(total_nms_time)))
+    logger.info(
+        "Total inference nms time: {} ({:.3f} ms / iter per device, on {} devices)".format(
+            total_nms_time_str,
+            total_nms_time * 1000 / (total - num_warmup),
+            num_devices,
+        )
+    )
+
+    results = evaluator.evaluate()
+    # An evaluator may return None when not in main process.
+    # Replace it by an empty dict instead to make it easier for downstream code to handle
+    if results is None:
+        results = {}
+    return results
diff --git a/det/yolox/engine/yolox_setup.py b/det/yolox/engine/yolox_setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..7412fafe50037b30ba132158064a10de51fea702
--- /dev/null
+++ b/det/yolox/engine/yolox_setup.py
@@ -0,0 +1,174 @@
+import argparse
+
+# from loguru import logger
+import os
+import os.path as osp
+import sys
+import weakref
+from collections import OrderedDict
+from typing import Optional
+import torch
+from fvcore.nn.precise_bn import get_bn_modules
+from omegaconf import OmegaConf
+from torch.nn.parallel import DistributedDataParallel
+
+import detectron2.data.transforms as T
+from detectron2.checkpoint import DetectionCheckpointer
+from detectron2.config import CfgNode, LazyConfig
+from detectron2.data import (
+    MetadataCatalog,
+    build_detection_test_loader,
+    build_detection_train_loader,
+)
+from detectron2.evaluation import (
+    DatasetEvaluator,
+    inference_on_dataset,
+    print_csv_format,
+    verify_results,
+)
+from detectron2.modeling import build_model
+from detectron2.solver import build_lr_scheduler, build_optimizer
+
+from detectron2.utils.collect_env import collect_env_info
+from detectron2.utils.env import seed_all_rng
+from detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
+from detectron2.utils.file_io import PathManager
+
+import mmcv
+import PIL
+
+from lib.utils.setup_logger import setup_my_logger
+from lib.utils.setup_logger_loguru import setup_logger
+from lib.utils.time_utils import get_time_str
+from lib.utils.config_utils import try_get_key
+import core.utils.my_comm as comm
+from core.utils.my_writer import (
+    MyCommonMetricPrinter,
+    MyJSONWriter,
+    MyTensorboardXWriter,
+)
+
+
+def _highlight(code, filename):
+    try:
+        import pygments
+    except ImportError:
+        return code
+
+    from pygments.lexers import Python3Lexer, YamlLexer
+    from pygments.formatters import Terminal256Formatter
+
+    lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
+    code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
+    return code
+
+
+def default_yolox_setup(cfg, args):
+    """NOTE: compared to d2,
+        1) logger has line number;
+        2) more project related logger names;
+        3) setup mmcv image backend
+    Perform some basic common setups at the beginning of a job, including:
+
+    1. Set up the detectron2 logger
+    2. Log basic information about environment, cmdline arguments, and config
+    3. Backup the config to the output directory
+
+    Args:
+        cfg (CfgNode or omegaconf.DictConfig): the full config to be used
+        args (argparse.NameSpace): the command line arguments to be logged
+    """
+    output_dir = try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
+    if comm.is_main_process() and output_dir:
+        PathManager.mkdirs(output_dir)
+
+    rank = comm.get_rank()
+    # filename = osp.join(output_dir, f"log_{get_time_str()}.txt")
+    # setup_logger(output_dir, distributed_rank=rank, filename=filename, mode="a")
+    setup_my_logger(output_dir, distributed_rank=rank, name="fvcore")
+    setup_my_logger(output_dir, distributed_rank=rank, name="mylib")
+    setup_my_logger(output_dir, distributed_rank=rank, name="core")
+    setup_my_logger(output_dir, distributed_rank=rank, name="det")
+    setup_my_logger(output_dir, distributed_rank=rank, name="detectron2")
+    setup_my_logger(output_dir, distributed_rank=rank, name="ref")
+    setup_my_logger(output_dir, distributed_rank=rank, name="tests")
+    setup_my_logger(output_dir, distributed_rank=rank, name="tools")
+    setup_my_logger(output_dir, distributed_rank=rank, name="__main__")
+    logger = setup_my_logger(output_dir, distributed_rank=rank, name=__name__)
+
+    logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+    logger.info("Environment info:\n" + collect_env_info())
+
+    logger.info("Command line arguments: " + str(args))
+    if hasattr(args, "config_file") and args.config_file != "":
+        logger.info(
+            "Contents of args.config_file={}:\n{}".format(
+                args.config_file,
+                _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
+            )
+        )
+
+    if comm.is_main_process() and output_dir:
+        # Note: some of our scripts may expect the existence of
+        # config.yaml in output directory
+        path = os.path.join(output_dir, "config.yaml")
+        if isinstance(cfg, CfgNode):
+            logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
+            with PathManager.open(path, "w") as f:
+                f.write(cfg.dump())
+        else:
+            LazyConfig.save(cfg, path)
+        logger.info("Full config saved to {}".format(path))
+
+    # make sure each worker has a different, yet deterministic seed if specified
+    seed = try_get_key(cfg, "SEED", "train.seed", default=-1)
+    logger.info(f"seed: {seed}")
+    seed_all_rng(None if seed < 0 else seed + rank)
+
+    cudnn_deterministic = try_get_key(cfg, "CUDNN_DETERMINISTIC", "train.cudnn_deterministic", default=False)
+    if cudnn_deterministic:
+        torch.backends.cudnn.deterministic = True
+        logger.warning(
+            "You have turned on the CUDNN deterministic setting, "
+            "which can slow down your training considerably! You may see unexpected behavior "
+            "when restarting from checkpoints."
+        )
+
+    # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+    # typical validation set.
+    if not (hasattr(args, "eval_only") and args.eval_only):  # currently only used for train
+        torch.backends.cudnn.benchmark = try_get_key(cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False)
+
+    # set mmcv backend
+    mmcv_backend = try_get_key(cfg, "mmcv_backend", "MMCV_BACKEND", default="cv2")
+    if mmcv_backend == "pillow" and "post" not in PIL.__version__:
+        logger.warning("Consider installing pillow-simd!")
+    logger.info(f"Used mmcv backend: {mmcv_backend}")
+    mmcv.use_backend(mmcv_backend)
+
+
+def default_yolox_writers(output_dir: str, max_iter: Optional[int] = None, backup=True):
+    """
+    Build a list of :class:`EventWriter` to be used.
+    It now consists of a :class:`CommonMetricPrinter`,
+    :class:`TensorboardXWriter` and :class:`JSONWriter`.
+
+    Args:
+        output_dir: directory to store JSON metrics and tensorboard events
+        max_iter: the total number of iterations
+
+    Returns:
+        list[EventWriter]: a list of :class:`EventWriter` objects.
+    """
+    tb_logdir = osp.join(output_dir, "tb")
+    mmcv.mkdir_or_exist(tb_logdir)
+    if backup and comm.is_main_process():
+        old_tb_logdir = osp.join(output_dir, "tb_old")
+        mmcv.mkdir_or_exist(old_tb_logdir)
+        os.system("mv -v {} {}".format(osp.join(tb_logdir, "events.*"), old_tb_logdir))
+    return [
+        # It may not always print what you want to see, since it prints "common" metrics only.
+        MyCommonMetricPrinter(max_iter),
+        MyJSONWriter(os.path.join(output_dir, "metrics.json")),
+        MyTensorboardXWriter(tb_logdir, backend="tensorboardx"),
+    ]
diff --git a/det/yolox/engine/yolox_train_test_plain.py b/det/yolox/engine/yolox_train_test_plain.py
new file mode 100644
index 0000000000000000000000000000000000000000..9027695c9bf48b416cce2387d9b0306ff4f5e4ea
--- /dev/null
+++ b/det/yolox/engine/yolox_train_test_plain.py
@@ -0,0 +1,186 @@
+# TODO: just use plain train loop
+import time
+import os.path as osp
+import logging
+from collections import OrderedDict
+from collections.abc import Sequence
+from detectron2.engine import (
+    SimpleTrainer,
+    default_writers,
+    hooks,
+)
+from detectron2.data.build import AspectRatioGroupedDataset
+from detectron2.data import MetadataCatalog
+from detectron2.utils.events import EventStorage
+from detectron2.config import LazyConfig, instantiate
+from detectron2.evaluation import print_csv_format
+from detectron2.config import LazyConfig, instantiate
+from detectron2.engine import (
+    AMPTrainer,
+    SimpleTrainer,
+    # default_writers,
+    hooks,
+)
+from detectron2.engine.defaults import create_ddp_model
+from detectron2.evaluation import print_csv_format
+
+import core.utils.my_comm as comm
+from core.utils.my_writer import MyPeriodicWriter
+from core.utils.my_checkpoint import MyCheckpointer
+from det.yolox.data import DataPrefetcher
+from det.yolox.utils import (
+    MeterBuffer,
+    ModelEMA,
+    all_reduce_norm,
+    get_model_info,
+    get_rank,
+    get_world_size,
+    gpu_mem_usage,
+    load_ckpt,
+    occupy_mem,
+    save_checkpoint,
+    setup_logger,
+    synchronize,
+)
+from .yolox_inference import yolox_inference_on_dataset
+from .yolox_setup import default_yolox_writers
+
+
+logger = logging.getLogger(__name__)
+
+
+def do_test_yolox(cfg, model, use_all_reduce_norm=False):
+    if "evaluator" not in cfg.dataloader:
+        logger.warning("no evaluator in cfg.dataloader, do nothing!")
+        return
+
+    if use_all_reduce_norm:
+        all_reduce_norm(model)
+
+    if not isinstance(cfg.dataloader.test, Sequence):
+        test_dset_name = cfg.dataloader.test.dataset.lst.names
+        if not isinstance(test_dset_name, str):
+            test_dset_name = ",".join(test_dset_name)
+        cfg.dataloader.evaluator.output_dir = osp.join(cfg.train.output_dir, "inference", test_dset_name)
+        ret = yolox_inference_on_dataset(
+            model,
+            instantiate(cfg.dataloader.test),
+            evaluator=instantiate(cfg.dataloader.evaluator),
+            amp_test=cfg.test.amp_test,
+            half_test=cfg.test.half_test,
+            test_cfg=cfg.test,
+        )
+        logger.info("Evaluation results for {} in csv format:".format(test_dset_name))
+        print_csv_format(ret)
+        return ret
+    else:
+        results = OrderedDict()
+        for loader_cfg, eval_cfg in zip(cfg.dataloader.test, cfg.dataloader.evaluator):
+            test_dset_name = loader_cfg.dataset.lst.names
+            if not isinstance(test_dset_name, str):
+                test_dset_name = ",".join(test_dset_name)
+            eval_cfg.output_dir = osp.join(cfg.train.output_dir, "inference", test_dset_name)
+            ret_i = yolox_inference_on_dataset(
+                model,
+                instantiate(loader_cfg),
+                evaluator=instantiate(eval_cfg),
+                amp_test=cfg.test.amp_test,
+                half_test=cfg.test.half_test,
+                test_cfg=cfg.test,
+            )
+            logger.info("Evaluation results for {} in csv format:".format(test_dset_name))
+            print_csv_format(ret_i)
+            results[test_dset_name] = ret_i
+        return results
+
+
+def do_train_yolox(args, cfg):
+    """
+    Args:
+        cfg: an object with the following attributes:
+            model: instantiate to a module
+            dataloader.{train,test}: instantiate to dataloaders
+            dataloader.evaluator: instantiate to evaluator for test set
+            optimizer: instantaite to an optimizer
+            lr_multiplier: instantiate to a fvcore scheduler
+            train: other misc config defined in `configs/common/train.py`, including:
+                output_dir (str)
+                init_checkpoint (str)
+                amp.enabled (bool)
+                max_iter (int)
+                eval_period, log_period (int)
+                device (str)
+                checkpointer (dict)
+                ddp (dict)
+    """
+    model = instantiate(cfg.model)
+    logger.info("Model:\n{}".format(model))
+    model.to(cfg.train.device)
+
+    cfg.optimizer.params.model = model
+    optim = instantiate(cfg.optimizer)
+
+    # TODO: support train2 and train2_ratio
+    train_loader = instantiate(cfg.dataloader.train)
+    ims_per_batch = cfg.dataloader.train.total_batch_size
+    # only using train to determine iters_per_epoch
+    if isinstance(train_loader, AspectRatioGroupedDataset):
+        dataset_len = len(train_loader.dataset.dataset)
+        iters_per_epoch = dataset_len // ims_per_batch
+    else:
+        dataset_len = len(train_loader.dataset)
+        iters_per_epoch = dataset_len // ims_per_batch
+    max_iter = cfg.lr_config.total_epochs * iters_per_epoch
+    cfg.train.max_iter = max_iter
+    cfg.train.no_aug_iters = cfg.train.no_aug_epochs * iters_per_epoch
+    cfg.train.warmup_iters = cfg.train.warmup_epochs * iters_per_epoch
+    logger.info("ims_per_batch: {}".format(ims_per_batch))
+    logger.info("dataset length: {}".format(dataset_len))
+    logger.info("iters per epoch: {}".format(iters_per_epoch))
+    logger.info("total iters: {}".format(max_iter))
+
+    anneal_point = cfg.lr_config.get("anneal_point", 0)
+    if cfg.train.anneal_after_warmup:
+        anneal_point = min(
+            anneal_point + cfg.train.warmup_epochs / (cfg.train.total_epochs - cfg.train.no_aug_epochs),
+            1.0,
+        )
+    cfg.lr_config.update(
+        optimizer=optim,
+        total_iters=max_iter - cfg.train.no_aug_iters,  # exclude no aug iters
+        warmup_iters=cfg.train.warmup_iters,
+        anneal_point=anneal_point,
+    )
+
+    model = create_ddp_model(model, **cfg.train.ddp)
+    trainer = (AMPTrainer if cfg.train.amp.enabled else SimpleTrainer)(model, train_loader, optim)
+    checkpointer = MyCheckpointer(
+        model,
+        cfg.train.output_dir,
+        optimizer=optim,
+        trainer=trainer,
+        save_to_disk=comm.is_main_process(),
+    )
+    trainer.register_hooks(
+        [
+            hooks.IterationTimer(),
+            hooks.LRScheduler(scheduler=instantiate(cfg.lr_config)),
+            hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None,
+            hooks.EvalHook(cfg.train.eval_period, lambda: do_test_yolox(cfg, model)),
+            MyPeriodicWriter(
+                default_yolox_writers(cfg.train.output_dir, cfg.train.max_iter),
+                period=cfg.train.log_period,
+            )
+            if comm.is_main_process()
+            else None,
+        ]
+    )
+
+    checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume)
+    if args.resume and checkpointer.has_checkpoint():
+        # The checkpoint stores the training iteration that just finished, thus we start
+        # at the next iteration
+        start_iter = trainer.iter + 1
+    else:
+        start_iter = 0
+    trainer.train(start_iter, cfg.train.max_iter)
diff --git a/det/yolox/engine/yolox_trainer.py b/det/yolox/engine/yolox_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc7558e61f43dfbbf853ac4a02a45a6608c1a205
--- /dev/null
+++ b/det/yolox/engine/yolox_trainer.py
@@ -0,0 +1,657 @@
+import copy
+import time
+import random
+import numpy as np
+import logging
+import os
+import os.path as osp
+import sys
+from typing import List, Mapping, Optional
+import weakref
+from collections import OrderedDict
+from typing import Optional, Sequence
+
+import core.utils.my_comm as comm
+import detectron2.data.transforms as T
+import torch
+from torch import nn
+from torch.cuda.amp import autocast, GradScaler
+import torch.distributed as dist
+from core.utils.my_writer import MyPeriodicWriter
+from det.yolox.utils import (
+    MeterBuffer,
+    ModelEMA,
+    all_reduce_norm,
+    get_model_info,
+    get_rank,
+    get_world_size,
+    gpu_mem_usage,
+    load_ckpt,
+    occupy_mem,
+    is_parallel,
+    save_checkpoint,
+    setup_logger,
+    synchronize,
+)
+from detectron2.utils.events import EventStorage, get_event_storage
+from detectron2.config.instantiate import instantiate
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.build import AspectRatioGroupedDataset
+from detectron2.engine import create_ddp_model, hooks
+from detectron2.engine.train_loop import TrainerBase
+from detectron2.evaluation import DatasetEvaluator, print_csv_format, verify_results
+from detectron2.utils.collect_env import collect_env_info
+from detectron2.utils.env import seed_all_rng
+
+from fvcore.nn.precise_bn import get_bn_modules
+from lib.utils.setup_logger import setup_my_logger, log_first_n
+from omegaconf import OmegaConf
+from torch.nn.parallel import DistributedDataParallel, DataParallel
+
+from lib.torch_utils.solver.grad_clip_d2 import maybe_add_gradient_clipping
+from lib.utils.config_utils import try_get_key
+from core.utils.my_checkpoint import MyCheckpointer
+from det.yolox.data import DataPrefetcher
+from .yolox_inference import yolox_inference_on_dataset
+from .yolox_setup import default_yolox_writers
+
+
+logger = logging.getLogger(__name__)
+
+
+class YOLOX_DefaultTrainer(TrainerBase):
+    """A trainer with default training logic. It does the following:
+
+    1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
+       defined by the given config. Create a LR scheduler defined by the config.
+    2. Load the last checkpoint or `cfg.train.init_checkpoint`, if exists, when
+       `resume_or_load` is called.
+    3. Register a few common hooks defined by the config.
+
+    It is created to simplify the **standard model training workflow** and reduce code boilerplate
+    for users who only need the standard training workflow, with standard features.
+    It means this class makes *many assumptions* about your training logic that
+    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
+    :class:`SimpleTrainer` are too much for research.
+
+    The code of this class has been annotated about restrictive assumptions it makes.
+    When they do not work for you, you're encouraged to:
+
+    1. Overwrite methods of this class, OR:
+    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
+       nothing else. You can then add your own hooks if needed. OR:
+    3. Write your own training loop similar to `tools/plain_train_net.py`.
+
+    See the :doc:`/tutorials/training` tutorials for more details.
+
+    Note that the behavior of this class, like other functions/classes in
+    this file, is not stable, since it is meant to represent the "common default behavior".
+    It is only guaranteed to work well with the standard models and training workflow in detectron2.
+    To obtain more stable behavior, write your own training logic with other public APIs.
+
+    Examples:
+    ::
+        trainer = DefaultTrainer(cfg)
+        trainer.resume_or_load()  # load last checkpoint or train.init_checkpoint
+        trainer.train()
+
+    Attributes:
+        scheduler:
+        checkpointer (DetectionCheckpointer):
+        cfg (CfgNode):
+    """
+
+    def __init__(self, cfg):
+        """
+        Args:
+            cfg (CfgNode):
+        """
+        super().__init__()
+        # if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
+        #     setup_my_logger(name="det")
+
+        # Assume these objects must be constructed in this order.
+        model = self.build_model(cfg)
+        optimizer = self.build_optimizer(cfg, model)
+
+        self.data_type = torch.float16 if cfg.train.amp.enabled else torch.float32
+
+        train_loader_cfg = cfg.dataloader.train
+        self.input_size = train_loader_cfg.dataset.img_size
+        train_loader = instantiate(train_loader_cfg)
+
+        # TODO: support train2 and train2_ratio
+        ims_per_batch = train_loader_cfg.total_batch_size
+        # only using train to determine iters_per_epoch
+        if isinstance(train_loader, AspectRatioGroupedDataset):
+            dataset_len = len(train_loader.dataset.dataset)
+            iters_per_epoch = dataset_len // ims_per_batch
+        else:
+            dataset_len = len(train_loader.dataset)
+            iters_per_epoch = dataset_len // ims_per_batch
+        max_iter = cfg.train.total_epochs * iters_per_epoch
+        cfg.train.iters_per_epoch = iters_per_epoch
+        cfg.train.max_iter = max_iter
+        cfg.train.no_aug_iters = cfg.train.no_aug_epochs * iters_per_epoch
+        cfg.train.warmup_iters = cfg.train.warmup_epochs * iters_per_epoch
+        logger.info("ims_per_batch: {}".format(ims_per_batch))
+        logger.info("dataset length: {}".format(dataset_len))
+        logger.info("iters per epoch: {}".format(iters_per_epoch))
+        logger.info("total iters: {}".format(max_iter))
+
+        cfg.train.eval_period = cfg.train.eval_period * cfg.train.iters_per_epoch
+        cfg.train.checkpointer.period = cfg.train.checkpointer.period * cfg.train.iters_per_epoch
+
+        OmegaConf.set_readonly(cfg, True)  # freeze config
+        self.cfg = cfg
+
+        model = create_ddp_model(model, broadcast_buffers=False)
+
+        self.use_model_ema = cfg.train.ema
+        if self.use_model_ema:
+            self.ema_model = ModelEMA(model, 0.9998)
+
+        amp_ckpt_data = {}
+        self.init_model_loader_optimizer_amp(model, train_loader, optimizer, amp_ckpt_data)
+        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+        self.checkpointer = MyCheckpointer(
+            # Assume you want to save checkpoints together with logs/statistics
+            model,
+            cfg.train.output_dir,
+            # trainer=weakref.proxy(self),
+            optimizer=self.optimizer,
+            scheduler=self.scheduler,
+            save_to_disk=comm.is_main_process(),
+            **amp_ckpt_data,
+        )
+        self.start_iter = 0
+        self.max_iter = cfg.train.max_iter
+
+        self.register_hooks(self.build_hooks())
+
+    def init_model_loader_optimizer_amp(self, model, train_loader, optimizer, amp_ckpt_data={}, train_loader2=None):
+        amp_cfg = self.cfg.train.amp
+        if amp_cfg.enabled:
+            logger.info("Using pytorch amp")
+            unsupported = "AMPTrainer does not support single-process multi-device training!"
+            if isinstance(model, DistributedDataParallel):
+                assert not (model.device_ids and len(model.device_ids) > 1), unsupported
+            assert not isinstance(model, DataParallel), unsupported
+
+        self.grad_scaler = GradScaler(enabled=amp_cfg.enabled)
+        amp_ckpt_data["grad_scaler"] = self.grad_scaler
+        self.init_model_loader_optimizer_simple(model, train_loader, optimizer, train_loader2=train_loader2)
+
+    def init_model_loader_optimizer_simple(self, model, train_loader, optimizer, train_loader2=None):
+        model.train()
+
+        self.model = model
+
+        self.data_loader = train_loader
+        self._data_loader_iter = iter(train_loader)
+        logger.info("init prefetcher, this might take one minute or less...")
+        self.prefetcher = DataPrefetcher(train_loader)
+
+        self.data_loader2 = train_loader2
+        self._data_loader_iter2 = None
+        self.prefecher2 = None
+        if train_loader2 is not None:
+            self._data_loader_iter2 = iter(train_loader2)
+            logger.info("init prefetcher2, this might take one minute or less...")
+            self.prefecher2 = DataPrefetcher(train_loader2)
+
+        self.optimizer = optimizer
+
+    def resume_or_load(self, resume=True):
+        """NOTE: should run before train()
+        if resume from a middle/last ckpt but want to reset iteration,
+        remove the iteration key from the ckpt first
+        """
+        if resume:
+            # NOTE: --resume always from last_checkpoint
+            iter_saved = self.checkpointer.resume_or_load("", resume=True).get("iteration", -1)
+        else:
+            if self.cfg.train.resume_from != "":
+                # resume_from a given ckpt
+                iter_saved = self.checkpointer.load(self.cfg.train.resume_from).get("iteration", -1)
+            else:
+                # load from a given ckpt
+                # iter_saved = self.checkpointer.load(self.cfg.train.init_checkpoint).get("iteration", -1)
+                iter_saved = self.checkpointer.resume_or_load(self.cfg.train.init_checkpoint, resume=resume).get(
+                    "iteration", -1
+                )
+        self.start_iter = iter_saved + 1
+
+    def build_hooks(self):
+        """Build a list of default hooks, including timing, evaluation,
+        checkpointing, lr scheduling, precise BN, writing events.
+
+        Returns:
+            list[HookBase]:
+        """
+        cfg = self.cfg
+        train_loader_cfg = copy.deepcopy(cfg.dataloader.train)
+        if OmegaConf.is_readonly(train_loader_cfg):
+            OmegaConf.set_readonly(train_loader_cfg, False)
+
+        train_loader_cfg.num_workers = 0  # save some memory and time for PreciseBN
+
+        ret = [
+            hooks.IterationTimer(),
+            hooks.LRScheduler(),
+            hooks.PreciseBN(
+                # Run at the same freq as (but before) evaluation.
+                cfg.train.eval_period,
+                self.model,
+                # Build a new data loader to not affect training
+                instantiate(train_loader_cfg),
+                cfg.test.precise_bn.num_iter,
+            )
+            if cfg.test.precise_bn.enabled and get_bn_modules(self.model)
+            else None,
+        ]
+
+        # Do PreciseBN before checkpointer, because it updates the model and need to
+        # be saved by checkpointer.
+        # This is not always the best: if checkpointing has a different frequency,
+        # some checkpoints may have more precise statistics than others.
+        if comm.is_main_process():
+            ret.append(hooks.PeriodicCheckpointer(self.checkpointer, **cfg.train.checkpointer))
+
+        def test_and_save_results():
+            # TODO: check this ema
+            if self.use_model_ema:
+                evalmodel = self.ema_model.ema
+            else:
+                evalmodel = self.model
+                if is_parallel(evalmodel):
+                    evalmodel = evalmodel.module
+            self._last_eval_results = self.test(self.cfg, evalmodel)
+            return self._last_eval_results
+
+        # Do evaluation after checkpointer, because then if it fails,
+        # we can use the saved checkpoint to debug.
+        ret.append(hooks.EvalHook(cfg.train.eval_period, test_and_save_results))
+
+        if comm.is_main_process():
+            # Here the default print/log frequency of each writer is used.
+            # run writers in the end, so that evaluation metrics are written
+            ret.append(MyPeriodicWriter(self.build_writers(), period=cfg.train.log_period))
+        return ret
+
+    def build_writers(self):
+        """Build a list of writers to be used using :func:`default_writers()`.
+        If you'd like a different list of writers, you can overwrite it in your
+        trainer.
+
+        Returns:
+            list[EventWriter]: a list of :class:`EventWriter` objects.
+        """
+        return default_yolox_writers(self.cfg.train.output_dir, self.max_iter)
+
+    def before_train(self):
+        if try_get_key(self.cfg, "train.occupy_gpu", default=False):
+            occupy_mem(comm.get_local_rank())
+        super().before_train()  # for hooks
+
+        if self.start_iter >= self.max_iter - self.cfg.train.no_aug_iters:
+            self.close_mosaic()
+            if self.cfg.train.use_l1 and self.cfg.train.l1_from_scratch is False:
+                self.enable_l1()
+            OmegaConf.set_readonly(self.cfg, False)
+            logger.info(f"sync norm period changed from {self.cfg.train.sync_norm_period} to 1")
+            self.cfg.train.sync_norm_period = 1  # sync norm every epoch when mosaic is closed
+            OmegaConf.set_readonly(self.cfg, True)
+
+        if self.cfg.train.use_l1 and self.cfg.train.l1_from_scratch is True:
+            self.enable_l1()
+
+        if self.use_model_ema:
+            self.ema_model.updates = self.start_iter
+
+    def train(self):
+        """Run training.
+
+        Returns:
+            OrderedDict of results, if evaluation is enabled. Otherwise None.
+        """
+        logger.info(
+            f"total batch size: {self.cfg.dataloader.train.total_batch_size}, num_gpus: {comm.get_world_size()}"
+        )
+        super().train(self.start_iter, self.max_iter)
+        # if len(self.cfg.test.expected_results) and comm.is_main_process():
+        #     assert hasattr(
+        #         self, "_last_eval_results"
+        #     ), "No evaluation results obtained during training!"
+        #     verify_results(self.cfg, self._last_eval_results)
+        #     return self._last_eval_results
+
+    def enable_l1(self):
+        logger.info("--->Add additional L1 loss now!")
+        if comm.get_world_size() > 1:
+            self.model.module.head.use_l1 = True
+        else:
+            self.model.head.use_l1 = True
+
+    def close_mosaic(self):
+        logger.info("--->No mosaic aug now!")
+        self.data_loader.close_mosaic()
+
+    def before_step(self):
+        super().before_step()
+
+        self.epoch = self.iter // self.cfg.train.iters_per_epoch
+        if self.iter == self.max_iter - self.cfg.train.no_aug_iters:
+            self.close_mosaic()
+            if self.cfg.train.use_l1 and self.cfg.train.l1_from_scratch is False:
+                self.enable_l1()
+            OmegaConf.set_readonly(self.cfg, False)
+            logger.info(f"sync norm period changed from {self.cfg.train.sync_norm_period} to 1")
+            self.cfg.train.sync_norm_period = 1  # sync norm every epoch when mosaic is closed
+            OmegaConf.set_readonly(self.cfg, True)
+            if comm.is_main_process():
+                self.checkpointer.save(
+                    name=f"last_mosaic_epoch{self.epoch}_iter{self.iter}",
+                    iteration=self.iter,
+                )
+
+    def run_step(self):
+        assert self.model.training, "[YOLOX_DefaultTrainer] model was changed to eval mode!"
+
+        # log_first_n(logging.INFO, f"running iter: {self.iter}", n=5)  # for debug
+
+        start = time.perf_counter()  # get data --------------------------
+        # inps, targets, _, _ = next(self._data_loader_iter)
+        # inps, targets, scene_im_id, _, img_id = next(self._data_loader_iter)
+        inps, targets = self.prefetcher.next()
+        inps = inps.to(self.data_type)
+        targets = targets.to(self.data_type)
+        targets.requires_grad = False
+        inps, targets = self.preprocess(inps, targets, self.input_size)
+        data_time = time.perf_counter() - start
+
+        with autocast(enabled=self.cfg.train.amp.enabled):
+            out_dict, loss_dict = self.model(inps, targets)
+            if isinstance(loss_dict, torch.Tensor):
+                losses = loss_dict
+                loss_dict = {"total_loss": loss_dict}
+            else:
+                losses = sum(loss_dict.values())
+
+        # vis image
+        # from det.yolox.utils.visualize import vis_train
+        # vis_train(inps, targets, self.cfg)
+
+        # optimizer step ------------------------------------------------
+        self.optimizer.zero_grad()
+        self.grad_scaler.scale(losses).backward()
+
+        # write metrics before opt step
+        self._write_metrics(loss_dict, data_time)
+
+        self.grad_scaler.step(self.optimizer)
+        self.grad_scaler.update()
+
+        if self.use_model_ema:
+            self.ema_model.update(self.model)
+
+        # log_first_n(logging.INFO, f"done iter: {self.iter}", n=5)  # for debug
+
+    def after_step(self):
+        for h in self._hooks:
+            # NOTE: hack to save ema model
+            if isinstance(h, hooks.PeriodicCheckpointer) and self.use_model_ema:
+                h.checkpointer.model = self.ema_model.ema
+            h.after_step()
+
+        # sync norm
+        if self.cfg.train.sync_norm_period > 0:
+            if (self.epoch + 1) % self.cfg.train.sync_norm_period == 0:
+                all_reduce_norm(self.model)
+
+        # random resizing
+        if self.cfg.train.random_size is not None and self.iter % 10 == 0:
+            is_distributed = comm.get_world_size() > 1
+            self.input_size = self.random_resize(self.data_loader, self.epoch, comm.get_rank(), is_distributed)
+
+    def _write_metrics(
+        self,
+        loss_dict: Mapping[str, torch.Tensor],
+        data_time: float,
+        prefix: str = "",
+    ) -> None:
+        """
+        Args:
+            loss_dict (dict): dict of scalar losses
+                losses should have a `loss` str;
+                other scalar metrics can be also in it.
+            data_time (float): time taken by the dataloader iteration
+            prefix (str): prefix for logging keys
+        """
+        metrics_dict = {}
+        for k, v in loss_dict.items():
+            if isinstance(v, torch.Tensor):
+                metrics_dict[k] = v.detach().cpu().item()
+            else:
+                metrics_dict[k] = v  # assume float/int
+        metrics_dict["data_time"] = data_time
+
+        # Gather metrics among all workers for logging
+        # This assumes we do DDP-style training, which is currently the only
+        # supported method in detectron2.
+        all_metrics_dict = comm.gather(metrics_dict)
+
+        if comm.is_main_process():
+            storage = get_event_storage()
+
+            # data_time among workers can have high variance. The actual latency
+            # caused by data_time is the maximum among workers.
+            data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
+            storage.put_scalar("data_time", data_time)
+
+            storage.put_scalar("epoch", self.epoch)  # NOTE: added
+
+            # average the rest metrics
+            metrics_dict = {k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()}
+            # NOTE: filter losses
+            total_losses_reduced = sum([v for k, v in metrics_dict.items() if "loss" in k])
+            if not np.isfinite(total_losses_reduced):
+                raise FloatingPointError(
+                    f"Loss became infinite or NaN at iteration={storage.iter}!\n" f"loss_dict = {metrics_dict}"
+                )
+
+            storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
+            if len(metrics_dict) > 1:
+                storage.put_scalars(**metrics_dict)
+
+    def preprocess(self, inputs, targets, tsize):
+        scale_y = tsize[0] / self.input_size[0]
+        scale_x = tsize[1] / self.input_size[1]
+        if scale_x != 1 or scale_y != 1:
+            inputs = nn.functional.interpolate(inputs, size=tsize, mode="bilinear", align_corners=False)
+            targets[..., 1::2] = targets[..., 1::2] * scale_x
+            targets[..., 2::2] = targets[..., 2::2] * scale_y
+        return inputs, targets
+
+    def random_resize(self, data_loader, epoch, rank, is_distributed):
+        # randomly choose a int, *32, aspect ratio is the same as intput size
+        tensor = torch.LongTensor(2).cuda()
+
+        if rank == 0:
+            size_factor = self.input_size[1] * 1.0 / self.input_size[0]  # w/h
+            size = random.randint(*self.cfg.train.random_size)
+            size = (int(32 * size), 32 * int(size * size_factor))
+            tensor[0] = size[0]
+            tensor[1] = size[1]
+
+        if is_distributed:
+            dist.barrier()
+            dist.broadcast(tensor, 0)
+
+        input_size = (tensor[0].item(), tensor[1].item())
+        return input_size
+
+    @classmethod
+    def build_model(cls, cfg, verbose=True):
+        """
+        Returns:
+            torch.nn.Module:
+
+        It now calls :func:`detectron2.modeling.build_model`.
+        Overwrite it if you'd like a different model.
+        """
+        model = instantiate(cfg.model)
+        if verbose:
+            logger.info("Model:\n{}".format(model))
+        model.to(cfg.train.device)
+        return model
+
+    @classmethod
+    def build_optimizer(cls, cfg, model):
+        """
+        Returns:
+            torch.optim.Optimizer:
+        """
+        cfg.optimizer.params.model = model
+        optimizer = instantiate(cfg.optimizer)
+        optimizer = maybe_add_gradient_clipping(cfg, optimizer)
+        return optimizer
+
+    @classmethod
+    def build_lr_scheduler(cls, cfg, optimizer):
+        anneal_point = cfg.lr_config.get("anneal_point", 0)
+        if cfg.train.anneal_after_warmup:
+            anneal_point = min(
+                anneal_point + cfg.train.warmup_epochs / (cfg.train.total_epochs - cfg.train.no_aug_epochs),
+                1.0,
+            )
+        OmegaConf.set_readonly(cfg, False)
+        cfg.lr_config.update(
+            optimizer=optimizer,
+            total_iters=cfg.train.max_iter - cfg.train.no_aug_iters,  # exclude no aug iters
+            warmup_iters=cfg.train.warmup_iters,
+            anneal_point=anneal_point,
+        )
+        OmegaConf.set_readonly(cfg, True)
+        return instantiate(cfg.lr_config)
+
+    @classmethod
+    def test_single(cls, cfg, model, evaluator=None):
+        """
+        Args:
+            cfg (CfgNode):
+            model (nn.Module):
+        Returns:
+            dict: a dict of result metrics
+        """
+        test_dset_name = cfg.dataloader.test.dataset.lst.names
+        if not isinstance(test_dset_name, str):
+            test_dset_name = ",".join(test_dset_name)
+        if OmegaConf.is_readonly(cfg):
+            OmegaConf.set_readonly(cfg, False)
+        cfg.dataloader.evaluator.output_dir = osp.join(cfg.train.output_dir, "inference", test_dset_name)
+        OmegaConf.set_readonly(cfg, True)
+        if evaluator is None:
+            evaluator = instantiate(cfg.dataloader.evaluator)
+        cls.auto_set_test_batch_size(cfg.dataloader.test)
+        ret = yolox_inference_on_dataset(
+            model,
+            instantiate(cfg.dataloader.test),
+            evaluator=evaluator,
+            amp_test=cfg.test.amp_test,
+            half_test=cfg.test.half_test,
+            test_cfg=cfg.test,
+            val_cfg=cfg.val,
+        )
+        if comm.is_main_process():
+            assert isinstance(ret, dict), "Evaluator must return a dict on the main process. Got {} instead.".format(
+                ret
+            )
+            logger.info("Evaluation results for {} in csv format:".format(test_dset_name))
+            print_csv_format(ret)
+        return ret
+
+    @classmethod
+    def test(cls, cfg, model, evaluators=None):
+        """
+        Args:
+            cfg (CfgNode):
+            model (nn.Module):
+            evaluators (list[DatasetEvaluator] or None): if None, will call
+                :meth:`build_evaluator`. Otherwise, must have the same length as
+                test_loaders
+        Returns:
+            dict: a dict of result metrics
+        """
+        loader_cfgs = cfg.dataloader.test
+        if not isinstance(loader_cfgs, Sequence):
+            return cls.test_single(cfg, model, evaluator=evaluators)
+
+        if isinstance(evaluators, DatasetEvaluator):
+            evaluators = [evaluators]
+
+        if evaluators is not None:
+            assert len(loader_cfgs) == len(evaluators), "{} != {}".format(len(loader_cfgs), len(evaluators))
+        else:
+            evaluator_cfgs = cfg.dataloader.evaluator
+            assert isinstance(evaluator_cfgs, Sequence)
+            assert len(loader_cfgs) == len(evaluator_cfgs), "{} != {}".format(len(loader_cfgs), len(evaluator_cfgs))
+
+        results = OrderedDict()
+        for idx, loader_cfg in enumerate(loader_cfgs):
+            cls.auto_set_test_batch_size(loader_cfg)
+            test_loader = instantiate(loader_cfg)
+
+            test_dset_name = loader_cfg.dataset.lst.names
+            if not isinstance(test_dset_name, str):
+                test_dset_name = ",".join(test_dset_name)
+
+            # When evaluators are passed in as arguments,
+            # implicitly assume that evaluators can be created before data_loader.
+            if evaluators is not None:
+                evaluator = evaluators[idx]
+            else:
+                try:
+                    eval_cfg = evaluator_cfgs[idx]
+                    if OmegaConf.is_readonly(eval_cfg):
+                        OmegaConf.set_readonly(eval_cfg, False)
+                    eval_cfg.output_dir = osp.join(cfg.train.output_dir, "inference", test_dset_name)
+                    OmegaConf.set_readonly(eval_cfg, True)
+                    evaluator = instantiate(eval_cfg)
+                except NotImplementedError:
+                    logger.warning("No evaluator found. Use `DefaultTrainer.test(evaluators=)` instead")
+                    results[test_dset_name] = {}
+                    continue
+            ret_i = yolox_inference_on_dataset(
+                model,
+                test_loader,
+                evaluator=evaluator,
+                amp_test=cfg.test.amp_test,
+                half_test=cfg.test.half_test,
+                test_cfg=cfg.test,
+                val_cfg=cfg.val,
+            )
+            results[test_dset_name] = ret_i
+            if comm.is_main_process():
+                assert isinstance(
+                    ret_i, dict
+                ), "Evaluator must return a dict on the main process. Got {} instead.".format(ret_i)
+                logger.info("Evaluation results for {} in csv format:".format(test_dset_name))
+                print_csv_format(ret_i)
+
+        if len(results) == 1:
+            results = list(results.values())[0]
+        return results
+
+    @classmethod
+    def auto_set_test_batch_size(cls, loader_cfg):
+        test_batch_size = loader_cfg.total_batch_size
+        n_gpus = comm.get_world_size()
+        if test_batch_size % n_gpus != 0:
+            OmegaConf.set_readonly(loader_cfg, False)
+            new_batch_size = int(np.ceil(test_batch_size / n_gpus) * n_gpus)
+            loader_cfg.total_batch_size = new_batch_size
+            logger.info(
+                "test total batch size reset from {} to {}, n_gpus: {}".format(test_batch_size, new_batch_size, n_gpus)
+            )
+            OmegaConf.set_readonly(loader_cfg, True)
diff --git a/det/yolox/evaluators/__init__.py b/det/yolox/evaluators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d77d4519cd3c04f9682e0fb2ae83abcb65881e98
--- /dev/null
+++ b/det/yolox/evaluators/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .coco_evaluator import COCOEvaluator
+from .voc_evaluator import VOCEvaluator
+from .yolox_coco_evaluator import YOLOX_COCOEvaluator
diff --git a/det/yolox/evaluators/coco_evaluator.py b/det/yolox/evaluators/coco_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..420aa91c089e5953d800a23c5b45693a6a52027a
--- /dev/null
+++ b/det/yolox/evaluators/coco_evaluator.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import contextlib
+import io
+import itertools
+import json
+import tempfile
+import time
+
+from loguru import logger
+from tqdm import tqdm
+
+import torch
+
+from det.yolox.utils import (
+    gather,
+    is_main_process,
+    postprocess,
+    synchronize,
+    time_synchronized,
+    xyxy2xywh,
+)
+
+
+class COCOEvaluator:
+    """COCO AP Evaluation class.
+
+    All the data in the val2017 dataset are processed and evaluated by
+    COCO API.
+    """
+
+    def __init__(self, img_size, conf_thr, nms_thr, num_classes, testdev=False):
+        """
+        Args:
+
+            img_size (int): image size after preprocess. images are resized
+                to squares whose shape is (img_size, img_size).
+            conf_thr (float): confidence threshold ranging from 0 to 1, which
+                is defined in the config file.
+            nms_thr (float): IoU threshold of non-max supression ranging from 0 to 1.
+        """
+        self.img_size = img_size
+        self.conf_thr = conf_thr
+        self.nms_thr = nms_thr
+        self.num_classes = num_classes
+        self.testdev = testdev
+
+    def set_dataloader(self, dataloader):
+        # dataloader (Dataloader): evaluate dataloader.
+        self.dataloader = dataloader
+
+    def evaluate(
+        self,
+        model,
+        distributed=False,
+        half=False,
+        trt_file=None,
+        decoder=None,
+        test_size=None,
+    ):
+        """COCO average precision (AP) Evaluation. Iterate inference on the
+        test dataset and the results are evaluated by COCO API.
+
+        NOTE: This function will change training mode to False, please save states if needed.
+
+        Args:
+            model : model to evaluate.
+
+        Returns:
+            ap50_95 (float) : COCO AP of IoU=50:95
+            ap50 (float) : COCO AP of IoU=50
+            summary (sr): summary info of evaluation.
+        """
+        # TODO half to amp_test
+        tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
+        model = model.eval()
+        if half:
+            model = model.half()
+        ids = []
+        data_list = []
+        progress_bar = tqdm if is_main_process() else iter
+
+        inference_time = 0
+        nms_time = 0
+        n_samples = max(len(self.dataloader) - 1, 1)
+
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
+            model(x)
+            model = model_trt
+
+        for cur_iter, (imgs, _, info_imgs, ids) in enumerate(progress_bar(self.dataloader)):
+            with torch.no_grad():
+                imgs = imgs.type(tensor_type)
+
+                # skip the the last iters since batchsize might be not enough for batch inference
+                is_time_record = cur_iter < len(self.dataloader) - 1
+                if is_time_record:
+                    start = time.perf_counter()
+
+                if trt_file is not None:
+                    outputs = model(imgs)
+                else:
+                    outputs = model(imgs)["det_preds"]
+
+                if decoder is not None:
+                    outputs = decoder(outputs, dtype=outputs.type())
+
+                if is_time_record:
+                    infer_end = time_synchronized()
+                    inference_time += infer_end - start
+
+                outputs = postprocess(outputs, self.num_classes, self.conf_thr, self.nms_thr)
+                if is_time_record:
+                    nms_end = time_synchronized()
+                    nms_time += nms_end - infer_end
+
+            data_list.extend(self.convert_to_coco_format(outputs, info_imgs, ids))
+
+        statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
+        if distributed:
+            data_list = gather(data_list, dst=0)
+            data_list = list(itertools.chain(*data_list))
+            torch.distributed.reduce(statistics, dst=0)
+
+        eval_results = self.evaluate_prediction(data_list, statistics)
+        synchronize()
+        return eval_results
+
+    def convert_to_coco_format(self, outputs, info_imgs, ids):
+        data_list = []
+        for (output, img_h, img_w, img_id) in zip(outputs, info_imgs[0], info_imgs[1], ids):
+            if output is None:
+                continue
+            output = output.cpu()
+
+            bboxes = output[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(self.img_size[0] / float(img_h), self.img_size[1] / float(img_w))
+            bboxes /= scale
+            bboxes = xyxy2xywh(bboxes)
+
+            cls = output[:, 6]
+            scores = output[:, 4] * output[:, 5]
+            for ind in range(bboxes.shape[0]):
+                label = self.dataloader.dataset.class_ids[int(cls[ind])]
+                pred_data = {
+                    "image_id": int(img_id),
+                    "category_id": label,
+                    "bbox": bboxes[ind].numpy().tolist(),
+                    "score": scores[ind].numpy().item(),
+                    "segmentation": [],
+                }  # COCO json format
+                data_list.append(pred_data)
+        return data_list
+
+    def evaluate_prediction(self, data_dict, statistics):
+        if not is_main_process():
+            return 0, 0, None
+
+        logger.info("Evaluate in main process...")
+
+        annType = ["segm", "bbox", "keypoints"]
+
+        inference_time = statistics[0].item()
+        nms_time = statistics[1].item()
+        n_samples = statistics[2].item()
+
+        a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
+        a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
+
+        time_info = ", ".join(
+            [
+                "Average {} time: {:.2f} ms".format(k, v)
+                for k, v in zip(
+                    ["forward", "NMS", "inference"],
+                    [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
+                )
+            ]
+        )
+
+        info = time_info + "\n"
+
+        # Evaluate the Dt (detection) json comparing with the ground truth
+        if len(data_dict) > 0:
+            cocoGt = self.dataloader.dataset.coco
+            # TODO: since pycocotools can't process dict in py36, write data to json file.
+            if self.testdev:
+                json.dump(data_dict, open("./yolox_testdev_2017.json", "w"))
+                cocoDt = cocoGt.loadRes("./yolox_testdev_2017.json")
+            else:
+                _, tmp = tempfile.mkstemp()
+                json.dump(data_dict, open(tmp, "w"))
+                cocoDt = cocoGt.loadRes(tmp)
+            try:
+                from detectron2.evaluation.fast_eval_api import COCOeval_opt as COCOeval
+            except ImportError:
+                from pycocotools.cocoeval import COCOeval
+
+                logger.warning("Use standard COCOeval.")
+
+            cocoEval = COCOeval(cocoGt, cocoDt, annType[1])
+            cocoEval.evaluate()
+            cocoEval.accumulate()
+            redirect_string = io.StringIO()
+            with contextlib.redirect_stdout(redirect_string):
+                cocoEval.summarize()
+            info += redirect_string.getvalue()
+            return cocoEval.stats[0], cocoEval.stats[1], info
+        else:
+            return 0, 0, info
diff --git a/det/yolox/evaluators/voc_eval.py b/det/yolox/evaluators/voc_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..3570191de16ff96054d53e65c1e233203070a956
--- /dev/null
+++ b/det/yolox/evaluators/voc_eval.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Code are based on
+# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
+# Copyright (c) Bharath Hariharan.
+# Copyright (c) Megvii, Inc. and its affiliates.
+import os
+import pickle
+import xml.etree.ElementTree as ET
+
+import numpy as np
+
+
+def parse_rec(filename):
+    """Parse a PASCAL VOC xml file."""
+    tree = ET.parse(filename)
+    objects = []
+    for obj in tree.findall("object"):
+        obj_struct = {}
+        obj_struct["name"] = obj.find("name").text
+        obj_struct["pose"] = obj.find("pose").text
+        obj_struct["truncated"] = int(obj.find("truncated").text)
+        obj_struct["difficult"] = int(obj.find("difficult").text)
+        bbox = obj.find("bndbox")
+        obj_struct["bbox"] = [
+            int(bbox.find("xmin").text),
+            int(bbox.find("ymin").text),
+            int(bbox.find("xmax").text),
+            int(bbox.find("ymax").text),
+        ]
+        objects.append(obj_struct)
+
+    return objects
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+    """ap = voc_ap(rec, prec, [use_07_metric])
+    Compute VOC AP given precision and recall.
+    If use_07_metric is true, uses the
+    VOC 07 11 point method (default:False).
+    """
+    if use_07_metric:
+        # 11 point metric
+        ap = 0.0
+        for t in np.arange(0.0, 1.1, 0.1):
+            if np.sum(rec >= t) == 0:
+                p = 0
+            else:
+                p = np.max(prec[rec >= t])
+            ap = ap + p / 11.0
+    else:
+        # correct AP calculation
+        # first append sentinel values at the end
+        mrec = np.concatenate(([0.0], rec, [1.0]))
+        mpre = np.concatenate(([0.0], prec, [0.0]))
+
+        # compute the precision envelope
+        for i in range(mpre.size - 1, 0, -1):
+            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+        # to calculate area under PR curve, look for points
+        # where X axis (recall) changes value
+        i = np.where(mrec[1:] != mrec[:-1])[0]
+
+        # and sum (\Delta recall) * prec
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+    return ap
+
+
+def voc_eval(
+    detpath,
+    annopath,
+    imagesetfile,
+    classname,
+    cachedir,
+    ovthresh=0.5,
+    use_07_metric=False,
+):
+    # first load gt
+    if not os.path.isdir(cachedir):
+        os.mkdir(cachedir)
+    cachefile = os.path.join(cachedir, "annots.pkl")
+    # read list of images
+    with open(imagesetfile, "r") as f:
+        lines = f.readlines()
+    imagenames = [x.strip() for x in lines]
+
+    if not os.path.isfile(cachefile):
+        # load annots
+        recs = {}
+        for i, imagename in enumerate(imagenames):
+            recs[imagename] = parse_rec(annopath.format(imagename))
+            if i % 100 == 0:
+                print("Reading annotation for {:d}/{:d}".format(i + 1, len(imagenames)))
+        # save
+        print("Saving cached annotations to {:s}".format(cachefile))
+        with open(cachefile, "wb") as f:
+            pickle.dump(recs, f)
+    else:
+        # load
+        with open(cachefile, "rb") as f:
+            recs = pickle.load(f)
+
+    # extract gt objects for this class
+    class_recs = {}
+    npos = 0
+    for imagename in imagenames:
+        R = [obj for obj in recs[imagename] if obj["name"] == classname]
+        bbox = np.array([x["bbox"] for x in R])
+        difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
+        det = [False] * len(R)
+        npos = npos + sum(~difficult)
+        class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
+
+    # read dets
+    detfile = detpath.format(classname)
+    with open(detfile, "r") as f:
+        lines = f.readlines()
+
+    if len(lines) == 0:
+        return 0, 0, 0
+
+    splitlines = [x.strip().split(" ") for x in lines]
+    image_ids = [x[0] for x in splitlines]
+    confidence = np.array([float(x[1]) for x in splitlines])
+    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
+
+    # sort by confidence
+    sorted_ind = np.argsort(-confidence)
+    BB = BB[sorted_ind, :]
+    image_ids = [image_ids[x] for x in sorted_ind]
+
+    # go down dets and mark TPs and FPs
+    nd = len(image_ids)
+    tp = np.zeros(nd)
+    fp = np.zeros(nd)
+    for d in range(nd):
+        R = class_recs[image_ids[d]]
+        bb = BB[d, :].astype(float)
+        ovmax = -np.inf
+        BBGT = R["bbox"].astype(float)
+
+        if BBGT.size > 0:
+            # compute overlaps
+            # intersection
+            ixmin = np.maximum(BBGT[:, 0], bb[0])
+            iymin = np.maximum(BBGT[:, 1], bb[1])
+            ixmax = np.minimum(BBGT[:, 2], bb[2])
+            iymax = np.minimum(BBGT[:, 3], bb[3])
+            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
+            ih = np.maximum(iymax - iymin + 1.0, 0.0)
+            inters = iw * ih
+
+            # union
+            uni = (
+                (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+                + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
+                - inters
+            )
+
+            overlaps = inters / uni
+            ovmax = np.max(overlaps)
+            jmax = np.argmax(overlaps)
+
+        if ovmax > ovthresh:
+            if not R["difficult"][jmax]:
+                if not R["det"][jmax]:
+                    tp[d] = 1.0
+                    R["det"][jmax] = 1
+                else:
+                    fp[d] = 1.0
+        else:
+            fp[d] = 1.0
+
+        # compute precision recall
+    fp = np.cumsum(fp)
+    tp = np.cumsum(tp)
+    rec = tp / float(npos)
+    # avoid divide by zero in case the first detection matches a difficult
+    # ground truth
+    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+    ap = voc_ap(rec, prec, use_07_metric)
+
+    return rec, prec, ap
diff --git a/det/yolox/evaluators/voc_evaluator.py b/det/yolox/evaluators/voc_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..b002fe194b125b46cb1885b59ba5d2f1389b1f2c
--- /dev/null
+++ b/det/yolox/evaluators/voc_evaluator.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import sys
+import tempfile
+import time
+from collections import ChainMap
+from loguru import logger
+from tqdm import tqdm
+
+import numpy as np
+
+import torch
+
+from det.yolox.utils import (
+    gather,
+    is_main_process,
+    postprocess,
+    synchronize,
+    time_synchronized,
+)
+
+
+class VOCEvaluator:
+    """VOC AP Evaluation class."""
+
+    def __init__(
+        self,
+        dataloader,
+        img_size,
+        confthre,
+        nmsthre,
+        num_classes,
+    ):
+        """
+        Args:
+            dataloader (Dataloader): evaluate dataloader.
+            img_size (int): image size after preprocess. images are resized
+                to squares whose shape is (img_size, img_size).
+            confthre (float): confidence threshold ranging from 0 to 1, which
+                is defined in the config file.
+            nmsthre (float): IoU threshold of non-max supression ranging from 0 to 1.
+        """
+        self.dataloader = dataloader
+        self.img_size = img_size
+        self.confthre = confthre
+        self.nmsthre = nmsthre
+        self.num_classes = num_classes
+        self.num_images = len(dataloader.dataset)
+
+    def evaluate(
+        self,
+        model,
+        distributed=False,
+        half=False,
+        trt_file=None,
+        decoder=None,
+        test_size=None,
+    ):
+        """VOC average precision (AP) Evaluation. Iterate inference on the test
+        dataset and the results are evaluated by COCO API.
+
+        NOTE: This function will change training mode to False, please save states if needed.
+
+        Args:
+            model : model to evaluate.
+
+        Returns:
+            ap50_95 (float) : COCO style AP of IoU=50:95
+            ap50 (float) : VOC 2007 metric AP of IoU=50
+            summary (sr): summary info of evaluation.
+        """
+        # TODO half to amp_test
+        tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
+        model = model.eval()
+        if half:
+            model = model.half()
+        ids = []
+        data_list = {}
+        progress_bar = tqdm if is_main_process() else iter
+
+        inference_time = 0
+        nms_time = 0
+        n_samples = max(len(self.dataloader) - 1, 1)
+
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
+            model(x)
+            model = model_trt
+
+        for cur_iter, (imgs, _, info_imgs, ids) in enumerate(progress_bar(self.dataloader)):
+            with torch.no_grad():
+                imgs = imgs.type(tensor_type)
+
+                # skip the the last iters since batchsize might be not enough for batch inference
+                is_time_record = cur_iter < len(self.dataloader) - 1
+                if is_time_record:
+                    start = time.time()
+
+                outputs = model(imgs)
+                if decoder is not None:
+                    outputs = decoder(outputs, dtype=outputs.type())
+
+                if is_time_record:
+                    infer_end = time_synchronized()
+                    inference_time += infer_end - start
+
+                outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)
+                if is_time_record:
+                    nms_end = time_synchronized()
+                    nms_time += nms_end - infer_end
+
+            data_list.update(self.convert_to_voc_format(outputs, info_imgs, ids))
+
+        statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
+        if distributed:
+            data_list = gather(data_list, dst=0)
+            data_list = ChainMap(*data_list)
+            torch.distributed.reduce(statistics, dst=0)
+
+        eval_results = self.evaluate_prediction(data_list, statistics)
+        synchronize()
+        return eval_results
+
+    def convert_to_voc_format(self, outputs, info_imgs, ids):
+        predictions = {}
+        for (output, img_h, img_w, img_id) in zip(outputs, info_imgs[0], info_imgs[1], ids):
+            if output is None:
+                predictions[int(img_id)] = (None, None, None)
+                continue
+            output = output.cpu()
+
+            bboxes = output[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(self.img_size[0] / float(img_h), self.img_size[1] / float(img_w))
+            bboxes /= scale
+
+            cls = output[:, 6]
+            scores = output[:, 4] * output[:, 5]
+
+            predictions[int(img_id)] = (bboxes, cls, scores)
+        return predictions
+
+    def evaluate_prediction(self, data_dict, statistics):
+        if not is_main_process():
+            return 0, 0, None
+
+        logger.info("Evaluate in main process...")
+
+        inference_time = statistics[0].item()
+        nms_time = statistics[1].item()
+        n_samples = statistics[2].item()
+
+        a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
+        a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
+
+        time_info = ", ".join(
+            [
+                "Average {} time: {:.2f} ms".format(k, v)
+                for k, v in zip(
+                    ["forward", "NMS", "inference"],
+                    [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
+                )
+            ]
+        )
+
+        info = time_info + "\n"
+
+        all_boxes = [[[] for _ in range(self.num_images)] for _ in range(self.num_classes)]
+        for img_num in range(self.num_images):
+            bboxes, cls, scores = data_dict[img_num]
+            if bboxes is None:
+                for j in range(self.num_classes):
+                    all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
+                continue
+            for j in range(self.num_classes):
+                mask_c = cls == j
+                if sum(mask_c) == 0:
+                    all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
+                    continue
+
+                c_dets = torch.cat((bboxes, scores.unsqueeze(1)), dim=1)
+                all_boxes[j][img_num] = c_dets[mask_c].numpy()
+
+            sys.stdout.write("im_eval: {:d}/{:d} \r".format(img_num + 1, self.num_images))
+            sys.stdout.flush()
+
+        with tempfile.TemporaryDirectory() as tempdir:
+            mAP50, mAP70 = self.dataloader.dataset.evaluate_detections(all_boxes, tempdir)
+            return mAP50, mAP70, info
diff --git a/det/yolox/evaluators/yolox_coco_evaluator.py b/det/yolox/evaluators/yolox_coco_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..03102a98cec3652ce600ce0c21a75db9e7bedec8
--- /dev/null
+++ b/det/yolox/evaluators/yolox_coco_evaluator.py
@@ -0,0 +1,758 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import contextlib
+import copy
+import io
+import itertools
+import json
+import logging
+import numpy as np
+import os
+import os.path as osp
+import pickle
+from collections import OrderedDict
+import pycocotools.mask as mask_util
+import torch
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from tabulate import tabulate
+
+from detectron2.config import CfgNode
+from detectron2.data import MetadataCatalog
+from detectron2.data.datasets.coco import convert_to_coco_json
+from detectron2.evaluation.fast_eval_api import COCOeval_opt
+from detectron2.structures import Boxes, BoxMode, pairwise_iou
+from detectron2.utils.file_io import PathManager
+from detectron2.utils.logger import create_small_table
+
+from detectron2.evaluation.evaluator import DatasetEvaluator
+
+from det.yolox.utils import (
+    gather,
+    postprocess,
+    synchronize,
+    time_synchronized,
+    xyxy2xywh,
+)
+import core.utils.my_comm as comm
+
+import ref
+
+
+class YOLOX_COCOEvaluator(DatasetEvaluator):
+    """Evaluate AR for object proposals, AP for instance
+    detection/segmentation, AP for keypoint detection outputs using COCO's
+    metrics. See http://cocodataset.org/#detection-eval and
+    http://cocodataset.org/#keypoints-eval to understand its metrics. The
+    metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
+    the metric cannot be computed (e.g. due to no predictions made).
+
+    In addition to COCO, this evaluator is able to support any bounding
+    box detection, instance segmentation, or keypoint detection dataset.
+    """
+
+    def __init__(
+        self,
+        dataset_name,
+        filter_scene=False,
+        tasks=None,
+        distributed=True,
+        output_dir=None,
+        *,
+        use_fast_impl=True,
+        kpt_oks_sigmas=(),
+    ):
+        """
+        Args:
+            dataset_name (str): name of the dataset to be evaluated.
+                It must have either the following corresponding metadata:
+
+                    "json_file": the path to the COCO format annotation
+
+                Or it must be in detectron2's standard dataset format
+                so it can be converted to COCO format automatically.
+            tasks (tuple[str]): tasks that can be evaluated under the given
+                configuration. A task is one of "bbox", "segm", "keypoints".
+                By default, will infer this automatically from predictions.
+            distributed (True): if True, will collect results from all ranks and run evaluation
+                in the main process.
+                Otherwise, will only evaluate the results in the current process.
+            output_dir (str): optional, an output directory to dump all
+                results predicted on the dataset. The dump contains two files:
+
+                1. "instances_predictions.pth" a file that can be loaded with `torch.load` and
+                   contains all the results in the format they are produced by the model.
+                2. "coco_instances_results.json" a json file in COCO's result format.
+            use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP.
+                Although the results should be very close to the official implementation in COCO
+                API, it is still recommended to compute results with the official API for use in
+                papers. The faster implementation also uses more RAM.
+            kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS.
+                See http://cocodataset.org/#keypoints-eval
+                When empty, it will use the defaults in COCO.
+                Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
+        """
+        self._logger = logging.getLogger(__name__)
+        self._distributed = distributed
+        self._output_dir = output_dir
+        self._use_fast_impl = use_fast_impl
+        self.filter_scene = filter_scene
+
+        if tasks is not None and isinstance(tasks, CfgNode):
+            kpt_oks_sigmas = tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas
+            self._logger.warn(
+                "COCO Evaluator instantiated using config, this is deprecated behavior."
+                " Please pass in explicit arguments instead."
+            )
+            self._tasks = None  # Infering it from predictions should be better
+        else:
+            self._tasks = tasks
+
+        self._cpu_device = torch.device("cpu")
+
+        self._metadata = MetadataCatalog.get(dataset_name)
+        if not hasattr(self._metadata, "json_file"):
+            self._logger.info(
+                f"'{dataset_name}' is not registered by `register_coco_instances`."
+                " Therefore trying to convert it to COCO format ..."
+            )
+
+            cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
+            self._metadata.json_file = cache_path
+            convert_to_coco_json(dataset_name, cache_path)
+
+        json_file = PathManager.get_local_path(self._metadata.json_file)
+        with contextlib.redirect_stdout(io.StringIO()):
+            self._coco_api = COCO(json_file)
+
+        # Test set json files do not contain annotations (evaluation must be
+        # performed using the COCO evaluation server).
+        self._do_evaluation = "annotations" in self._coco_api.dataset
+        if self._do_evaluation:
+            self._kpt_oks_sigmas = kpt_oks_sigmas
+
+        if self.filter_scene:
+            self.data_ref = ref.__dict__[self._metadata.ref_key]
+            self.objs = self._metadata.objs
+
+    def reset(self):
+        self._predictions = []
+
+    # def process(self, inputs, outputs):
+    #     """
+    #     Args:
+    #         inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
+    #             It is a list of dict. Each dict corresponds to an image and
+    #             contains keys like "height", "width", "file_name", "image_id".
+    #         outputs: the outputs of a COCO model. It is a list of dicts with key
+    #             "instances" that contains :class:`Instances`.
+    #     """
+    #     for input, output in zip(inputs, outputs):
+    #         prediction = {"image_id": input["image_id"]}
+
+    #         if "instances" in output:
+    #             instances = output["instances"].to(self._cpu_device)
+    #             prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
+    #         if "proposals" in output:
+    #             prediction["proposals"] = output["proposals"].to(self._cpu_device)
+    #         if len(prediction) > 1:
+    #             self._predictions.append(prediction)
+
+    def process(self, outputs, scene_im_ids, info_imgs, ids, cfg):
+        # TODO: use inputs as dataset dicts
+        # cur_predictions = self.convert_to_coco_format(outputs["det_preds"], scene_im_ids, info_imgs, ids, cfg)
+        cur_predictions = self.convert_to_coco_format_bop(
+            outputs["det_preds"], scene_im_ids, info_imgs, ids, outputs["time"], cfg
+        )
+        self._predictions.extend(cur_predictions)
+
+    def convert_to_coco_format(self, det_preds, scene_im_ids, info_imgs, ids, cfg):
+        data_list = []
+        for (det_pred, scene_im_id, img_h, img_w, img_id) in zip(
+            det_preds, scene_im_ids, info_imgs[0], info_imgs[1], ids
+        ):
+            if det_pred is None:
+                continue
+
+            prediction = {"image_id": int(img_id)}
+            det_pred = det_pred.cpu()
+
+            instances = []
+            bboxes = det_pred[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(cfg.test_size[0] / float(img_h), cfg.test_size[1] / float(img_w))
+            bboxes /= scale
+            bboxes = xyxy2xywh(bboxes)
+
+            labels = det_pred[:, 6]
+            scores = det_pred[:, 4] * det_pred[:, 5]
+            for ind in range(bboxes.shape[0]):
+                # label = self.dataloader.dataset.class_ids[int(cls[ind])]
+                if self.filter_scene:
+                    label = int(labels[ind])
+                    obj_name = self.objs[label]
+                    obj_id = self.data_ref.obj2id[obj_name]
+                    scene_id = int(scene_im_id.split("/")[0])
+                    if scene_id != obj_id:
+                        continue
+
+                pred_data = {
+                    "image_id": int(img_id),
+                    "category_id": int(labels[ind]),
+                    "bbox": bboxes[ind].numpy().tolist(),
+                    "score": scores[ind].numpy().item(),
+                    # "segmentation": [],
+                }  # COCO json format
+                instances.append(pred_data)
+            prediction["instances"] = instances
+            data_list.append(prediction)
+        return data_list
+
+    def convert_to_coco_format_bop(self, det_preds, scene_im_ids, info_imgs, ids, time, cfg):
+        data_list = []
+        for (det_pred, scene_im_id, img_h, img_w, img_id) in zip(
+            det_preds, scene_im_ids, info_imgs[0], info_imgs[1], ids
+        ):
+            if det_pred is None:
+                continue
+
+            prediction = {"image_id": int(img_id)}
+            det_pred = det_pred.cpu()
+
+            instances = []
+            instances_bop = []
+            bboxes = det_pred[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(cfg.test_size[0] / float(img_h), cfg.test_size[1] / float(img_w))
+            bboxes /= scale
+            bboxes = xyxy2xywh(bboxes)
+
+            labels = det_pred[:, 6]
+            scores = det_pred[:, 4] * det_pred[:, 5]
+            for ind in range(bboxes.shape[0]):
+                # label = self.dataloader.dataset.class_ids[int(cls[ind])]
+                scene_id = int(scene_im_id.split("/")[0])
+                if self.filter_scene:
+                    label = int(labels[ind])
+                    obj_name = self.objs[label]
+                    obj_id = self.data_ref.obj2id[obj_name]
+                    if scene_id != obj_id:
+                        continue
+
+                pred_data = {
+                    "image_id": int(img_id),
+                    "category_id": int(labels[ind]),
+                    "score": scores[ind].numpy().item(),
+                    "bbox": bboxes[ind].numpy().tolist(),
+                    # "segmentation": [],
+                }  # COCO json format
+                instances.append(pred_data)
+
+                pred_data_bop = {
+                    "scene_id": int(scene_id),
+                    "image_id": int(scene_im_id.split("/")[1]),  # careful!
+                    "category_id": int(labels[ind] + 1),
+                    "score": scores[ind].numpy().item(),
+                    "bbox": bboxes[ind].numpy().tolist(),
+                    # "segmentation": [],
+                    "time": time,
+                }  # COCO json format
+                instances_bop.append(pred_data_bop)
+            prediction["instances"] = instances
+            prediction["instances_bop"] = instances_bop
+            data_list.append(prediction)
+        return data_list
+
+    def evaluate(self, img_ids=None, eval_cached=False):
+        """
+        Args:
+            img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset
+        """
+        if eval_cached:
+            if self._distributed:
+                if not comm.is_main_process():
+                    return {}
+            file_path = os.path.join(self._output_dir, "instances_predictions.pth")
+            assert osp.exists(file_path), file_path
+            self._logger.info("evaluating from cached results: {}".format(file_path))
+            self._predictions = predictions = torch.load(file_path)
+        else:
+            if self._distributed:
+                comm.synchronize()
+                predictions = comm.gather(self._predictions, dst=0)
+                predictions = list(itertools.chain(*predictions))
+
+                if not comm.is_main_process():
+                    return {}
+            else:
+                predictions = self._predictions
+
+            if len(predictions) == 0:
+                self._logger.warning("[YOLOX_COCOEvaluator] Did not receive valid predictions.")
+                return {}
+
+            if self._output_dir:
+                PathManager.mkdirs(self._output_dir)
+                file_path = os.path.join(self._output_dir, "instances_predictions.pth")
+                with PathManager.open(file_path, "wb") as f:
+                    torch.save(predictions, f)
+
+        self._results = OrderedDict()
+        if "proposals" in predictions[0]:
+            self._eval_box_proposals(predictions)
+        if "instances" in predictions[0]:
+            self._eval_predictions(predictions, img_ids=img_ids)
+        # Copy so the caller can do whatever with results
+        return copy.deepcopy(self._results)
+
+    def _tasks_from_predictions(self, predictions):
+        """Get COCO API "tasks" (i.e. iou_type) from COCO-format
+        predictions."""
+        tasks = {"bbox"}
+        for pred in predictions:
+            if "segmentation" in pred:
+                tasks.add("segm")
+            if "keypoints" in pred:
+                tasks.add("keypoints")
+        return sorted(tasks)
+
+    def _eval_predictions(self, predictions, img_ids=None):
+        """Evaluate predictions.
+
+        Fill self._results with the metrics of the tasks.
+        """
+        self._logger.info("Preparing results for COCO format ...")
+        coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
+        coco_results_bop = list(itertools.chain(*[x["instances_bop"] for x in predictions]))
+        tasks = self._tasks or self._tasks_from_predictions(coco_results)
+
+        # unmap the category ids for COCO
+        if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+            dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
+            all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
+            num_classes = len(all_contiguous_ids)
+            assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
+
+            reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
+            for result in coco_results:
+                category_id = result["category_id"]
+                assert category_id < num_classes, (
+                    f"A prediction has class={category_id}, "
+                    f"but the dataset only has {num_classes} classes and "
+                    f"predicted class id should be in [0, {num_classes - 1}]."
+                )
+                result["category_id"] = reverse_id_mapping[category_id]
+
+        if self._output_dir:
+            file_path = os.path.join(self._output_dir, "coco_instances_results.json")
+            self._logger.info("Saving results to {}".format(file_path))
+            with PathManager.open(file_path, "w") as f:
+                f.write(json.dumps(coco_results))
+                f.flush()
+
+        # unmap the category ids for COCO BOP
+        if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
+            dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
+            all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
+            num_classes = len(all_contiguous_ids)
+            assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
+
+            reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
+            for result in coco_results_bop:
+                category_id = result["category_id"]
+                assert category_id < num_classes, (
+                    f"A prediction has class={category_id}, "
+                    f"but the dataset only has {num_classes} classes and "
+                    f"predicted class id should be in [0, {num_classes - 1}]."
+                )
+                result["category_id"] = reverse_id_mapping[category_id]
+
+        if self._output_dir:
+            file_path = os.path.join(self._output_dir, "coco_instances_results_bop.json")
+            self._logger.info("Saving results to {}".format(file_path))
+            with PathManager.open(file_path, "w") as f:
+                f.write(json.dumps(coco_results_bop))
+                f.flush()
+
+        if not self._do_evaluation:
+            self._logger.info("Annotations are not available for evaluation.")
+            return
+
+        self._logger.info(
+            "Evaluating predictions with {} COCO API...".format("unofficial" if self._use_fast_impl else "official")
+        )
+        # import ipdb; ipdb.set_trace()
+        for task in sorted(tasks):
+            assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
+            coco_eval = (
+                _evaluate_predictions_on_coco(
+                    self._coco_api,
+                    coco_results,
+                    task,
+                    kpt_oks_sigmas=self._kpt_oks_sigmas,
+                    use_fast_impl=self._use_fast_impl,
+                    img_ids=img_ids,
+                )
+                if len(coco_results) > 0
+                else None  # cocoapi does not handle empty results very well
+            )
+
+            res = self._derive_coco_results(coco_eval, task, class_names=self._metadata.get("thing_classes"))
+            self._results[task] = res
+
+    def _eval_box_proposals(self, predictions):
+        """Evaluate the box proposals in predictions.
+
+        Fill self._results with the metrics for "box_proposals" task.
+        """
+        if self._output_dir:
+            # Saving generated box proposals to file.
+            # Predicted box_proposals are in XYXY_ABS mode.
+            bbox_mode = BoxMode.XYXY_ABS.value
+            ids, boxes, objectness_logits = [], [], []
+            for prediction in predictions:
+                ids.append(prediction["image_id"])
+                boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy())
+                objectness_logits.append(prediction["proposals"].objectness_logits.numpy())
+
+            proposal_data = {
+                "boxes": boxes,
+                "objectness_logits": objectness_logits,
+                "ids": ids,
+                "bbox_mode": bbox_mode,
+            }
+            with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f:
+                pickle.dump(proposal_data, f)
+
+        if not self._do_evaluation:
+            self._logger.info("Annotations are not available for evaluation.")
+            return
+
+        self._logger.info("Evaluating bbox proposals ...")
+        res = {}
+        areas = {"all": "", "small": "s", "medium": "m", "large": "l"}
+        for limit in [100, 1000]:
+            for area, suffix in areas.items():
+                stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit)
+                key = "AR{}@{:d}".format(suffix, limit)
+                res[key] = float(stats["ar"].item() * 100)
+        self._logger.info("Proposal metrics: \n" + create_small_table(res))
+        self._results["box_proposals"] = res
+
+    def _derive_coco_results(self, coco_eval, iou_type, class_names=None):
+        """Derive the desired score numbers from summarized COCOeval.
+
+        Args:
+            coco_eval (None or COCOEval): None represents no predictions from model.
+            iou_type (str):
+            class_names (None or list[str]): if provided, will use it to predict
+                per-category AP.
+
+        Returns:
+            a dict of {metric name: score}
+        """
+
+        metrics = {
+            "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+            "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"],
+            "keypoints": ["AP", "AP50", "AP75", "APm", "APl"],
+        }[iou_type]
+
+        if coco_eval is None:
+            self._logger.warn("No predictions from the model!")
+            return {metric: float("nan") for metric in metrics}
+
+        # the standard metrics
+        results = {
+            metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan")
+            for idx, metric in enumerate(metrics)
+        }
+        self._logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results))
+        if not np.isfinite(sum(results.values())):
+            self._logger.info("Some metrics cannot be computed and is shown as NaN.")
+
+        if class_names is None or len(class_names) <= 1:
+            return results
+        # Compute per-category AP
+        # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa
+        precisions = coco_eval.eval["precision"]
+        # precision has dims (iou, recall, cls, area range, max dets)
+        assert len(class_names) == precisions.shape[2]
+
+        results_per_category = []
+        for idx, name in enumerate(class_names):
+            # area range index 0: all area ranges
+            # max dets index -1: typically 100 per image
+            precision = precisions[:, :, idx, 0, -1]
+            precision = precision[precision > -1]
+            ap = np.mean(precision) if precision.size else float("nan")
+            results_per_category.append(("{}".format(name), float(ap * 100)))
+
+        # tabulate it
+        N_COLS = min(6, len(results_per_category) * 2)
+        results_flatten = list(itertools.chain(*results_per_category))
+        results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)])
+        table = tabulate(
+            results_2d,
+            tablefmt="pipe",
+            floatfmt=".3f",
+            headers=["category", "AP"] * (N_COLS // 2),
+            numalign="left",
+        )
+        self._logger.info("Per-category {} AP: \n".format(iou_type) + table)
+
+        results.update({"AP-" + name: ap for name, ap in results_per_category})
+
+        ###############################
+        # Compute per-category AR
+        recalls = coco_eval.eval["recall"]
+        # recall has dims (iou, cls, area range, max dets)
+        assert len(class_names) == recalls.shape[1]
+
+        recall_results_per_category = []
+        for idx, name in enumerate(class_names):
+            # area range index 0: all area ranges
+            # max dets index -1: typically 100 per image
+            _recall = recalls[:, idx, 0, -1]
+            _recall = _recall[_recall > -1]
+            ar = np.mean(_recall) if _recall.size else float("nan")
+            recall_results_per_category.append(("{}".format(name), float(ar * 100)))
+
+        # tabulate it
+        N_COLS = min(10, len(recall_results_per_category) * 2)
+        recall_results_flatten = list(itertools.chain(*recall_results_per_category))
+        recall_results_2d = itertools.zip_longest(*[recall_results_flatten[i::N_COLS] for i in range(N_COLS)])
+        table = tabulate(
+            recall_results_2d,
+            tablefmt="pipe",
+            floatfmt=".3f",
+            headers=["category", "AR"] * (N_COLS // 2),
+            numalign="left",
+        )
+        self._logger.info("Per-category {} AR: \n".format(iou_type) + table)
+
+        results.update({"AR-" + name: _ar for name, _ar in recall_results_per_category})
+        return results
+
+
+# def instances_to_coco_json(instances, img_id):
+#     """
+#     Dump an "Instances" object to a COCO-format json that's used for evaluation.
+
+#     Args:
+#         instances (Instances):
+#         img_id (int): the image id
+
+#     Returns:
+#         list[dict]: list of json annotations in COCO format.
+#     """
+#     num_instance = len(instances)
+#     if num_instance == 0:
+#         return []
+
+#     boxes = instances.pred_boxes.tensor.numpy()
+#     boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
+#     boxes = boxes.tolist()
+#     scores = instances.scores.tolist()
+#     classes = instances.pred_classes.tolist()
+
+#     has_mask = instances.has("pred_masks")
+#     if has_mask:
+#         # use RLE to encode the masks, because they are too large and takes memory
+#         # since this evaluator stores outputs of the entire dataset
+#         rles = [
+#             mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
+#             for mask in instances.pred_masks
+#         ]
+#         for rle in rles:
+#             # "counts" is an array encoded by mask_util as a byte-stream. Python3's
+#             # json writer which always produces strings cannot serialize a bytestream
+#             # unless you decode it. Thankfully, utf-8 works out (which is also what
+#             # the pycocotools/_mask.pyx does).
+#             rle["counts"] = rle["counts"].decode("utf-8")
+
+#     has_keypoints = instances.has("pred_keypoints")
+#     if has_keypoints:
+#         keypoints = instances.pred_keypoints
+
+#     results = []
+#     for k in range(num_instance):
+#         result = {
+#             "image_id": img_id,
+#             "category_id": classes[k],
+#             "bbox": boxes[k],
+#             "score": scores[k],
+#         }
+#         if has_mask:
+#             result["segmentation"] = rles[k]
+#         if has_keypoints:
+#             # In COCO annotations,
+#             # keypoints coordinates are pixel indices.
+#             # However our predictions are floating point coordinates.
+#             # Therefore we subtract 0.5 to be consistent with the annotation format.
+#             # This is the inverse of data loading logic in `datasets/coco.py`.
+#             keypoints[k][:, :2] -= 0.5
+#             result["keypoints"] = keypoints[k].flatten().tolist()
+#         results.append(result)
+#     return results
+
+
+# inspired from Detectron:
+# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa
+def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None):
+    """Evaluate detection proposal recall metrics.
+
+    This function is a much faster alternative to the official COCO API
+    recall evaluation code. However, it produces slightly different
+    results.
+    """
+    # Record max overlap value for each gt box
+    # Return vector of overlap values
+    areas = {
+        "all": 0,
+        "small": 1,
+        "medium": 2,
+        "large": 3,
+        "96-128": 4,
+        "128-256": 5,
+        "256-512": 6,
+        "512-inf": 7,
+    }
+    area_ranges = [
+        [0**2, 1e5**2],  # all
+        [0**2, 32**2],  # small
+        [32**2, 96**2],  # medium
+        [96**2, 1e5**2],  # large
+        [96**2, 128**2],  # 96-128
+        [128**2, 256**2],  # 128-256
+        [256**2, 512**2],  # 256-512
+        [512**2, 1e5**2],
+    ]  # 512-inf
+    assert area in areas, "Unknown area range: {}".format(area)
+    area_range = area_ranges[areas[area]]
+    gt_overlaps = []
+    num_pos = 0
+
+    for prediction_dict in dataset_predictions:
+        predictions = prediction_dict["proposals"]
+
+        # sort predictions in descending order
+        # TODO maybe remove this and make it explicit in the documentation
+        inds = predictions.objectness_logits.sort(descending=True)[1]
+        predictions = predictions[inds]
+
+        ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"])
+        anno = coco_api.loadAnns(ann_ids)
+        gt_boxes = [
+            BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno if obj["iscrowd"] == 0
+        ]
+        gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4)  # guard against no boxes
+        gt_boxes = Boxes(gt_boxes)
+        gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0])
+
+        if len(gt_boxes) == 0 or len(predictions) == 0:
+            continue
+
+        valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1])
+        gt_boxes = gt_boxes[valid_gt_inds]
+
+        num_pos += len(gt_boxes)
+
+        if len(gt_boxes) == 0:
+            continue
+
+        if limit is not None and len(predictions) > limit:
+            predictions = predictions[:limit]
+
+        overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes)
+
+        _gt_overlaps = torch.zeros(len(gt_boxes))
+        for j in range(min(len(predictions), len(gt_boxes))):
+            # find which proposal box maximally covers each gt box
+            # and get the iou amount of coverage for each gt box
+            max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+            # find which gt box is 'best' covered (i.e. 'best' = most iou)
+            gt_ovr, gt_ind = max_overlaps.max(dim=0)
+            assert gt_ovr >= 0
+            # find the proposal box that covers the best covered gt box
+            box_ind = argmax_overlaps[gt_ind]
+            # record the iou coverage of this gt box
+            _gt_overlaps[j] = overlaps[box_ind, gt_ind]
+            assert _gt_overlaps[j] == gt_ovr
+            # mark the proposal box and the gt box as used
+            overlaps[box_ind, :] = -1
+            overlaps[:, gt_ind] = -1
+
+        # append recorded iou coverage level
+        gt_overlaps.append(_gt_overlaps)
+    gt_overlaps = torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32)
+    gt_overlaps, _ = torch.sort(gt_overlaps)
+
+    if thresholds is None:
+        step = 0.05
+        thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32)
+    recalls = torch.zeros_like(thresholds)
+    # compute recall for each iou threshold
+    for i, t in enumerate(thresholds):
+        recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos)
+    # ar = 2 * np.trapz(recalls, thresholds)
+    ar = recalls.mean()
+    return {
+        "ar": ar,
+        "recalls": recalls,
+        "thresholds": thresholds,
+        "gt_overlaps": gt_overlaps,
+        "num_pos": num_pos,
+    }
+
+
+def _evaluate_predictions_on_coco(
+    coco_gt,
+    coco_results,
+    iou_type,
+    kpt_oks_sigmas=None,
+    use_fast_impl=True,
+    img_ids=None,
+):
+    """Evaluate the coco results using COCOEval API."""
+    assert len(coco_results) > 0
+
+    if iou_type == "segm":
+        coco_results = copy.deepcopy(coco_results)
+        # When evaluating mask AP, if the results contain bbox, cocoapi will
+        # use the box area as the area of the instance, instead of the mask area.
+        # This leads to a different definition of small/medium/large.
+        # We remove the bbox field to let mask AP use mask area.
+        for c in coco_results:
+            c.pop("bbox", None)
+
+    coco_dt = coco_gt.loadRes(coco_results)
+    coco_eval = (COCOeval_opt if use_fast_impl else COCOeval)(coco_gt, coco_dt, iou_type)
+    if img_ids is not None:
+        coco_eval.params.imgIds = img_ids
+
+    if iou_type == "keypoints":
+        # Use the COCO default keypoint OKS sigmas unless overrides are specified
+        if kpt_oks_sigmas:
+            assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "pycocotools is too old!"
+            coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas)
+        # COCOAPI requires every detection and every gt to have keypoints, so
+        # we just take the first entry from both
+        num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3
+        num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3
+        num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas)
+        assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, (
+            f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. "
+            f"Ground truth contains {num_keypoints_gt} keypoints. "
+            f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. "
+            "They have to agree with each other. For meaning of OKS, please refer to "
+            "http://cocodataset.org/#keypoints-eval."
+        )
+
+    coco_eval.evaluate()
+    coco_eval.accumulate()
+    coco_eval.summarize()
+
+    return coco_eval
diff --git a/det/yolox/exp/__init__.py b/det/yolox/exp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..951195cb905195145ac10a6b9aefd84f9d9c3b03
--- /dev/null
+++ b/det/yolox/exp/__init__.py
@@ -0,0 +1,7 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+from .base_exp import BaseExp
+from .build import get_exp
+from .yolox_base import Exp
diff --git a/det/yolox/exp/base_exp.py b/det/yolox/exp/base_exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fa507562d080510e32d85cf6edbd711e414411e
--- /dev/null
+++ b/det/yolox/exp/base_exp.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import ast
+import pprint
+from abc import ABCMeta, abstractmethod
+from typing import Dict
+from tabulate import tabulate
+
+import torch
+from torch.nn import Module
+
+from det.yolox.utils import LRScheduler
+
+
+class BaseExp(metaclass=ABCMeta):
+    """Basic class for any experiment."""
+
+    def __init__(self):
+        self.seed = None
+        self.output_dir = "./YOLOX_outputs"
+        self.print_interval = 100
+        self.eval_interval = 10
+
+    @abstractmethod
+    def get_model(self) -> Module:
+        pass
+
+    @abstractmethod
+    def get_data_loader(self, batch_size: int, is_distributed: bool) -> Dict[str, torch.utils.data.DataLoader]:
+        pass
+
+    @abstractmethod
+    def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
+        pass
+
+    @abstractmethod
+    def get_lr_scheduler(self, lr: float, iters_per_epoch: int, **kwargs) -> LRScheduler:
+        pass
+
+    @abstractmethod
+    def get_evaluator(self):
+        pass
+
+    @abstractmethod
+    def eval(self, model, evaluator, weights):
+        pass
+
+    def __repr__(self):
+        table_header = ["keys", "values"]
+        exp_table = [(str(k), pprint.pformat(v)) for k, v in vars(self).items() if not k.startswith("_")]
+        return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
+
+    def merge(self, cfg_list):
+        assert len(cfg_list) % 2 == 0
+        for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
+            # only update value with same key
+            if hasattr(self, k):
+                src_value = getattr(self, k)
+                src_type = type(src_value)
+                if src_value is not None and src_type != type(v):
+                    try:
+                        v = src_type(v)
+                    except Exception:
+                        v = ast.literal_eval(v)
+                setattr(self, k, v)
diff --git a/det/yolox/exp/build.py b/det/yolox/exp/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f82d7ef8e579fa41d102813f0b71fd1eaf05e93
--- /dev/null
+++ b/det/yolox/exp/build.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import importlib
+import os
+import sys
+
+
+def get_exp_by_file(exp_file):
+    try:
+        sys.path.append(os.path.dirname(exp_file))
+        current_exp = importlib.import_module(os.path.basename(exp_file).split(".")[0])
+        exp = current_exp.Exp()
+    except Exception:
+        raise ImportError("{} doesn't contains class named 'Exp'".format(exp_file))
+    return exp
+
+
+def get_exp_by_name(exp_name):
+    from det import yolox
+
+    yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
+    filedict = {
+        "yolox-s": "yolox_s.py",
+        "yolox-m": "yolox_m.py",
+        "yolox-l": "yolox_l.py",
+        "yolox-x": "yolox_x.py",
+        "yolox-tiny": "yolox_tiny.py",
+        "yolox-nano": "nano.py",
+        "yolov3": "yolov3.py",
+    }
+    filename = filedict[exp_name]
+    exp_path = os.path.join(yolox_path, "yolox/exps", "default", filename)
+    # print(exp_path)
+    return get_exp_by_file(exp_path)
+
+
+def get_exp(exp_file, exp_name):
+    """get Exp object by file or name. If exp_file and exp_name are both
+    provided, get Exp by exp_file.
+
+    Args:
+        exp_file (str): file path of experiment.
+        exp_name (str): name of experiment. "yolo-s",
+    """
+    assert exp_file is not None or exp_name is not None, "plz provide exp file or exp name."
+    if exp_file is not None:
+        return get_exp_by_file(exp_file)
+    else:
+        return get_exp_by_name(exp_name)
diff --git a/det/yolox/exp/yolox_base.py b/det/yolox/exp/yolox_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..26d0de631c82ad2a1b7b0de1f939adc42d6449e8
--- /dev/null
+++ b/det/yolox/exp/yolox_base.py
@@ -0,0 +1,265 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import os
+import random
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from .base_exp import BaseExp
+
+
+class Exp(BaseExp):
+    def __init__(self):
+        super().__init__()
+
+        # ---------------- model config ---------------- #
+        self.num_classes = 80
+        self.depth = 1.00
+        self.width = 1.00
+
+        # ---------------- dataloader config ---------------- #
+        # set worker to 4 for shorter dataloader init time
+        self.data_num_workers = 4
+        self.input_size = (640, 640)
+        # Actual multiscale ranges: [640-5*32, 640+5*32].
+        # To disable multiscale training, set the
+        # self.multiscale_range to 0.
+        self.multiscale_range = 5
+        # You can uncomment this line to specify a multiscale range
+        # self.random_size = (14, 26)
+        self.data_dir = None
+        self.train_ann = "instances_train2017.json"
+        self.val_ann = "instances_val2017.json"
+
+        # --------------- transform config ----------------- #
+        self.mosaic_prob = 1.0
+        self.mixup_prob = 1.0
+        self.degrees = 10.0
+        self.translate = 0.1
+        self.mosaic_scale = (0.1, 2)
+        self.mixup_scale = (0.5, 1.5)
+        self.shear = 2.0
+        self.perspective = 0.0
+        self.enable_mixup = True
+
+        # --------------  training config --------------------- #
+        self.warmup_epochs = 5
+        self.max_epoch = 300
+        self.warmup_lr = 0
+        self.basic_lr_per_img = 0.01 / 64.0
+        self.scheduler = "yoloxwarmcos"
+        self.no_aug_epochs = 15
+        self.min_lr_ratio = 0.05
+        self.ema = True
+
+        self.weight_decay = 5e-4
+        self.momentum = 0.9
+        self.print_interval = 10
+        self.eval_interval = 10
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+        # -----------------  testing config ------------------ #
+        self.test_size = (640, 640)
+        self.test_conf = 0.01
+        self.nmsthre = 0.65
+
+    def get_model(self):
+        from det.yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+
+        if getattr(self, "model", None) is None:
+            in_channels = [256, 512, 1024]
+            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)
+            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
+
+    def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
+        from det.yolox.data import (
+            COCODataset,
+            DataLoader,
+            InfiniteSampler,
+            MosaicDetection,
+            TrainTransform,
+            YoloBatchSampler,
+            worker_init_reset_seed,
+        )
+        from det.yolox.utils import (
+            wait_for_the_master,
+            get_local_rank,
+        )
+
+        local_rank = get_local_rank()
+
+        with wait_for_the_master(local_rank):
+            dataset = COCODataset(
+                data_dir=self.data_dir,
+                json_file=self.train_ann,
+                img_size=self.input_size,
+                preproc=TrainTransform(max_labels=50),
+                cache=cache_img,
+            )
+
+        dataset = MosaicDetection(
+            dataset,
+            mosaic=not no_aug,
+            img_size=self.input_size,
+            preproc=TrainTransform(max_labels=120),
+            degrees=self.degrees,
+            translate=self.translate,
+            mosaic_scale=self.mosaic_scale,
+            mixup_scale=self.mixup_scale,
+            shear=self.shear,
+            perspective=self.perspective,
+            enable_mixup=self.enable_mixup,
+            mosaic_prob=self.mosaic_prob,
+            mixup_prob=self.mixup_prob,
+        )
+
+        self.dataset = dataset
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+
+        sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
+
+        batch_sampler = YoloBatchSampler(
+            sampler=sampler,
+            batch_size=batch_size,
+            drop_last=False,
+            mosaic=not no_aug,
+        )
+
+        dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
+        dataloader_kwargs["batch_sampler"] = batch_sampler
+
+        # Make sure each process has different random seed, especially for 'fork' method.
+        # Check https://github.com/pytorch/pytorch/issues/63311 for more details.
+        dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
+
+        train_loader = DataLoader(self.dataset, **dataloader_kwargs)
+
+        return train_loader
+
+    def random_resize(self, data_loader, epoch, rank, is_distributed):
+        tensor = torch.LongTensor(2).cuda()
+
+        if rank == 0:
+            size_factor = self.input_size[1] * 1.0 / self.input_size[0]
+            if not hasattr(self, "random_size"):
+                min_size = int(self.input_size[0] / 32) - self.multiscale_range
+                max_size = int(self.input_size[0] / 32) + self.multiscale_range
+                self.random_size = (min_size, max_size)
+            size = random.randint(*self.random_size)
+            size = (int(32 * size), 32 * int(size * size_factor))
+            tensor[0] = size[0]
+            tensor[1] = size[1]
+
+        if is_distributed:
+            dist.barrier()
+            dist.broadcast(tensor, 0)
+
+        input_size = (tensor[0].item(), tensor[1].item())
+        return input_size
+
+    def preprocess(self, inputs, targets, tsize):
+        scale = tsize[0] / self.input_size[0]
+        if scale != 1:
+            inputs = nn.functional.interpolate(inputs, size=tsize, mode="bilinear", align_corners=False)
+            targets[..., 1:] = targets[..., 1:] * scale
+        return inputs, targets
+
+    def get_optimizer(self, batch_size):
+        if "optimizer" not in self.__dict__:
+            if self.warmup_epochs > 0:
+                lr = self.warmup_lr
+            else:
+                lr = self.basic_lr_per_img * batch_size
+
+            pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
+
+            for k, v in self.model.named_modules():
+                if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
+                    pg2.append(v.bias)  # biases
+                if isinstance(v, nn.BatchNorm2d) or "bn" in k:
+                    pg0.append(v.weight)  # no decay
+                elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
+                    pg1.append(v.weight)  # apply decay
+
+            optimizer = torch.optim.SGD(pg0, lr=lr, momentum=self.momentum, nesterov=True)
+            optimizer.add_param_group({"params": pg1, "weight_decay": self.weight_decay})  # add pg1 with weight_decay
+            optimizer.add_param_group({"params": pg2})
+            self.optimizer = optimizer
+
+        return self.optimizer
+
+    def get_lr_scheduler(self, lr, iters_per_epoch):
+        from det.yolox.utils import LRScheduler
+
+        scheduler = LRScheduler(
+            self.scheduler,
+            lr,
+            iters_per_epoch,
+            self.max_epoch,
+            warmup_epochs=self.warmup_epochs,
+            warmup_lr_start=self.warmup_lr,
+            no_aug_epochs=self.no_aug_epochs,
+            min_lr_ratio=self.min_lr_ratio,
+        )
+        return scheduler
+
+    def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from det.yolox.data import COCODataset, ValTransform
+
+        valdataset = COCODataset(
+            data_dir=self.data_dir,
+            json_file=self.val_ann if not testdev else "image_info_test-dev2017.json",
+            name="val2017" if not testdev else "test2017",
+            img_size=self.test_size,
+            preproc=ValTransform(legacy=legacy),
+        )
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+            sampler = torch.utils.data.distributed.DistributedSampler(valdataset, shuffle=False)
+        else:
+            sampler = torch.utils.data.SequentialSampler(valdataset)
+
+        dataloader_kwargs = {
+            "num_workers": self.data_num_workers,
+            "pin_memory": True,
+            "sampler": sampler,
+        }
+        dataloader_kwargs["batch_size"] = batch_size
+        val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
+
+        return val_loader
+
+    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from det.yolox.evaluators import COCOEvaluator
+
+        val_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy)
+        evaluator = COCOEvaluator(
+            img_size=self.test_size,
+            conf_thr=self.test_conf,
+            nms_thr=self.nmsthre,
+            num_classes=self.num_classes,
+            testdev=testdev,
+        )
+        evaluator.set_dataloader(val_loader)
+        return evaluator
+
+    def eval(self, model, evaluator, is_distributed, half=False):
+        return evaluator.evaluate(model, is_distributed, half)
diff --git a/det/yolox/exps/default/nano.py b/det/yolox/exps/default/nano.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd4cf78c1132a78686398b40ea9d420b2f2822e5
--- /dev/null
+++ b/det/yolox/exps/default/nano.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+import torch.nn as nn
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.25
+        self.input_size = (416, 416)
+        self.random_size = (10, 20)
+        self.mosaic_scale = (0.5, 1.5)
+        self.test_size = (416, 416)
+        self.mosaic_prob = 0.5
+        self.enable_mixup = False
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_model(self, sublinear=False):
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+
+        if "model" not in self.__dict__:
+            from det.yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+
+            in_channels = [256, 512, 1024]
+            # NANO model use depthwise = True, which is main difference.
+            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, depthwise=True)
+            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, depthwise=True)
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
diff --git a/det/yolox/exps/default/yolov3.py b/det/yolox/exps/default/yolov3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aadf67509518d95e64ef48bf4768809a5cd0ecb
--- /dev/null
+++ b/det/yolox/exps/default/yolov3.py
@@ -0,0 +1,36 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+import torch
+import torch.nn as nn
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.0
+        self.width = 1.0
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_model(self, sublinear=False):
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+
+        if "model" not in self.__dict__:
+            from det.yolox.models import YOLOX, YOLOFPN, YOLOXHead
+
+            backbone = YOLOFPN()
+            head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
diff --git a/det/yolox/exps/default/yolox_l.py b/det/yolox/exps/default/yolox_l.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2c240e53e3588d760c340f593b86e78eeb9267
--- /dev/null
+++ b/det/yolox/exps/default/yolox_l.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.0
+        self.width = 1.0
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/det/yolox/exps/default/yolox_m.py b/det/yolox/exps/default/yolox_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f79deabd71ac2a3180527bf78d5f6072142d8e2
--- /dev/null
+++ b/det/yolox/exps/default/yolox_m.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.67
+        self.width = 0.75
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/det/yolox/exps/default/yolox_s.py b/det/yolox/exps/default/yolox_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..b475d54473ec4fbf7f6c0339e614beea97c89810
--- /dev/null
+++ b/det/yolox/exps/default/yolox_s.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.50
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/det/yolox/exps/default/yolox_tiny.py b/det/yolox/exps/default/yolox_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c9678bd5dd464bf11205de01d35a099f6270f3
--- /dev/null
+++ b/det/yolox/exps/default/yolox_tiny.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.375
+        self.input_scale = (416, 416)
+        self.mosaic_scale = (0.5, 1.5)
+        self.random_size = (10, 20)
+        self.test_size = (416, 416)
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+        self.enable_mixup = False
diff --git a/det/yolox/exps/default/yolox_x.py b/det/yolox/exps/default/yolox_x.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aded8f4bb888eda6ce56cb2a70e6065fc87cd80
--- /dev/null
+++ b/det/yolox/exps/default/yolox_x.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.33
+        self.width = 1.25
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/det/yolox/exps/example/custom/nano.py b/det/yolox/exps/example/custom/nano.py
new file mode 100644
index 0000000000000000000000000000000000000000..7319ec1eee874a7d2706d41e5755506149b3e7ff
--- /dev/null
+++ b/det/yolox/exps/example/custom/nano.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import torch.nn as nn
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.25
+        self.input_size = (416, 416)
+        self.mosaic_scale = (0.5, 1.5)
+        self.random_size = (10, 20)
+        self.test_size = (416, 416)
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+        self.enable_mixup = False
+
+        # Define yourself dataset path
+        self.data_dir = "datasets/coco128"
+        self.train_ann = "instances_train2017.json"
+        self.val_ann = "instances_val2017.json"
+
+        self.num_classes = 71
+
+    def get_model(self, sublinear=False):
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+
+        if "model" not in self.__dict__:
+            from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+
+            in_channels = [256, 512, 1024]
+            # NANO model use depthwise = True, which is main difference.
+            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, depthwise=True)
+            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, depthwise=True)
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
diff --git a/det/yolox/exps/example/custom/yolox_s.py b/det/yolox/exps/example/custom/yolox_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0b0a5f76b63a993c24e3f33c69fd960144a42c
--- /dev/null
+++ b/det/yolox/exps/example/custom/yolox_s.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.50
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+        # Define yourself dataset path
+        self.data_dir = "datasets/coco128"
+        self.train_ann = "instances_train2017.json"
+        self.val_ann = "instances_val2017.json"
+
+        self.num_classes = 71
+
+        self.max_epoch = 300
+        self.data_num_workers = 4
+        self.eval_interval = 1
diff --git a/det/yolox/exps/example/yolox_voc/yolox_voc_s.py b/det/yolox/exps/example/yolox_voc/yolox_voc_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..5148d22b4abcf4717745372123d47abacccc5215
--- /dev/null
+++ b/det/yolox/exps/example/yolox_voc/yolox_voc_s.py
@@ -0,0 +1,123 @@
+# encoding: utf-8
+import os
+
+import torch
+import torch.distributed as dist
+
+from det.yolox.utils import get_yolox_datadir
+from det.yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.num_classes = 20
+        self.depth = 0.33
+        self.width = 0.50
+        self.warmup_epochs = 1
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False):
+        from det.yolox.data import (
+            VOCDetection,
+            TrainTransform,
+            YoloBatchSampler,
+            DataLoader,
+            InfiniteSampler,
+            MosaicDetection,
+            worker_init_reset_seed,
+        )
+        from yolox.utils import (
+            wait_for_the_master,
+            get_local_rank,
+        )
+
+        local_rank = get_local_rank()
+
+        with wait_for_the_master(local_rank):
+            dataset = VOCDetection(
+                data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
+                image_sets=[("2007", "trainval"), ("2012", "trainval")],
+                img_size=self.input_size,
+                preproc=TrainTransform(max_labels=50),
+                cache=cache_img,
+            )
+
+        dataset = MosaicDetection(
+            dataset,
+            mosaic=not no_aug,
+            img_size=self.input_size,
+            preproc=TrainTransform(max_labels=120),
+            degrees=self.degrees,
+            translate=self.translate,
+            mosaic_scale=self.mosaic_scale,
+            mixup_scale=self.mixup_scale,
+            shear=self.shear,
+            perspective=self.perspective,
+            enable_mixup=self.enable_mixup,
+            mosaic_prob=self.mosaic_prob,
+            mixup_prob=self.mixup_prob,
+        )
+
+        self.dataset = dataset
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+
+        sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
+
+        batch_sampler = YoloBatchSampler(
+            sampler=sampler,
+            batch_size=batch_size,
+            drop_last=False,
+            mosaic=not no_aug,
+        )
+
+        dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
+        dataloader_kwargs["batch_sampler"] = batch_sampler
+
+        # Make sure each process has different random seed, especially for 'fork' method
+        dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
+
+        train_loader = DataLoader(self.dataset, **dataloader_kwargs)
+
+        return train_loader
+
+    def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from det.yolox.data import VOCDetection, ValTransform
+
+        valdataset = VOCDetection(
+            data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
+            image_sets=[("2007", "test")],
+            img_size=self.test_size,
+            preproc=ValTransform(legacy=legacy),
+        )
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+            sampler = torch.utils.data.distributed.DistributedSampler(valdataset, shuffle=False)
+        else:
+            sampler = torch.utils.data.SequentialSampler(valdataset)
+
+        dataloader_kwargs = {
+            "num_workers": self.data_num_workers,
+            "pin_memory": True,
+            "sampler": sampler,
+        }
+        dataloader_kwargs["batch_size"] = batch_size
+        val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
+
+        return val_loader
+
+    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from det.yolox.evaluators import VOCEvaluator
+
+        val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy)
+        evaluator = VOCEvaluator(
+            dataloader=val_loader,
+            img_size=self.test_size,
+            confthre=self.test_conf,
+            nmsthre=self.nmsthre,
+            num_classes=self.num_classes,
+        )
+        return evaluator
diff --git a/det/yolox/models/__init__.py b/det/yolox/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4641a61bf466259c88e0a0b92e4ff55b2abcd61
--- /dev/null
+++ b/det/yolox/models/__init__.py
@@ -0,0 +1,10 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+from .darknet import CSPDarknet, Darknet
+from .losses import IOUloss
+from .yolo_fpn import YOLOFPN
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+from .yolox import YOLOX
diff --git a/det/yolox/models/darknet.py b/det/yolox/models/darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b0a8284de5e2691548c162f7661887db64015f
--- /dev/null
+++ b/det/yolox/models/darknet.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+from torch import nn
+
+from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
+
+
+class Darknet(nn.Module):
+    # number of blocks from dark2 to dark5.
+    depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
+
+    def __init__(
+        self,
+        depth,
+        in_channels=3,
+        stem_out_channels=32,
+        out_features=("dark3", "dark4", "dark5"),
+    ):
+        """
+        Args:
+            depth (int): depth of darknet used in model, usually use [21, 53] for this param.
+            in_channels (int): number of input channels, for example, use 3 for RGB image.
+            stem_out_channels (int): number of output chanels of darknet stem.
+                It decides channels of darknet layer2 to layer5.
+            out_features (Tuple[str]): desired output layer name.
+        """
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        self.stem = nn.Sequential(
+            BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
+            *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
+        )
+        in_channels = stem_out_channels * 2  # 64
+
+        num_blocks = Darknet.depth2blocks[depth]
+        # create darknet with `stem_out_channels` and `num_blocks` layers.
+        # to make model structure more clear, we don't use `for` statement in python.
+        self.dark2 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[0], stride=2))
+        in_channels *= 2  # 128
+        self.dark3 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[1], stride=2))
+        in_channels *= 2  # 256
+        self.dark4 = nn.Sequential(*self.make_group_layer(in_channels, num_blocks[2], stride=2))
+        in_channels *= 2  # 512
+
+        self.dark5 = nn.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[3], stride=2),
+            *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
+        )
+
+    def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
+        """starts with conv layer then has `num_blocks` `ResLayer`"""
+        return [
+            BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
+            *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
+        ]
+
+    def make_spp_block(self, filters_list, in_filters):
+        m = nn.Sequential(
+            *[
+                BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                SPPBottleneck(
+                    in_channels=filters_list[1],
+                    out_channels=filters_list[0],
+                    activation="lrelu",
+                ),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
+            ]
+        )
+        return m
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
+
+
+class CSPDarknet(nn.Module):
+    def __init__(
+        self,
+        dep_mul,
+        wid_mul,
+        out_features=("dark3", "dark4", "dark5"),
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        Conv = DWConv if depthwise else BaseConv
+
+        base_channels = int(wid_mul * 64)  # 64
+        base_depth = max(round(dep_mul * 3), 1)  # 3
+
+        # stem
+        self.stem = Focus(3, base_channels, ksize=3, act=act)
+
+        # dark2
+        self.dark2 = nn.Sequential(
+            Conv(base_channels, base_channels * 2, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 2,
+                base_channels * 2,
+                n=base_depth,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark3
+        self.dark3 = nn.Sequential(
+            Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 4,
+                base_channels * 4,
+                n=base_depth * 3,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark4
+        self.dark4 = nn.Sequential(
+            Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 8,
+                base_channels * 8,
+                n=base_depth * 3,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark5
+        self.dark5 = nn.Sequential(
+            Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
+            SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
+            CSPLayer(
+                base_channels * 16,
+                base_channels * 16,
+                n=base_depth,
+                shortcut=False,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
diff --git a/det/yolox/models/losses.py b/det/yolox/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..a17543065723f59669f292164b7a65d185c3187d
--- /dev/null
+++ b/det/yolox/models/losses.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+class IOUloss(nn.Module):
+    def __init__(self, reduction="none", loss_type="iou"):
+        super(IOUloss, self).__init__()
+        self.reduction = reduction
+        self.loss_type = loss_type
+
+    def forward(self, pred, target):
+        assert pred.shape[0] == target.shape[0]
+
+        pred = pred.view(-1, 4)
+        target = target.view(-1, 4)
+        tl = torch.max((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
+        br = torch.min((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
+
+        area_p = torch.prod(pred[:, 2:], 1)
+        area_g = torch.prod(target[:, 2:], 1)
+
+        en = (tl < br).type(tl.type()).prod(dim=1)
+        area_i = torch.prod(br - tl, 1) * en
+        area_u = area_p + area_g - area_i
+        iou = (area_i) / (area_u + 1e-16)
+
+        if self.loss_type == "iou":
+            loss = 1 - iou**2
+        elif self.loss_type == "giou":
+            c_tl = torch.min((pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2))
+            c_br = torch.max((pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2))
+            area_c = torch.prod(c_br - c_tl, 1)
+            giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
+            loss = 1 - giou.clamp(min=-1.0, max=1.0)
+
+        if self.reduction == "mean":
+            loss = loss.mean()
+        elif self.reduction == "sum":
+            loss = loss.sum()
+
+        return loss
diff --git a/det/yolox/models/network_blocks.py b/det/yolox/models/network_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d5b078ee675f184ba608d55707a0dd34b840e9
--- /dev/null
+++ b/det/yolox/models/network_blocks.py
@@ -0,0 +1,196 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_activation(name="silu", inplace=True):
+    if name == "silu":
+        module = nn.SiLU(inplace=inplace)
+    elif name == "relu":
+        module = nn.ReLU(inplace=inplace)
+    elif name == "lrelu":
+        module = nn.LeakyReLU(0.1, inplace=inplace)
+    elif name == "mish":
+        module = nn.Mish(inplace=inplace)
+    elif name == "gelu":
+        module = nn.GELU()
+    else:
+        raise AttributeError("Unsupported act type: {}".format(name))
+    return module
+
+
+class BaseConv(nn.Module):
+    """A Conv2d -> Batchnorm -> silu/leaky relu block."""
+
+    def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
+        super().__init__()
+        # same padding
+        pad = (ksize - 1) // 2
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=ksize,
+            stride=stride,
+            padding=pad,
+            groups=groups,
+            bias=bias,
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.act = get_activation(act, inplace=True)
+
+    def forward(self, x):
+        return self.act(self.bn(self.conv(x)))
+
+    def fuseforward(self, x):
+        return self.act(self.conv(x))
+
+
+class DWConv(nn.Module):
+    """Depthwise Conv + Conv."""
+
+    def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
+        super().__init__()
+        self.dconv = BaseConv(
+            in_channels,
+            in_channels,
+            ksize=ksize,
+            stride=stride,
+            groups=in_channels,
+            act=act,
+        )
+        self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act)
+
+    def forward(self, x):
+        x = self.dconv(x)
+        return self.pconv(x)
+
+
+class Bottleneck(nn.Module):
+    # Standard bottleneck
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        shortcut=True,
+        expansion=0.5,
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)
+        Conv = DWConv if depthwise else BaseConv
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
+        self.use_add = shortcut and in_channels == out_channels
+
+    def forward(self, x):
+        y = self.conv2(self.conv1(x))
+        if self.use_add:
+            y = y + x
+        return y
+
+
+class ResLayer(nn.Module):
+    """Residual layer with `in_channels` inputs."""
+
+    def __init__(self, in_channels: int):
+        super().__init__()
+        mid_channels = in_channels // 2
+        self.layer1 = BaseConv(in_channels, mid_channels, ksize=1, stride=1, act="lrelu")
+        self.layer2 = BaseConv(mid_channels, in_channels, ksize=3, stride=1, act="lrelu")
+
+    def forward(self, x):
+        out = self.layer2(self.layer1(x))
+        return x + out
+
+
+class SPPBottleneck(nn.Module):
+    """Spatial pyramid pooling layer used in YOLOv3-SPP."""
+
+    def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
+        super().__init__()
+        hidden_channels = in_channels // 2
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
+        self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes])
+        conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
+        self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = torch.cat([x] + [m(x) for m in self.m], dim=1)
+        x = self.conv2(x)
+        return x
+
+
+class CSPLayer(nn.Module):
+    """C3 in yolov5, CSP Bottleneck with 3 convolutions."""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        n=1,
+        shortcut=True,
+        expansion=0.5,
+        depthwise=False,
+        act="silu",
+    ):
+        """
+        Args:
+            in_channels (int): input channels.
+            out_channels (int): output channels.
+            n (int): number of Bottlenecks. Default value: 1.
+        """
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)  # hidden channels
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
+        module_list = [
+            Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n)
+        ]
+        self.m = nn.Sequential(*module_list)
+
+    def forward(self, x):
+        x_1 = self.conv1(x)
+        x_2 = self.conv2(x)
+        x_1 = self.m(x_1)
+        x = torch.cat((x_1, x_2), dim=1)
+        return self.conv3(x)
+
+
+class Focus(nn.Module):
+    """Focus width and height information into channel space."""
+
+    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
+        super().__init__()
+        self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
+
+    def forward(self, x):
+        # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
+        patch_top_left = x[..., ::2, ::2]
+        patch_top_right = x[..., ::2, 1::2]
+        patch_bot_left = x[..., 1::2, ::2]
+        patch_bot_right = x[..., 1::2, 1::2]
+        x = torch.cat(
+            (
+                patch_top_left,
+                patch_bot_left,
+                patch_top_right,
+                patch_bot_right,
+            ),
+            dim=1,
+        )
+        return self.conv(x)
diff --git a/det/yolox/models/yolo_fpn.py b/det/yolox/models/yolo_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1e127e846d3be8ba769fcbad6dcfd50ccd110cc
--- /dev/null
+++ b/det/yolox/models/yolo_fpn.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from .darknet import Darknet
+from .network_blocks import BaseConv
+
+
+class YOLOFPN(nn.Module):
+    """YOLOFPN module.
+
+    Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self,
+        depth=53,
+        in_features=["dark3", "dark4", "dark5"],
+    ):
+        super().__init__()
+
+        self.backbone = Darknet(depth)
+        self.in_features = in_features
+
+        # out 1
+        self.out1_cbl = self._make_cbl(512, 256, 1)
+        self.out1 = self._make_embedding([256, 512], 512 + 256)
+
+        # out 2
+        self.out2_cbl = self._make_cbl(256, 128, 1)
+        self.out2 = self._make_embedding([128, 256], 256 + 128)
+
+        # upsample
+        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+
+    def _make_cbl(self, _in, _out, ks):
+        return BaseConv(_in, _out, ks, stride=1, act="lrelu")
+
+    def _make_embedding(self, filters_list, in_filters):
+        m = nn.Sequential(
+            *[
+                self._make_cbl(in_filters, filters_list[0], 1),
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+            ]
+        )
+        return m
+
+    def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
+        with open(filename, "rb") as f:
+            state_dict = torch.load(f, map_location="cpu")
+        print("loading pretrained weights...")
+        self.backbone.load_state_dict(state_dict)
+
+    def forward(self, inputs):
+        """
+        Args:
+            inputs (Tensor): input image.
+
+        Returns:
+            Tuple[Tensor]: FPN output features..
+        """
+        #  backbone
+        out_features = self.backbone(inputs)
+        x2, x1, x0 = [out_features[f] for f in self.in_features]
+
+        #  yolo branch 1
+        x1_in = self.out1_cbl(x0)
+        x1_in = self.upsample(x1_in)
+        x1_in = torch.cat([x1_in, x1], 1)
+        out_dark4 = self.out1(x1_in)
+
+        #  yolo branch 2
+        x2_in = self.out2_cbl(out_dark4)
+        x2_in = self.upsample(x2_in)
+        x2_in = torch.cat([x2_in, x2], 1)
+        out_dark3 = self.out2(x2_in)
+
+        outputs = (out_dark3, out_dark4, x0)
+        return outputs
diff --git a/det/yolox/models/yolo_head.py b/det/yolox/models/yolo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..07a5c2b44fbe4ff4a5b91a219423606c3ed3e37f
--- /dev/null
+++ b/det/yolox/models/yolo_head.py
@@ -0,0 +1,627 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import logging
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+from fvcore.nn.focal_loss import sigmoid_focal_loss
+from det.yolox.utils import bboxes_iou
+
+from .losses import IOUloss
+from .network_blocks import BaseConv, DWConv
+
+logger = logging.getLogger(__name__)
+
+
+class YOLOXHead(nn.Module):
+    def __init__(
+        self,
+        num_classes,
+        width=1.0,
+        strides=[8, 16, 32],
+        in_channels=[256, 512, 1024],
+        act="silu",
+        depthwise=False,
+        iou_loss_type="iou",  # iou | giou
+        cls_loss_type="bce",  # bce | focal
+        obj_loss_type="bce",  # bce
+        fl_alpha=0.25,
+        fl_gamma=2.0,
+    ):
+        """
+        Args:
+            act (str): activation type of conv. Defalut value: "silu".
+            depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
+        """
+        super().__init__()
+
+        self.n_anchors = 1
+        self.num_classes = num_classes
+        self.decode_in_inference = True  # for deploy, set to False
+
+        self.cls_convs = nn.ModuleList()
+        self.reg_convs = nn.ModuleList()
+        self.cls_preds = nn.ModuleList()
+        self.reg_preds = nn.ModuleList()
+        self.obj_preds = nn.ModuleList()
+        self.stems = nn.ModuleList()
+        Conv = DWConv if depthwise else BaseConv
+
+        for i in range(len(in_channels)):
+            self.stems.append(
+                BaseConv(
+                    in_channels=int(in_channels[i] * width),
+                    out_channels=int(256 * width),
+                    ksize=1,
+                    stride=1,
+                    act=act,
+                )
+            )
+            self.cls_convs.append(
+                nn.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.reg_convs.append(
+                nn.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.cls_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=self.n_anchors * self.num_classes,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.reg_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=4,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.obj_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=self.n_anchors * 1,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+
+        self.use_l1 = False
+        self.l1_loss = nn.L1Loss(reduction="none")
+
+        self.cls_loss_type = cls_loss_type
+        self.fl_gamma = fl_gamma
+        self.fl_alpha = fl_alpha
+        self.obj_loss_type = obj_loss_type
+        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
+
+        self.iou_loss = IOUloss(reduction="none", loss_type=iou_loss_type)
+        self.strides = strides
+        self.grids = [torch.zeros(1)] * len(in_channels)
+
+    def initialize_biases(self, prior_prob):
+        for conv in self.cls_preds:
+            b = conv.bias.view(self.n_anchors, -1)
+            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+        for conv in self.obj_preds:
+            b = conv.bias.view(self.n_anchors, -1)
+            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def forward(self, xin, labels=None, imgs=None):
+        outputs = []
+        origin_preds = []
+        x_shifts = []
+        y_shifts = []
+        expanded_strides = []
+
+        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
+            zip(self.cls_convs, self.reg_convs, self.strides, xin)
+        ):
+            x = self.stems[k](x)
+            cls_x = x
+            reg_x = x
+
+            cls_feat = cls_conv(cls_x)
+            cls_output = self.cls_preds[k](cls_feat)
+
+            reg_feat = reg_conv(reg_x)
+            reg_output = self.reg_preds[k](reg_feat)
+            obj_output = self.obj_preds[k](reg_feat)
+
+            if self.training:
+                output = torch.cat([reg_output, obj_output, cls_output], 1)
+                output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
+                x_shifts.append(grid[:, :, 0])
+                y_shifts.append(grid[:, :, 1])
+                expanded_strides.append(torch.zeros(1, grid.shape[1]).fill_(stride_this_level).type_as(xin[0]))
+                if self.use_l1:
+                    batch_size = reg_output.shape[0]
+                    hsize, wsize = reg_output.shape[-2:]
+                    reg_output = reg_output.view(batch_size, self.n_anchors, 4, hsize, wsize)
+                    reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(batch_size, -1, 4)
+                    origin_preds.append(reg_output.clone())
+
+            else:
+                output = torch.cat([reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1)
+
+            outputs.append(output)
+
+        if self.training:
+            return self.get_losses(
+                imgs,
+                x_shifts,
+                y_shifts,
+                expanded_strides,
+                labels,
+                torch.cat(outputs, 1),
+                origin_preds,
+                dtype=xin[0].dtype,
+            )
+        else:
+            self.hw = [x.shape[-2:] for x in outputs]
+            # [batch, n_anchors_all, 85]
+            outputs = torch.cat([x.flatten(start_dim=2) for x in outputs], dim=2).permute(0, 2, 1)
+            if self.decode_in_inference:
+                outputs = self.decode_outputs(outputs, dtype=xin[0].type())
+                # "box_preds": outputs[:, :4],
+                # "obj_score_preds": outputs[:, 4],
+                # "cls_score_preds": outputs[:, 5],
+                # "cls_preds": outputs[:, 6]
+                out_dict = {"det_preds": outputs}
+                return out_dict
+            else:  # for export and deploy
+                return outputs
+
+    def get_output_and_grid(self, output, k, stride, dtype):
+        grid = self.grids[k]
+
+        batch_size = output.shape[0]
+        n_ch = 5 + self.num_classes
+        hsize, wsize = output.shape[-2:]
+        if grid.shape[2:4] != output.shape[2:4]:
+            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
+            grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
+            self.grids[k] = grid
+
+        output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
+        output = output.permute(0, 1, 3, 4, 2).reshape(batch_size, self.n_anchors * hsize * wsize, -1)
+        grid = grid.view(1, -1, 2)
+        output[..., :2] = (output[..., :2] + grid) * stride
+        output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
+        return output, grid
+
+    def decode_outputs(self, outputs, dtype):
+        grids = []
+        strides = []
+        for (hsize, wsize), stride in zip(self.hw, self.strides):
+            yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)], indexing="ij")
+            grid = torch.stack((xv, yv), 2).view(1, -1, 2)
+            grids.append(grid)
+            shape = grid.shape[:2]
+            strides.append(torch.full((*shape, 1), stride))
+
+        grids = torch.cat(grids, dim=1).type(dtype)
+        strides = torch.cat(strides, dim=1).type(dtype)
+
+        outputs[..., :2] = (outputs[..., :2] + grids) * strides
+        outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
+        return outputs
+
+    def get_losses(
+        self,
+        imgs,
+        x_shifts,
+        y_shifts,
+        expanded_strides,
+        labels,
+        outputs,
+        origin_preds,
+        dtype,
+    ):
+        bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]
+        obj_preds = outputs[:, :, 4].unsqueeze(-1)  # [batch, n_anchors_all, 1]
+        cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]
+
+        # calculate targets
+        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects
+
+        total_num_anchors = outputs.shape[1]
+        x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]
+        y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]
+        expanded_strides = torch.cat(expanded_strides, 1)
+        if self.use_l1:
+            origin_preds = torch.cat(origin_preds, 1)
+
+        cls_targets = []
+        reg_targets = []
+        l1_targets = []
+        obj_targets = []
+        fg_masks = []
+
+        num_fg = 0.0
+        num_gts = 0.0
+
+        for batch_idx in range(outputs.shape[0]):
+            num_gt = int(nlabel[batch_idx])
+            num_gts += num_gt
+            if num_gt == 0:
+                cls_target = outputs.new_zeros((0, self.num_classes))
+                reg_target = outputs.new_zeros((0, 4))
+                l1_target = outputs.new_zeros((0, 4))
+                obj_target = outputs.new_zeros((total_num_anchors, 1))
+                fg_mask = outputs.new_zeros(total_num_anchors).bool()
+            else:
+                gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
+                gt_classes = labels[batch_idx, :num_gt, 0]
+                bboxes_preds_per_image = bbox_preds[batch_idx]
+
+                try:
+                    (
+                        gt_matched_classes,
+                        fg_mask,
+                        pred_ious_this_matching,
+                        matched_gt_inds,
+                        num_fg_img,
+                    ) = self.get_assignments(  # noqa
+                        batch_idx,
+                        num_gt,
+                        total_num_anchors,
+                        gt_bboxes_per_image,
+                        gt_classes,
+                        bboxes_preds_per_image,
+                        expanded_strides,
+                        x_shifts,
+                        y_shifts,
+                        cls_preds,
+                        bbox_preds,
+                        obj_preds,
+                        labels,
+                        imgs,
+                    )
+                except RuntimeError as e:
+                    logger.error("GPU mode failed: {}, changed to CPU mode.".format(e))
+                    torch.cuda.empty_cache()
+                    (
+                        gt_matched_classes,
+                        fg_mask,
+                        pred_ious_this_matching,
+                        matched_gt_inds,
+                        num_fg_img,
+                    ) = self.get_assignments(  # noqa
+                        batch_idx,
+                        num_gt,
+                        total_num_anchors,
+                        gt_bboxes_per_image,
+                        gt_classes,
+                        bboxes_preds_per_image,
+                        expanded_strides,
+                        x_shifts,
+                        y_shifts,
+                        cls_preds,
+                        bbox_preds,
+                        obj_preds,
+                        labels,
+                        imgs,
+                        mode="cpu",
+                        # mode='gpu',
+                    )
+
+                torch.cuda.empty_cache()
+                num_fg += num_fg_img
+
+                cls_target = F.one_hot(
+                    gt_matched_classes.to(torch.int64), self.num_classes
+                ) * pred_ious_this_matching.unsqueeze(-1)
+                obj_target = fg_mask.unsqueeze(-1)
+                reg_target = gt_bboxes_per_image[matched_gt_inds]
+                if self.use_l1:
+                    l1_target = self.get_l1_target(
+                        outputs.new_zeros((num_fg_img, 4)),
+                        gt_bboxes_per_image[matched_gt_inds],
+                        expanded_strides[0][fg_mask],
+                        x_shifts=x_shifts[0][fg_mask],
+                        y_shifts=y_shifts[0][fg_mask],
+                    )
+
+            cls_targets.append(cls_target)
+            reg_targets.append(reg_target)
+            obj_targets.append(obj_target.to(dtype))
+            fg_masks.append(fg_mask)
+            if self.use_l1:
+                l1_targets.append(l1_target)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        reg_targets = torch.cat(reg_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        if self.use_l1:
+            l1_targets = torch.cat(l1_targets, 0)
+
+        num_fg = max(num_fg, 1)
+        loss_iou = (self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)).sum() / num_fg
+
+        if self.obj_loss_type == "bce":
+            loss_obj = (self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)).sum() / num_fg
+        else:
+            raise NotImplementedError("Unknown obj_loss_type: {}".format(self.obj_loss_type))
+
+        if self.cls_loss_type == "bce":
+            loss_cls = self.bcewithlog_loss(cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets).sum() / num_fg
+        elif self.cls_loss_type == "focal":
+            loss_cls = (
+                sigmoid_focal_loss(
+                    cls_preds.view(-1, self.num_classes)[fg_masks],
+                    cls_targets,
+                    alpha=self.fl_alpha,
+                    gamma=self.fl_gamma,
+                    reduction="none",
+                ).sum()
+                / num_fg
+            )
+        else:
+            raise NotImplementedError("Unknown cls_loss_type: {}".format(self.cls_loss_type))
+
+        if self.use_l1:
+            loss_l1 = (self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
+        else:
+            loss_l1 = 0.0
+
+        reg_weight = 5.0  # TODO: config this
+        # loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
+        loss_dict = {
+            "loss_iou": reg_weight * loss_iou,
+            "loss_conf": loss_obj,
+            "loss_l1": loss_l1,
+            "loss_cls": loss_cls,
+        }
+        out_dict = {
+            "num_fg": num_fg / max(num_gts, 1),
+        }
+        return out_dict, loss_dict
+
+    def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
+        l1_target[:, 0] = gt[:, 0] / stride - x_shifts
+        l1_target[:, 1] = gt[:, 1] / stride - y_shifts
+        l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
+        l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
+        return l1_target
+
+    @torch.no_grad()
+    def get_assignments(
+        self,
+        batch_idx,
+        num_gt,
+        total_num_anchors,
+        gt_bboxes_per_image,
+        gt_classes,
+        bboxes_preds_per_image,
+        expanded_strides,
+        x_shifts,
+        y_shifts,
+        cls_preds,
+        bbox_preds,
+        obj_preds,
+        labels,
+        imgs,
+        mode="gpu",
+    ):
+
+        if mode == "cpu":
+            print("------------CPU Mode for This Batch-------------")
+            gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
+            bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
+            gt_classes = gt_classes.cpu().float()
+            expanded_strides = expanded_strides.cpu().float()
+            x_shifts = x_shifts.cpu()
+            y_shifts = y_shifts.cpu()
+
+        fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
+            gt_bboxes_per_image,
+            expanded_strides,
+            x_shifts,
+            y_shifts,
+            total_num_anchors,
+            num_gt,
+        )
+
+        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
+        cls_preds_ = cls_preds[batch_idx][fg_mask]
+        obj_preds_ = obj_preds[batch_idx][fg_mask]
+        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
+
+        if mode == "cpu":
+            gt_bboxes_per_image = gt_bboxes_per_image.cpu()
+            bboxes_preds_per_image = bboxes_preds_per_image.cpu()
+
+        pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
+
+        gt_cls_per_image = (
+            F.one_hot(gt_classes.to(torch.int64), self.num_classes)
+            .float()
+            .unsqueeze(1)
+            .repeat(1, num_in_boxes_anchor, 1)
+        )
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
+
+        if mode == "cpu":
+            cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.float().cpu()
+
+        with autocast(False):
+            cls_preds_ = (
+                cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+                * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+            )
+            pair_wise_cls_loss = F.binary_cross_entropy(cls_preds_.sqrt_(), gt_cls_per_image, reduction="none").sum(-1)
+
+        del cls_preds_
+
+        cost = pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center)
+
+        (
+            num_fg,
+            gt_matched_classes,
+            pred_ious_this_matching,
+            matched_gt_inds,
+        ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
+        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
+
+        if mode == "cpu":
+            gt_matched_classes = gt_matched_classes.cuda()
+            fg_mask = fg_mask.cuda()
+            pred_ious_this_matching = pred_ious_this_matching.cuda()
+            matched_gt_inds = matched_gt_inds.cuda()
+
+        return (
+            gt_matched_classes,
+            fg_mask,
+            pred_ious_this_matching,
+            matched_gt_inds,
+            num_fg,
+        )
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes_per_image,
+        expanded_strides,
+        x_shifts,
+        y_shifts,
+        total_num_anchors,
+        num_gt,
+    ):
+        expanded_strides_per_image = expanded_strides[0]
+        x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
+        y_shifts_per_image = y_shifts[0] * expanded_strides_per_image
+        x_centers_per_image = (
+            (x_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
+        )  # [n_anchor] -> [n_gt, n_anchor]
+        y_centers_per_image = (y_shifts_per_image + 0.5 * expanded_strides_per_image).unsqueeze(0).repeat(num_gt, 1)
+
+        gt_bboxes_per_image_l = (
+            (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
+        )
+        gt_bboxes_per_image_r = (
+            (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]).unsqueeze(1).repeat(1, total_num_anchors)
+        )
+        gt_bboxes_per_image_t = (
+            (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
+        )
+        gt_bboxes_per_image_b = (
+            (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]).unsqueeze(1).repeat(1, total_num_anchors)
+        )
+
+        b_l = x_centers_per_image - gt_bboxes_per_image_l
+        b_r = gt_bboxes_per_image_r - x_centers_per_image
+        b_t = y_centers_per_image - gt_bboxes_per_image_t
+        b_b = gt_bboxes_per_image_b - y_centers_per_image
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+
+        center_radius = 2.5
+
+        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
+            1, total_num_anchors
+        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
+        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
+            1, total_num_anchors
+        ) + center_radius * expanded_strides_per_image.unsqueeze(0)
+        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
+            1, total_num_anchors
+        ) - center_radius * expanded_strides_per_image.unsqueeze(0)
+        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
+            1, total_num_anchors
+        ) + center_radius * expanded_strides_per_image.unsqueeze(0)
+
+        c_l = x_centers_per_image - gt_bboxes_per_image_l
+        c_r = gt_bboxes_per_image_r - x_centers_per_image
+        c_t = y_centers_per_image - gt_bboxes_per_image_t
+        c_b = gt_bboxes_per_image_b - y_centers_per_image
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        return is_in_boxes_anchor, is_in_boxes_and_center
+
+    def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(10, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx], largest=False)
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        num_fg = fg_mask_inboxes.sum().item()
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        gt_matched_classes = gt_classes[matched_gt_inds]
+
+        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[fg_mask_inboxes]
+        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
diff --git a/det/yolox/models/yolo_pafpn.py b/det/yolox/models/yolo_pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6cd08397939de24b02d48b7107a4a5ace94d032
--- /dev/null
+++ b/det/yolox/models/yolo_pafpn.py
@@ -0,0 +1,109 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from .darknet import CSPDarknet
+from .network_blocks import BaseConv, CSPLayer, DWConv
+
+
+class YOLOPAFPN(nn.Module):
+    """YOLOv3 model.
+
+    Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self,
+        depth=1.0,
+        width=1.0,
+        in_features=("dark3", "dark4", "dark5"),
+        in_channels=[256, 512, 1024],
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
+        self.in_features = in_features
+        self.in_channels = in_channels
+        Conv = DWConv if depthwise else BaseConv
+
+        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+        self.lateral_conv0 = BaseConv(int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act)
+        self.C3_p4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )  # cat
+
+        self.reduce_conv1 = BaseConv(int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act)
+        self.C3_p3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[0] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv2 = Conv(int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act)
+        self.C3_n3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv1 = Conv(int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act)
+        self.C3_n4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[2] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+    def forward(self, input):
+        """
+        Args:
+            inputs: input images.
+
+        Returns:
+            Tuple[Tensor]: FPN feature.
+        """
+
+        #  backbone
+        out_features = self.backbone(input)
+        features = [out_features[f] for f in self.in_features]
+        [x2, x1, x0] = features
+
+        fpn_out0 = self.lateral_conv0(x0)  # 1024->512/32
+        f_out0 = self.upsample(fpn_out0)  # 512/16
+        f_out0 = torch.cat([f_out0, x1], 1)  # 512->1024/16
+        f_out0 = self.C3_p4(f_out0)  # 1024->512/16
+
+        fpn_out1 = self.reduce_conv1(f_out0)  # 512->256/16
+        f_out1 = self.upsample(fpn_out1)  # 256/8
+        f_out1 = torch.cat([f_out1, x2], 1)  # 256->512/8
+        pan_out2 = self.C3_p3(f_out1)  # 512->256/8
+
+        p_out1 = self.bu_conv2(pan_out2)  # 256->256/16
+        p_out1 = torch.cat([p_out1, fpn_out1], 1)  # 256->512/16
+        pan_out1 = self.C3_n3(p_out1)  # 512->512/16
+
+        p_out0 = self.bu_conv1(pan_out1)  # 512->512/32
+        p_out0 = torch.cat([p_out0, fpn_out0], 1)  # 512->1024/32
+        pan_out0 = self.C3_n4(p_out0)  # 1024->1024/32
+
+        outputs = (pan_out2, pan_out1, pan_out0)
+        return outputs
diff --git a/det/yolox/models/yolox.py b/det/yolox/models/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..1245110e549fc7c817e92b3a226e342bcf16bb40
--- /dev/null
+++ b/det/yolox/models/yolox.py
@@ -0,0 +1,100 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+from torch.nn.modules import loss
+
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+from det.yolox.utils.model_utils import scale_img
+
+
+class YOLOX(nn.Module):
+    """YOLOX model module.
+
+    The module list is defined by create_yolov3_modules function. The
+    network returns loss values from three YOLO layers during training
+    and detection results during test.
+    """
+
+    def __init__(self, backbone=None, head=None):
+        super().__init__()
+        if backbone is None:
+            backbone = YOLOPAFPN()
+        if head is None:
+            head = YOLOXHead(80)
+
+        self.backbone = backbone
+        self.head = head
+
+        self.init_yolo()
+
+    def init_yolo(self):
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03
+        self.head.initialize_biases(prior_prob=0.01)
+
+    # def forward(self, x, targets=None):
+    #     # fpn output content features of [dark3, dark4, dark5]
+    #     fpn_outs = self.backbone(x)
+
+    #     if self.training:
+    #         assert targets is not None
+    #         outputs, loss_dict = self.head(fpn_outs, targets, x)
+    #         return outputs, loss_dict
+    #     else:
+    #         outputs = self.head(fpn_outs)
+    #         return outputs
+
+    def forward(self, x, targets=None, augment=False, cfg=None):
+        if augment:
+            assert not self.training, "multiscale training is not implemented"
+            img_size = x.shape[-2:]
+            scales = cfg.scales
+            # flips = [None, 3, None]  # flips (2-ud, 3-lr)
+            det_preds = []
+            # for si, fi in zip(scales, flips):
+            for si in scales:
+                # xi = scale_img(x.flip(fi) if fi else x, si, gs=32)
+                xi = scale_img(x, si, gs=32)
+                yi = self.forward_once(xi, targets)
+                yi["det_preds"][:, :, :4] /= si  # de-scale
+                # if fi == 2:
+                #     yi["det_preds"][:, :, 1] = img_size[0] - yi["det_preds"][:, :, 1]  # de-flip ud
+                # elif fi == 3:
+                #     yi["det_preds"][:, :, 0] = img_size[1] - yi["det_preds"][:, :, 0]  # de-flip lr
+                # adaptive small medium large objects
+                # if si < 1:
+                #     yi["det_preds"][:, :, 5] = torch.where(
+                #         yi["det_preds"][:, :, 2] * yi["det_preds"][:, :, 3] < 96 * 96,
+                #         yi["det_preds"][:, :, 5] * 0.6,
+                #         yi["det_preds"][:, :, 5]
+                #     )
+                # elif si > 1:
+                #     yi["det_preds"][:, :, 5] = torch.where(
+                #         yi["det_preds"][:, :, 2] * yi["det_preds"][:, :, 3] > 32 * 32,
+                #         yi["det_preds"][:, :, 5] * 0.6,
+                #         yi["det_preds"][:, :, 5]
+                #     )
+                det_preds.append(yi["det_preds"])
+            det_preds = torch.cat(det_preds, 1)
+            outputs = dict(det_preds=det_preds)
+            return outputs  # augmented inference, train #TODO multi-scale train
+        else:
+            return self.forward_once(x, targets)
+
+    def forward_once(self, x, targets=None):
+        # fpn output content features of [dark3, dark4, dark5]
+        fpn_outs = self.backbone(x)
+
+        if self.training:
+            assert targets is not None
+            outputs, loss_dict = self.head(fpn_outs, targets, x)
+            return outputs, loss_dict
+        else:
+            outputs = self.head(fpn_outs)
+            return outputs
diff --git a/det/yolox/tools/convert_trt.py b/det/yolox/tools/convert_trt.py
new file mode 100644
index 0000000000000000000000000000000000000000..b50d66eb654dd4340d9b23096d3eebb3b2884d7d
--- /dev/null
+++ b/det/yolox/tools/convert_trt.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+import os.path as osp
+import sys
+import logging
+from loguru import logger
+
+import tensorrt as trt
+import torch
+from torch2trt import torch2trt
+
+from detectron2.engine import default_argument_parser
+from detectron2.config import LazyConfig, instantiate
+
+cur_dir = osp.abspath(osp.dirname(__file__))
+sys.path.insert(0, osp.join(cur_dir, "../../../"))
+from core.utils.my_checkpoint import MyCheckpointer
+from det.yolox.engine.yolox_setup import default_yolox_setup
+from det.yolox.engine.yolox_trainer import YOLOX_DefaultTrainer
+
+
+def setup(args):
+    """Create configs and perform basic setups."""
+    cfg = LazyConfig.load(args.config_file)
+    cfg = LazyConfig.apply_overrides(cfg, args.opts)
+    default_yolox_setup(cfg, args)
+    return cfg
+
+
+@logger.catch
+def main(args):
+    cfg = setup(args)
+    Trainer = YOLOX_DefaultTrainer
+    model = Trainer.build_model(cfg)
+
+    ckpt_file = args.ckpt
+    MyCheckpointer(model).load(ckpt_file)
+    logger.info("loaded checkpoint done.")
+
+    model.eval()
+    model.head.decode_in_inference = False
+    x = torch.ones(1, 3, cfg.test.test_size[0], cfg.test.test_size[1]).cuda()
+    model_trt = torch2trt(
+        model,
+        [x],
+        fp16_mode=True,
+        log_level=trt.Logger.INFO,
+        max_workspace_size=(1 << 32),
+    )
+
+    filename_wo_ext, ext = osp.splitext(ckpt_file)
+    trt_file = filename_wo_ext + "_trt" + ext
+    torch.save(model_trt.state_dict(), trt_file)
+    logger.info("Converted TensorRT model done.")
+
+    engine_file = filename_wo_ext + "_trt.engine"
+    with open(engine_file, "wb") as f:
+        f.write(model_trt.engine.serialize())
+    logger.info("Converted TensorRT model engine file is saved for C++ inference.")
+
+
+if __name__ == "__main__":
+    """python det/yolox/tools/convert_trt.py --config-file <path/to/cfg.py>
+
+    --ckpt <path/to/ckpt.pth>
+    """
+    parser = default_argument_parser()
+    parser.add_argument("--ckpt", type=str, help="ckpt path")
+    args = parser.parse_args()
+    main(args)
diff --git a/det/yolox/tools/demo.py b/det/yolox/tools/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..882beef12d7352ce8bd4b0d30efbeecc3090758c
--- /dev/null
+++ b/det/yolox/tools/demo.py
@@ -0,0 +1,303 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import argparse
+import os
+import time
+from loguru import logger
+
+import cv2
+
+import torch
+
+from yolox.data.data_augment import ValTransform
+from det.yolox.data.datasets import COCO_CLASSES
+from det.yolox.exp import get_exp
+from det.yolox.utils import fuse_model, get_model_info, postprocess, vis
+
+IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Demo!")
+    parser.add_argument("demo", default="image", help="demo type, eg. image, video and webcam")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    parser.add_argument("--path", default="./assets/dog.jpg", help="path to images or video")
+    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
+    parser.add_argument(
+        "--save_result",
+        action="store_true",
+        help="whether to save the inference result of image/video",
+    )
+
+    # exp file
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="pls input your experiment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument(
+        "--device",
+        default="cpu",
+        type=str,
+        help="device to run our model, can either be cpu or gpu",
+    )
+    parser.add_argument("--conf", default=0.3, type=float, help="test conf")
+    parser.add_argument("--nms", default=0.3, type=float, help="test nms threshold")
+    parser.add_argument("--tsize", default=None, type=int, help="test img size")
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision evaluating.",
+    )
+    parser.add_argument(
+        "--legacy",
+        dest="legacy",
+        default=False,
+        action="store_true",
+        help="To be compatible with older versions",
+    )
+    parser.add_argument(
+        "--fuse",
+        dest="fuse",
+        default=False,
+        action="store_true",
+        help="Fuse conv and bn for testing.",
+    )
+    parser.add_argument(
+        "--trt",
+        dest="trt",
+        default=False,
+        action="store_true",
+        help="Using TensorRT model for testing.",
+    )
+    return parser
+
+
+def get_image_list(path):
+    image_names = []
+    for maindir, subdir, file_name_list in os.walk(path):
+        for filename in file_name_list:
+            apath = os.path.join(maindir, filename)
+            ext = os.path.splitext(apath)[1]
+            if ext in IMAGE_EXT:
+                image_names.append(apath)
+    return image_names
+
+
+class Predictor(object):
+    def __init__(
+        self,
+        model,
+        exp,
+        cls_names=COCO_CLASSES,
+        trt_file=None,
+        decoder=None,
+        device="cpu",
+        fp16=False,
+        legacy=False,
+    ):
+        self.model = model
+        self.cls_names = cls_names
+        self.decoder = decoder
+        self.num_classes = exp.num_classes
+        self.confthre = exp.test_conf
+        self.nmsthre = exp.nmsthre
+        self.test_size = exp.test_size
+        self.device = device
+        self.fp16 = fp16
+        self.preproc = ValTransform(legacy=legacy)
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
+            self.model(x)
+            self.model = model_trt
+
+    def inference(self, img):
+        img_info = {"id": 0}
+        if isinstance(img, str):
+            img_info["file_name"] = os.path.basename(img)
+            img = cv2.imread(img)
+        else:
+            img_info["file_name"] = None
+
+        height, width = img.shape[:2]
+        img_info["height"] = height
+        img_info["width"] = width
+        img_info["raw_img"] = img
+
+        ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
+        img_info["ratio"] = ratio
+
+        img, _ = self.preproc(img, None, self.test_size)
+        img = torch.from_numpy(img).unsqueeze(0)
+        img = img.float()
+        if self.device == "gpu":
+            img = img.cuda()
+            if self.fp16:
+                img = img.half()  # to FP16
+
+        with torch.no_grad():
+            t0 = time.time()
+            outputs = self.model(img)
+            if self.decoder is not None:
+                outputs = self.decoder(outputs, dtype=outputs.type())
+            outputs = postprocess(
+                outputs,
+                self.num_classes,
+                self.confthre,
+                self.nmsthre,
+                class_agnostic=True,
+            )
+            logger.info("Infer time: {:.4f}s".format(time.time() - t0))
+        return outputs, img_info
+
+    def visual(self, output, img_info, cls_conf=0.35):
+        ratio = img_info["ratio"]
+        img = img_info["raw_img"]
+        if output is None:
+            return img
+        output = output.cpu()
+
+        bboxes = output[:, 0:4]
+
+        # preprocessing: resize
+        bboxes /= ratio
+
+        cls = output[:, 6]
+        scores = output[:, 4] * output[:, 5]
+
+        vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
+        return vis_res
+
+
+def image_demo(predictor, vis_folder, path, current_time, save_result):
+    if os.path.isdir(path):
+        files = get_image_list(path)
+    else:
+        files = [path]
+    files.sort()
+    for image_name in files:
+        outputs, img_info = predictor.inference(image_name)
+        result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
+        if save_result:
+            save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
+            os.makedirs(save_folder, exist_ok=True)
+            save_file_name = os.path.join(save_folder, os.path.basename(image_name))
+            logger.info("Saving detection result in {}".format(save_file_name))
+            cv2.imwrite(save_file_name, result_image)
+        ch = cv2.waitKey(0)
+        if ch == 27 or ch == ord("q") or ch == ord("Q"):
+            break
+
+
+def imageflow_demo(predictor, vis_folder, current_time, args):
+    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
+    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
+    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
+    os.makedirs(save_folder, exist_ok=True)
+    if args.demo == "video":
+        save_path = os.path.join(save_folder, args.path.split("/")[-1])
+    else:
+        save_path = os.path.join(save_folder, "camera.mp4")
+    logger.info(f"video save_path is {save_path}")
+    vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)))
+    while True:
+        ret_val, frame = cap.read()
+        if ret_val:
+            outputs, img_info = predictor.inference(frame)
+            result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
+            if args.save_result:
+                vid_writer.write(result_frame)
+            ch = cv2.waitKey(1)
+            if ch == 27 or ch == ord("q") or ch == ord("Q"):
+                break
+        else:
+            break
+
+
+def main(exp, args):
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+    os.makedirs(file_name, exist_ok=True)
+
+    vis_folder = None
+    if args.save_result:
+        vis_folder = os.path.join(file_name, "vis_res")
+        os.makedirs(vis_folder, exist_ok=True)
+
+    if args.trt:
+        args.device = "gpu"
+
+    logger.info("Args: {}".format(args))
+
+    if args.conf is not None:
+        exp.test_conf = args.conf
+    if args.nms is not None:
+        exp.nmsthre = args.nms
+    if args.tsize is not None:
+        exp.test_size = (args.tsize, args.tsize)
+
+    model = exp.get_model()
+    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
+
+    if args.device == "gpu":
+        model.cuda()
+        if args.fp16:
+            model.half()  # to FP16
+    model.eval()
+
+    if not args.trt:
+        if args.ckpt is None:
+            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+        else:
+            ckpt_file = args.ckpt
+        logger.info("loading checkpoint")
+        ckpt = torch.load(ckpt_file, map_location="cpu")
+        # load the model state dict
+        model.load_state_dict(ckpt["model"])
+        logger.info("loaded checkpoint done.")
+
+    if args.fuse:
+        logger.info("\tFusing model...")
+        model = fuse_model(model)
+
+    if args.trt:
+        assert not args.fuse, "TensorRT model is not support model fusing!"
+        trt_file = os.path.join(file_name, "model_trt.pth")
+        assert os.path.exists(trt_file), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
+        model.head.decode_in_inference = False
+        decoder = model.head.decode_outputs
+        logger.info("Using TensorRT to inference")
+    else:
+        trt_file = None
+        decoder = None
+
+    predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy)
+    current_time = time.localtime()
+    if args.demo == "image":
+        image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
+    elif args.demo == "video" or args.demo == "webcam":
+        imageflow_demo(predictor, vis_folder, current_time, args)
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+
+    main(exp, args)
diff --git a/det/yolox/tools/eval.py b/det/yolox/tools/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e880aba9dded07138bd02a42a571d72f2eefb7b
--- /dev/null
+++ b/det/yolox/tools/eval.py
@@ -0,0 +1,208 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+import os.path as osp
+import sys
+
+import random
+import warnings
+from loguru import logger
+
+import torch
+import torch.backends.cudnn as cudnn
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+sys.path.insert(0, osp.join(cur_dir, "../../../"))
+from det.yolox.engine.launch import launch
+from det.yolox.exp import get_exp
+from det.yolox.utils import (
+    configure_nccl,
+    fuse_model,
+    get_local_rank,
+    get_model_info,
+    setup_logger,
+)
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Eval")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    # distributed
+    parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
+    parser.add_argument(
+        "--dist-url",
+        default=None,
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
+    parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")
+    parser.add_argument("--num_machines", default=1, type=int, help="num of node for training")
+    parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="pls input your expriment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument("--conf", default=None, type=float, help="test conf")
+    parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
+    parser.add_argument("--tsize", default=None, type=int, help="test img size")
+    parser.add_argument("--seed", default=None, type=int, help="eval seed")
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision evaluating.",
+    )
+    parser.add_argument(
+        "--fuse",
+        dest="fuse",
+        default=False,
+        action="store_true",
+        help="Fuse conv and bn for testing.",
+    )
+    parser.add_argument(
+        "--trt",
+        dest="trt",
+        default=False,
+        action="store_true",
+        help="Using TensorRT model for testing.",
+    )
+    parser.add_argument(
+        "--legacy",
+        dest="legacy",
+        default=False,
+        action="store_true",
+        help="To be compatible with older versions",
+    )
+    parser.add_argument(
+        "--test",
+        dest="test",
+        default=False,
+        action="store_true",
+        help="Evaluating on test-dev set.",
+    )
+    parser.add_argument(
+        "--speed",
+        dest="speed",
+        default=False,
+        action="store_true",
+        help="speed test only.",
+    )
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+@logger.catch
+def main(exp, args, num_gpu):
+    if args.seed is not None:
+        random.seed(args.seed)
+        torch.manual_seed(args.seed)
+        cudnn.deterministic = True
+        warnings.warn("You have chosen to seed testing. This will turn on the CUDNN deterministic setting, ")
+
+    is_distributed = num_gpu > 1
+
+    # set environment variables for distributed training
+    configure_nccl()
+    cudnn.benchmark = True
+
+    rank = get_local_rank()
+
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+
+    if rank == 0:
+        os.makedirs(file_name, exist_ok=True)
+
+    setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
+    logger.info("Args: {}".format(args))
+
+    if args.conf is not None:
+        exp.test_conf = args.conf
+    if args.nms is not None:
+        exp.nmsthre = args.nms
+    if args.tsize is not None:
+        exp.test_size = (args.tsize, args.tsize)
+
+    model = exp.get_model()
+    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
+    logger.info("Model Structure:\n{}".format(str(model)))
+
+    evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test, args.legacy)
+
+    torch.cuda.set_device(rank)
+    model.cuda(rank)
+    model.eval()
+
+    if not args.speed and not args.trt:
+        if args.ckpt is None:
+            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+        else:
+            ckpt_file = args.ckpt
+        logger.info("loading checkpoint from {}".format(ckpt_file))
+        loc = "cuda:{}".format(rank)
+        ckpt = torch.load(ckpt_file, map_location=loc)
+        model.load_state_dict(ckpt["model"])
+        logger.info("loaded checkpoint done.")
+
+    if is_distributed:
+        model = DDP(model, device_ids=[rank])
+
+    if args.fuse:
+        logger.info("\tFusing model...")
+        model = fuse_model(model)
+
+    if args.trt:
+        assert (
+            not args.fuse and not is_distributed and args.batch_size == 1
+        ), "TensorRT model is not support model fusing and distributed inferencing!"
+        trt_file = os.path.join(file_name, "model_trt.pth")
+        assert os.path.exists(trt_file), "TensorRT model is not found!\n Run tools/trt.py first!"
+        model.head.decode_in_inference = False
+        decoder = model.head.decode_outputs
+    else:
+        trt_file = None
+        decoder = None
+
+    # start evaluate
+    *_, summary = evaluator.evaluate(model, is_distributed, args.fp16, trt_file, decoder, exp.test_size)
+    logger.info("\n" + summary)
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
+    assert num_gpu <= torch.cuda.device_count(), f"num_gpu: {num_gpu} device count: {torch.cuda.device_count()}"
+
+    dist_url = "auto" if args.dist_url is None else args.dist_url
+    # import ipdb; ipdb.set_trace()
+    launch(
+        main,
+        num_gpu,
+        args.num_machines,
+        args.machine_rank,
+        backend=args.dist_backend,
+        dist_url=dist_url,
+        args=(exp, args, num_gpu),
+    )
diff --git a/det/yolox/tools/export_onnx.py b/det/yolox/tools/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba5b9ae6535e08d5a5ee5b72651494d0acf2cd3
--- /dev/null
+++ b/det/yolox/tools/export_onnx.py
@@ -0,0 +1,106 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import argparse
+import os
+from loguru import logger
+
+import torch
+from torch import nn
+
+from det.yolox.exp import get_exp
+from det.yolox.models.network_blocks import SiLU
+from det.yolox.utils import replace_module
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX onnx deploy")
+    parser.add_argument("--output-name", type=str, default="yolox.onnx", help="output name of models")
+    parser.add_argument("--input", default="images", type=str, help="input node name of onnx model")
+    parser.add_argument("--output", default="output", type=str, help="output node name of onnx model")
+    parser.add_argument("-o", "--opset", default=11, type=int, help="onnx opset version")
+    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
+    parser.add_argument(
+        "--dynamic",
+        action="store_true",
+        help="whether the input shape should be dynamic or not",
+    )
+    parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="expriment description file",
+    )
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+
+    return parser
+
+
+@logger.catch
+def main():
+    args = make_parser().parse_args()
+    logger.info("args value: {}".format(args))
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    if args.ckpt is None:
+        file_name = os.path.join(exp.output_dir, args.experiment_name)
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    # load the model state dict
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+
+    model.eval()
+    if "model" in ckpt:
+        ckpt = ckpt["model"]
+    model.load_state_dict(ckpt)
+    model = replace_module(model, nn.SiLU, SiLU)
+    model.head.decode_in_inference = False
+
+    logger.info("loading checkpoint done.")
+    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
+
+    torch.onnx._export(
+        model,
+        dummy_input,
+        args.output_name,
+        input_names=[args.input],
+        output_names=[args.output],
+        dynamic_axes={args.input: {0: "batch"}, args.output: {0: "batch"}} if args.dynamic else None,
+        opset_version=args.opset,
+    )
+    logger.info("generated onnx model named {}".format(args.output_name))
+
+    if not args.no_onnxsim:
+        import onnx
+
+        from onnxsim import simplify
+
+        input_shapes = {args.input: list(dummy_input.shape)} if args.dynamic else None
+
+        # use onnxsimplify to reduce reduent model.
+        onnx_model = onnx.load(args.output_name)
+        model_simp, check = simplify(onnx_model, dynamic_input_shape=args.dynamic, input_shapes=input_shapes)
+        assert check, "Simplified ONNX model could not be validated"
+        onnx.save(model_simp, args.output_name)
+        logger.info("generated simplified onnx model named {}".format(args.output_name))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/det/yolox/tools/export_torchscript.py b/det/yolox/tools/export_torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..16424a28fd87e740bfe8568901d4c49310c9d0d7
--- /dev/null
+++ b/det/yolox/tools/export_torchscript.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+from loguru import logger
+
+import torch
+
+from det.yolox.exp import get_exp
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX torchscript deploy")
+    parser.add_argument(
+        "--output-name",
+        type=str,
+        default="yolox.torchscript.pt",
+        help="output name of models",
+    )
+    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="expriment description file",
+    )
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+
+    return parser
+
+
+@logger.catch
+def main():
+    args = make_parser().parse_args()
+    logger.info("args value: {}".format(args))
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    if args.ckpt is None:
+        file_name = os.path.join(exp.output_dir, args.experiment_name)
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    # load the model state dict
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+
+    model.eval()
+    if "model" in ckpt:
+        ckpt = ckpt["model"]
+    model.load_state_dict(ckpt)
+    model.head.decode_in_inference = False
+
+    logger.info("loading checkpoint done.")
+    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
+
+    mod = torch.jit.trace(model, dummy_input)
+    mod.save(args.output_name)
+    logger.info("generated torchscript model named {}".format(args.output_name))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/det/yolox/tools/main_yolox.py b/det/yolox/tools/main_yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..9190b49d85ade773ae3f0f8e67de9e090c0c773b
--- /dev/null
+++ b/det/yolox/tools/main_yolox.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+import logging
+from loguru import logger as loguru_logger
+import os.path as osp
+from setproctitle import setproctitle
+from detectron2.engine import (
+    default_argument_parser,
+    launch,
+)
+from detectron2.engine.defaults import create_ddp_model
+from detectron2.config import LazyConfig, instantiate
+
+import cv2
+
+cv2.setNumThreads(0)  # pytorch issue 1355: possible deadlock in dataloader
+# OpenCL may be enabled by default in OpenCV3; disable it because it's not
+# thread safe and causes unwanted GPU memory allocations.
+cv2.ocl.setUseOpenCL(False)
+
+import sys
+
+cur_dir = osp.dirname(osp.abspath(__file__))
+sys.path.insert(0, osp.join(cur_dir, "../../../"))
+
+from lib.utils.time_utils import get_time_str
+import core.utils.my_comm as comm
+from core.utils.my_checkpoint import MyCheckpointer
+from det.yolox.engine.yolox_setup import default_yolox_setup
+from det.yolox.engine.yolox_trainer import YOLOX_DefaultTrainer
+from det.yolox.utils import fuse_model
+from det.yolox.data.datasets.dataset_factory import register_datasets_in_cfg
+
+
+logger = logging.getLogger("detectron2")
+
+
+def setup(args):
+    """Create configs and perform basic setups."""
+    cfg = LazyConfig.load(args.config_file)
+    cfg = LazyConfig.apply_overrides(cfg, args.opts)
+    default_yolox_setup(cfg, args)
+    register_datasets_in_cfg(cfg)
+    setproctitle("{}.{}".format(cfg.train.exp_name, get_time_str()))
+    return cfg
+
+
+@loguru_logger.catch
+def main(args):
+    cfg = setup(args)
+    Trainer = YOLOX_DefaultTrainer
+    if args.eval_only:  # eval
+        model = Trainer.build_model(cfg)
+        MyCheckpointer(model, save_dir=cfg.train.output_dir).resume_or_load(
+            cfg.train.init_checkpoint, resume=args.resume
+        )
+        if cfg.test.fuse_conv_bn:
+            logger.info("\tFusing conv bn...")
+            model = fuse_model(model)
+        res = Trainer.test(cfg, model)
+        # import ipdb; ipdb.set_trace()
+        return res
+    # train
+    trainer = Trainer(cfg)
+    trainer.resume_or_load(resume=args.resume)
+    return trainer.train()
+
+
+if __name__ == "__main__":
+    args = default_argument_parser().parse_args()
+    launch(
+        main,
+        args.num_gpus,
+        num_machines=args.num_machines,
+        machine_rank=args.machine_rank,
+        dist_url=args.dist_url,
+        args=(args,),
+    )
diff --git a/det/yolox/tools/plain_main_yolox.py b/det/yolox/tools/plain_main_yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e89c29fa67a51389a01f72781372e7cf328ffc
--- /dev/null
+++ b/det/yolox/tools/plain_main_yolox.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+import logging
+import os.path as osp
+
+from detectron2.config import get_cfg
+from detectron2.data import build_detection_test_loader, build_detection_train_loader
+from detectron2.engine import (
+    default_argument_parser,
+    launch,
+)
+from detectron2.engine.defaults import create_ddp_model
+from detectron2.config import LazyConfig, instantiate
+
+from lib.utils.setup_logger import setup_my_logger
+import core.utils.my_comm as comm
+from core.utils.my_checkpoint import MyCheckpointer
+from det.yolox.engine.yolox_train_test_plain import do_train_yolox, do_test_yolox
+from det.yolox.engine.yolox_setup import default_yolox_setup
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger("detectron2")
+
+
+def setup(args):
+    """Create configs and perform basic setups."""
+    cfg = LazyConfig.load(args.config_file)
+    cfg = LazyConfig.apply_overrides(cfg, args.opts)
+    default_yolox_setup(cfg, args)
+    return cfg
+
+
+def main(args):
+    cfg = setup(args)
+    if args.eval_only:  # eval
+        model = instantiate(cfg.model)
+        model.to(cfg.train.device)
+        model = create_ddp_model(model)
+        MyCheckpointer(model, save_dir=cfg.train.output_dir).resume_or_load(
+            cfg.train.init_checkpoint, resume=args.resume
+        )
+        print(do_test_yolox(cfg, model))
+    else:  # train
+        do_train_yolox(args, cfg)
+
+
+if __name__ == "__main__":
+    args = default_argument_parser().parse_args()
+    launch(
+        main,
+        args.num_gpus,
+        num_machines=args.num_machines,
+        machine_rank=args.machine_rank,
+        dist_url=args.dist_url,
+        args=(args,),
+    )
diff --git a/det/yolox/tools/test_yolox.sh b/det/yolox/tools/test_yolox.sh
new file mode 100755
index 0000000000000000000000000000000000000000..dfdc1d2262ad226fb9cd0ea3cebd186bdd025dd4
--- /dev/null
+++ b/det/yolox/tools/test_yolox.sh
@@ -0,0 +1,29 @@
+#!/usr/bin/env bash
+# test
+set -x
+this_dir=$(dirname "$0")
+# commonly used opts:
+# train.init_checkpoint: resume or pretrained, or test checkpoint
+CFG=$1
+CUDA_VISIBLE_DEVICES=$2
+IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES"
+# GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n'))
+NGPU=${#GPUS[@]}  # echo "${GPUS[0]}"
+echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU"
+CKPT=$3
+if [ ! -f "$CKPT" ]; then
+    echo "$CKPT does not exist."
+    exit 1
+fi
+NCCL_DEBUG=INFO
+OMP_NUM_THREADS=1
+MKL_NUM_THREADS=1
+PYTHONPATH="$this_dir/../..":$PYTHONPATH \
+CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_yolox.py \
+    --config-file $CFG --num-gpus $NGPU --eval-only \
+    train.init_checkpoint=$CKPT \
+    ${@:4}
+
+# tensorboard --logdir /path/to/logdir --bind_all # --port 6007
+# to see tensorboard logs locally:
+# ssh -L 6006:localhost:6006 user@server
diff --git a/det/yolox/tools/train.py b/det/yolox/tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..50ac8094c40c0694669d8ce4e3aa0b71cc7a6a02
--- /dev/null
+++ b/det/yolox/tools/train.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import argparse
+import random
+import warnings
+from loguru import logger
+
+import torch
+import torch.backends.cudnn as cudnn
+
+from det.yolox.engine.trainer import Trainer
+from det.yolox.engine.launch import launch
+from det.yolox.exp import get_exp
+from det.yolox.utils import configure_nccl, configure_omp, get_num_devices
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX train parser")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    # distributed
+    parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")
+    parser.add_argument(
+        "--dist-url",
+        default=None,
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
+    parser.add_argument("-d", "--devices", default=None, type=int, help="device for training")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="plz input your experiment description file",
+    )
+    parser.add_argument("--resume", default=False, action="store_true", help="resume training")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
+    parser.add_argument(
+        "-e",
+        "--start_epoch",
+        default=None,
+        type=int,
+        help="resume training start epoch",
+    )
+    parser.add_argument("--num_machines", default=1, type=int, help="num of node for training")
+    parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision training.",
+    )
+    parser.add_argument(
+        "--cache",
+        dest="cache",
+        default=False,
+        action="store_true",
+        help="Caching imgs to RAM for fast training.",
+    )
+    parser.add_argument(
+        "-o",
+        "--occupy",
+        dest="occupy",
+        default=False,
+        action="store_true",
+        help="occupy GPU memory first for training.",
+    )
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+@logger.catch
+def main(exp, args):
+    if exp.seed is not None:
+        random.seed(exp.seed)
+        torch.manual_seed(exp.seed)
+        cudnn.deterministic = True
+        warnings.warn(
+            "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
+            "which can slow down your training considerably! You may see unexpected behavior "
+            "when restarting from checkpoints."
+        )
+
+    # set environment variables for distributed training
+    configure_nccl()
+    configure_omp()
+    cudnn.benchmark = True
+
+    trainer = Trainer(exp, args)
+    trainer.train()
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    num_gpu = get_num_devices() if args.devices is None else args.devices
+    assert num_gpu <= get_num_devices()
+
+    dist_url = "auto" if args.dist_url is None else args.dist_url
+    launch(
+        main,
+        num_gpu,
+        args.num_machines,
+        args.machine_rank,
+        backend=args.dist_backend,
+        dist_url=dist_url,
+        args=(exp, args),
+    )
diff --git a/det/yolox/tools/train_yolox.sh b/det/yolox/tools/train_yolox.sh
new file mode 100755
index 0000000000000000000000000000000000000000..0d058003e29742c63e8d65ce4cba85e16fb12d74
--- /dev/null
+++ b/det/yolox/tools/train_yolox.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+set -x
+this_dir=$(dirname "$0")
+# commonly used opts:
+# train.init_checkpoint: resume or pretrained, or test checkpoint
+CFG=$1
+CUDA_VISIBLE_DEVICES=$2
+IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES"
+# GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n'))
+NGPU=${#GPUS[@]}  # echo "${GPUS[0]}"
+echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU"
+# CUDA_LAUNCH_BLOCKING=1
+NCCL_DEBUG=INFO
+OMP_NUM_THREADS=1
+MKL_NUM_THREADS=1
+PYTHONPATH="$this_dir/../..":$PYTHONPATH \
+CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_yolox.py \
+    --config-file $CFG --num-gpus $NGPU  ${@:3}
diff --git a/det/yolox/tools/trt.py b/det/yolox/tools/trt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c3c28022271fbae77ddb2fe32a976b53cb0b00f
--- /dev/null
+++ b/det/yolox/tools/trt.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import argparse
+import os
+import shutil
+from loguru import logger
+
+import tensorrt as trt
+import torch
+from torch2trt import torch2trt
+
+from det.yolox.exp import get_exp
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX ncnn deploy")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="pls input your experiment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument("-w", "--workspace", type=int, default=32, help="max workspace size in detect")
+    parser.add_argument("-b", "--batch", type=int, default=1, help="max batch size in detect")
+    return parser
+
+
+@logger.catch
+def main():
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+    os.makedirs(file_name, exist_ok=True)
+    if args.ckpt is None:
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+    # load the model state dict
+
+    model.load_state_dict(ckpt["model"])
+    logger.info("loaded checkpoint done.")
+    model.eval()
+    model.cuda()
+    model.head.decode_in_inference = False
+    x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
+    model_trt = torch2trt(
+        model,
+        [x],
+        fp16_mode=True,
+        log_level=trt.Logger.INFO,
+        max_workspace_size=(1 << args.workspace),
+        max_batch_size=args.batch,
+    )
+    torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
+    logger.info("Converted TensorRT model done.")
+    engine_file = os.path.join(file_name, "model_trt.engine")
+    engine_file_demo = os.path.join("demo", "TensorRT", "cpp", "model_trt.engine")
+    with open(engine_file, "wb") as f:
+        f.write(model_trt.engine.serialize())
+
+    shutil.copyfile(engine_file, engine_file_demo)
+
+    logger.info("Converted TensorRT model engine file is saved for C++ inference.")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/det/yolox/utils/__init__.py b/det/yolox/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52da8f02d3a94f1e84dc28178386fd9d38a231a
--- /dev/null
+++ b/det/yolox/utils/__init__.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+from .allreduce_norm import *
+from .boxes import *
+from .checkpoint import load_ckpt, save_checkpoint
+from .demo_utils import *
+from .dist import *
+from .ema import *
+from .logger import setup_logger
+from .lr_scheduler import LRScheduler
+from .metric import *
+from .model_utils import *
+from .setup_env import *
+from .visualize import *
diff --git a/det/yolox/utils/allreduce_norm.py b/det/yolox/utils/allreduce_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..df3e0d924acf1dbd7d57f99d8a8d907cf2911719
--- /dev/null
+++ b/det/yolox/utils/allreduce_norm.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import pickle
+from collections import OrderedDict
+
+import torch
+from torch import distributed as dist
+from torch import nn
+
+from .dist import _get_global_gloo_group, get_world_size
+
+ASYNC_NORM = (
+    nn.BatchNorm1d,
+    nn.BatchNorm2d,
+    nn.BatchNorm3d,
+    nn.InstanceNorm1d,
+    nn.InstanceNorm2d,
+    nn.InstanceNorm3d,
+)
+
+__all__ = [
+    "get_async_norm_states",
+    "pyobj2tensor",
+    "tensor2pyobj",
+    "all_reduce",
+    "all_reduce_norm",
+]
+
+
+def get_async_norm_states(module):
+    async_norm_states = OrderedDict()
+    for name, child in module.named_modules():
+        if isinstance(child, ASYNC_NORM):
+            for k, v in child.state_dict().items():
+                async_norm_states[".".join([name, k])] = v
+    return async_norm_states
+
+
+def pyobj2tensor(pyobj, device="cuda"):
+    """serialize picklable python object to tensor."""
+    storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
+    return torch.ByteTensor(storage).to(device=device)
+
+
+def tensor2pyobj(tensor):
+    """deserialize tensor to picklable python object."""
+    return pickle.loads(tensor.cpu().numpy().tobytes())
+
+
+def _get_reduce_op(op_name):
+    return {
+        "sum": dist.ReduceOp.SUM,
+        "mean": dist.ReduceOp.SUM,
+    }[op_name.lower()]
+
+
+def all_reduce(py_dict, op="sum", group=None):
+    """
+    Apply all reduce function for python dict object.
+    NOTE: make sure that every py_dict has the same keys and values are in the same shape.
+
+    Args:
+        py_dict (dict): dict to apply all reduce op.
+        op (str): operator, could be "sum" or "mean".
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return py_dict
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return py_dict
+
+    # all reduce logic across different devices.
+    py_key = list(py_dict.keys())
+    py_key_tensor = pyobj2tensor(py_key)
+    dist.broadcast(py_key_tensor, src=0)
+    py_key = tensor2pyobj(py_key_tensor)
+
+    tensor_shapes = [py_dict[k].shape for k in py_key]
+    tensor_numels = [py_dict[k].numel() for k in py_key]
+
+    flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
+    dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
+    if op == "mean":
+        flatten_tensor /= world_size
+
+    split_tensors = [x.reshape(shape) for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)]
+    return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
+
+
+def all_reduce_norm(module):
+    """All reduce norm statistics in different devices."""
+    states = get_async_norm_states(module)
+    states = all_reduce(states, op="mean")
+    module.load_state_dict(states, strict=False)
diff --git a/det/yolox/utils/boxes.py b/det/yolox/utils/boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..197edb08c8ed11903195c0cf91d2124444ce7528
--- /dev/null
+++ b/det/yolox/utils/boxes.py
@@ -0,0 +1,136 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import logging
+import numpy as np
+
+import torch
+import torchvision
+
+__all__ = [
+    "filter_box",
+    "postprocess",
+    "bboxes_iou",
+    "matrix_iou",
+    "adjust_box_anns",
+    "xyxy2xywh",
+    "xyxy2cxcywh",
+]
+
+logger = logging.getLogger(__name__)
+
+
+def filter_box(output, scale_range):
+    """
+    output: (N, 5+class) shape
+    """
+    min_scale, max_scale = scale_range
+    w = output[:, 2] - output[:, 0]
+    h = output[:, 3] - output[:, 1]
+    keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
+    return output[keep]
+
+
+def postprocess(det_preds, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
+    box_corner = det_preds.new(det_preds.shape)
+    box_corner[:, :, 0] = det_preds[:, :, 0] - det_preds[:, :, 2] / 2
+    box_corner[:, :, 1] = det_preds[:, :, 1] - det_preds[:, :, 3] / 2
+    box_corner[:, :, 2] = det_preds[:, :, 0] + det_preds[:, :, 2] / 2
+    box_corner[:, :, 3] = det_preds[:, :, 1] + det_preds[:, :, 3] / 2
+    det_preds[:, :, :4] = box_corner[:, :, :4]
+
+    output = [None for _ in range(len(det_preds))]
+    for i, image_pred in enumerate(det_preds):
+
+        # If none are remaining => process next image
+        if not image_pred.size(0):
+            # logger.warn(f"image_pred.size: {image_pred.size(0)}")
+            continue
+        # Get score and class with highest confidence
+        class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
+
+        conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
+        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
+        detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
+        detections = detections[conf_mask]
+        if not detections.size(0):
+            # logger.warn(f"detections.size(0) {detections.size(0)} num_classes: {num_classes} conf_thr: {conf_thre} nms_thr: {nms_thre}")
+            continue
+
+        if class_agnostic:
+            nms_out_index = torchvision.ops.nms(
+                detections[:, :4],
+                detections[:, 4] * detections[:, 5],
+                nms_thre,
+            )
+        else:
+            nms_out_index = torchvision.ops.batched_nms(
+                detections[:, :4],
+                detections[:, 4] * detections[:, 5],
+                detections[:, 6],
+                nms_thre,
+            )
+
+        detections = detections[nms_out_index]
+        if output[i] is None:
+            output[i] = detections
+        else:
+            output[i] = torch.cat((output[i], detections))
+    return output
+
+
+def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
+    if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
+        raise IndexError
+
+    if xyxy:
+        tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
+        br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
+        area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
+        area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
+    else:
+        tl = torch.max(
+            (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
+            (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
+        )
+        br = torch.min(
+            (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
+            (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
+        )
+
+        area_a = torch.prod(bboxes_a[:, 2:], 1)
+        area_b = torch.prod(bboxes_b[:, 2:], 1)
+    en = (tl < br).type(tl.type()).prod(dim=2)
+    area_i = torch.prod(br - tl, 2) * en  # * ((tl < br).all())
+    return area_i / (area_a[:, None] + area_b - area_i)
+
+
+def matrix_iou(a, b):
+    """return iou of a and b, numpy version for data augenmentation."""
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+    return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12)
+
+
+def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
+    bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
+    bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
+    return bbox
+
+
+def xyxy2xywh(bboxes):
+    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+    return bboxes
+
+
+def xyxy2cxcywh(bboxes):
+    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
+    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
+    return bboxes
diff --git a/det/yolox/utils/checkpoint.py b/det/yolox/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd3b2f4954371d58b924d740647f384ec04070df
--- /dev/null
+++ b/det/yolox/utils/checkpoint.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import os
+import shutil
+from loguru import logger
+
+import torch
+
+
+def load_ckpt(model, ckpt):
+    model_state_dict = model.state_dict()
+    load_dict = {}
+    for key_model, v in model_state_dict.items():
+        if key_model not in ckpt:
+            logger.warning("{} is not in the ckpt. Please double check and see if this is desired.".format(key_model))
+            continue
+        v_ckpt = ckpt[key_model]
+        if v.shape != v_ckpt.shape:
+            logger.warning(
+                "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
+                    key_model, v_ckpt.shape, key_model, v.shape
+                )
+            )
+            continue
+        load_dict[key_model] = v_ckpt
+
+    model.load_state_dict(load_dict, strict=False)
+    return model
+
+
+def save_checkpoint(state, is_best, save_dir, model_name=""):
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    filename = os.path.join(save_dir, model_name + "_ckpt.pth")
+    torch.save(state, filename)
+    if is_best:
+        best_filename = os.path.join(save_dir, "best_ckpt.pth")
+        shutil.copyfile(filename, best_filename)
diff --git a/det/yolox/utils/convert_lmo_det.py b/det/yolox/utils/convert_lmo_det.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f2322537d11781f7045efb6406a160d1511cfa
--- /dev/null
+++ b/det/yolox/utils/convert_lmo_det.py
@@ -0,0 +1,92 @@
+from matplotlib import category
+import mmcv
+import sys
+import argparse
+import json
+import copy
+
+from detectron2.utils.file_io import PathManager
+
+parser = argparse.ArgumentParser(description="convert lmo det from lm category to lmo category")
+parser.add_argument("--input_path", type=str, default="0", help="input path")
+parser.add_argument("--out_path", type=str, default="0", help="outpur path")
+args = parser.parse_args()
+
+ds = mmcv.load(args.input_path)
+
+outs = []
+
+catid2obj = {
+    1: "ape",
+    5: "can",
+    6: "cat",
+    8: "driller",
+    9: "duck",
+    10: "eggbox",
+    11: "glue",
+    12: "holepuncher",
+}
+objects = [
+    "ape",
+    "can",
+    "cat",
+    "driller",
+    "duck",
+    "eggbox",
+    "glue",
+    "holepuncher",
+]
+obj2id = {_name: _id for _id, _name in catid2obj.items()}
+
+
+for d in ds:
+    d_new = copy.deepcopy(d)
+
+    obj_id = d_new["category_id"]
+    obj_name = objects[obj_id - 1]
+    category_id = obj2id[obj_name]
+
+    d_new["category_id"] = category_id
+
+    outs.append(d_new)
+
+with PathManager.open(args.out_path, "w") as f:
+    f.write(json.dumps(outs))
+    f.flush()
+
+
+def save_json(path, content, sort=False):
+    """Saves the provided content to a JSON file.
+
+    :param path: Path to the output JSON file.
+    :param content: Dictionary/list to save.
+    """
+    with open(path, "w") as f:
+
+        if isinstance(content, dict):
+            f.write("{\n")
+            if sort:
+                content_sorted = sorted(content.items(), key=lambda x: x[0])
+            else:
+                content_sorted = content.items()
+            for elem_id, (k, v) in enumerate(content_sorted):
+                f.write('  "{}": {}'.format(k, json.dumps(v, sort_keys=True)))
+                if elem_id != len(content) - 1:
+                    f.write(",")
+                f.write("\n")
+            f.write("}")
+
+        elif isinstance(content, list):
+            f.write("[\n")
+            for elem_id, elem in enumerate(content):
+                f.write("  {}".format(json.dumps(elem, sort_keys=True)))
+                if elem_id != len(content) - 1:
+                    f.write(",")
+                f.write("\n")
+            f.write("]")
+
+        else:
+            json.dump(content, f, sort_keys=True)
+
+
+# save_json(args.opath, outs)
diff --git a/det/yolox/utils/demo_utils.py b/det/yolox/utils/demo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1e0dfc200faba3235c38c49bedacbae93217cce
--- /dev/null
+++ b/det/yolox/utils/demo_utils.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import os
+
+import numpy as np
+
+__all__ = ["mkdir", "nms", "multiclass_nms", "demo_postprocess"]
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def nms(boxes, scores, nms_thr):
+    """Single class NMS implemented in Numpy."""
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= nms_thr)[0]
+        order = order[inds + 1]
+
+    return keep
+
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
+    """Multiclass NMS implemented in Numpy."""
+    if class_agnostic:
+        nms_method = multiclass_nms_class_agnostic
+    else:
+        nms_method = multiclass_nms_class_aware
+    return nms_method(boxes, scores, nms_thr, score_thr)
+
+
+def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy.
+
+    Class-aware version.
+    """
+    final_dets = []
+    num_classes = scores.shape[1]
+    for cls_ind in range(num_classes):
+        cls_scores = scores[:, cls_ind]
+        valid_score_mask = cls_scores > score_thr
+        if valid_score_mask.sum() == 0:
+            continue
+        else:
+            valid_scores = cls_scores[valid_score_mask]
+            valid_boxes = boxes[valid_score_mask]
+            keep = nms(valid_boxes, valid_scores, nms_thr)
+            if len(keep) > 0:
+                cls_inds = np.ones((len(keep), 1)) * cls_ind
+                dets = np.concatenate([valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
+                final_dets.append(dets)
+    if len(final_dets) == 0:
+        return None
+    return np.concatenate(final_dets, 0)
+
+
+def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy.
+
+    Class-agnostic version.
+    """
+    cls_inds = scores.argmax(1)
+    cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
+
+    valid_score_mask = cls_scores > score_thr
+    if valid_score_mask.sum() == 0:
+        return None
+    valid_scores = cls_scores[valid_score_mask]
+    valid_boxes = boxes[valid_score_mask]
+    valid_cls_inds = cls_inds[valid_score_mask]
+    keep = nms(valid_boxes, valid_scores, nms_thr)
+    if keep:
+        dets = np.concatenate([valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1)
+    return dets
+
+
+def demo_postprocess(outputs, img_size, p6=False):
+
+    grids = []
+    expanded_strides = []
+
+    if not p6:
+        strides = [8, 16, 32]
+    else:
+        strides = [8, 16, 32, 64]
+
+    hsizes = [img_size[0] // stride for stride in strides]
+    wsizes = [img_size[1] // stride for stride in strides]
+
+    for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+        xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+        grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+        grids.append(grid)
+        shape = grid.shape[:2]
+        expanded_strides.append(np.full((*shape, 1), stride))
+
+    grids = np.concatenate(grids, 1)
+    expanded_strides = np.concatenate(expanded_strides, 1)
+    outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+    outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+    return outputs
diff --git a/det/yolox/utils/dist.py b/det/yolox/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..2284d1b79c2e5c4fdb8ab0dff832d2edbbb376ca
--- /dev/null
+++ b/det/yolox/utils/dist.py
@@ -0,0 +1,264 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# This file mainly comes from
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+"""This file contains primitives for multi-gpu communication.
+
+This is useful when doing distributed training.
+"""
+import functools
+import os
+import pickle
+import time
+from contextlib import contextmanager
+from loguru import logger
+
+import numpy as np
+
+import torch
+from torch import distributed as dist
+
+__all__ = [
+    "get_num_devices",
+    "wait_for_the_master",
+    "is_main_process",
+    "synchronize",
+    "get_world_size",
+    "get_rank",
+    "get_local_rank",
+    "get_local_size",
+    "time_synchronized",
+    "gather",
+    "all_gather",
+]
+
+_LOCAL_PROCESS_GROUP = None
+
+
+def get_num_devices():
+    gpu_list = os.getenv("CUDA_VISIBLE_DEVICES", None)
+    if gpu_list is not None:
+        return len(gpu_list.split(","))
+    else:
+        devices_list_info = os.popen("nvidia-smi -L")
+        devices_list_info = devices_list_info.read().strip().split("\n")
+        return len(devices_list_info)
+
+
+@contextmanager
+def wait_for_the_master(local_rank: int):
+    """Make all processes waiting for the master to do some task."""
+    if local_rank > 0:
+        dist.barrier()
+    yield
+    if local_rank == 0:
+        if not dist.is_available():
+            return
+        if not dist.is_initialized():
+            return
+        else:
+            dist.barrier()
+
+
+def synchronize():
+    """Helper function to synchronize (barrier) among all processes when using
+    distributed training."""
+    if not dist.is_available():
+        return
+    if not dist.is_initialized():
+        return
+    world_size = dist.get_world_size()
+    if world_size == 1:
+        return
+    dist.barrier()
+
+
+def get_world_size() -> int:
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank() -> int:
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    """
+    Returns:
+        The rank of the current process within the local (per-machine) process group.
+    """
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    assert _LOCAL_PROCESS_GROUP is not None
+    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+    """
+    Returns:
+        The size of the per-machine process group, i.e. the number of processes per machine.
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+    return get_rank() == 0
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+    """Return a process group based on gloo backend, containing all the ranks
+    The result is cached."""
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+    else:
+        return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+    backend = dist.get_backend(group)
+    assert backend in ["gloo", "nccl"]
+    device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+    buffer = pickle.dumps(data)
+    if len(buffer) > 1024**3:
+        logger.warning(
+            "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+                get_rank(), len(buffer) / (1024**3), device
+            )
+        )
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to(device=device)
+    return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+    """
+    Returns:
+        list[int]: size of the tensor, on each rank
+        Tensor: padded tensor that has the max size
+    """
+    world_size = dist.get_world_size(group=group)
+    assert world_size >= 1, "comm.gather/all_gather must be called from ranks within the given group!"
+    local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+    size_list = [torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)]
+    dist.all_gather(size_list, local_size, group=group)
+    size_list = [int(size.item()) for size in size_list]
+
+    max_size = max(size_list)
+
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    if local_size != max_size:
+        padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+        tensor = torch.cat((tensor, padding), dim=0)
+    return size_list, tensor
+
+
+def all_gather(data, group=None):
+    """Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+    Args:
+        data: any picklable object
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return [data]
+
+    tensor = _serialize_to_tensor(data, group)
+
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
+    dist.all_gather(tensor_list, tensor, group=group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def gather(data, dst=0, group=None):
+    """Run gather on arbitrary picklable data (not necessarily tensors).
+
+    Args:
+        data: any picklable object
+        dst (int): destination rank
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+
+    Returns:
+        list[data]: on dst, a list of data gathered from each rank. Otherwise,
+            an empty list.
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group=group) == 1:
+        return [data]
+    rank = dist.get_rank(group=group)
+
+    tensor = _serialize_to_tensor(data, group)
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+    # receiving Tensor from all ranks
+    if rank == dst:
+        max_size = max(size_list)
+        tensor_list = [torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list]
+        dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+        data_list = []
+        for size, tensor in zip(size_list, tensor_list):
+            buffer = tensor.cpu().numpy().tobytes()[:size]
+            data_list.append(pickle.loads(buffer))
+        return data_list
+    else:
+        dist.gather(tensor, [], dst=dst, group=group)
+        return []
+
+
+def shared_random_seed():
+    """
+    Returns:
+        int: a random number that is the same across all workers.
+            If workers need a shared RNG, they can use this shared seed to
+            create one.
+    All workers must call this function, otherwise it will deadlock.
+    """
+    ints = np.random.randint(2**31)
+    all_ints = all_gather(ints)
+    return all_ints[0]
+
+
+def time_synchronized():
+    """pytorch-accurate time."""
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    return time.perf_counter()
diff --git a/det/yolox/utils/ema.py b/det/yolox/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b545a1c20a5154b184545de73a25cf09b772a28
--- /dev/null
+++ b/det/yolox/utils/ema.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import math
+from copy import deepcopy
+import logging
+
+import torch
+import torch.nn as nn
+
+from lib.utils.setup_logger import log_first_n
+
+__all__ = ["ModelEMA", "is_parallel"]
+
+
+def is_parallel(model):
+    """check if model is in parallel mode."""
+
+    parallel_type = (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+    return isinstance(model, parallel_type)
+
+
+class ModelEMA:
+    """Model Exponential Moving Average from
+    https://github.com/rwightman/pytorch-image-models Keep a moving average of
+    everything in the model state_dict (parameters and buffers).
+
+    This is intended to allow functionality like
+    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    A smoothed version of the weights is necessary for some training schemes to perform well.
+    This class is sensitive where it is initialized in the sequence of model init,
+    GPU assignment and distributed training wrappers.
+    """
+
+    def __init__(self, model, decay=0.9999, updates=0):
+        """
+        Args:
+            model (nn.Module): model to apply EMA.
+            decay (float): ema decay reate.
+            updates (int): counter of EMA updates.
+        """
+        # Create EMA(FP32)
+        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
+        self.updates = updates
+        # decay exponential ramp (to help early epochs)
+        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def update(self, model):
+        # Update EMA parameters
+        with torch.no_grad():
+            self.updates += 1
+            d = self.decay(self.updates)
+
+            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
+            for k, v in self.ema.state_dict().items():
+                if v.dtype.is_floating_point:
+                    v *= d
+                    v += (1.0 - d) * msd[k].detach()
diff --git a/det/yolox/utils/logger.py b/det/yolox/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..118f0cbc581ecea96addd2b46f49bb01c7482016
--- /dev/null
+++ b/det/yolox/utils/logger.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import inspect
+import os
+import sys
+from loguru import logger
+
+
+def get_caller_name(depth=0):
+    """
+    Args:
+        depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
+
+    Returns:
+        str: module name of the caller
+    """
+    # the following logic is a little bit faster than inspect.stack() logic
+    frame = inspect.currentframe().f_back
+    for _ in range(depth):
+        frame = frame.f_back
+
+    return frame.f_globals["__name__"]
+
+
+class StreamToLoguru:
+    """stream object that redirects writes to a logger instance."""
+
+    def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
+        """
+        Args:
+            level(str): log level string of loguru. Default value: "INFO".
+            caller_names(tuple): caller names of redirected module.
+                Default value: (apex, pycocotools).
+        """
+        self.level = level
+        self.linebuf = ""
+        self.caller_names = caller_names
+
+    def write(self, buf):
+        full_name = get_caller_name(depth=1)
+        module_name = full_name.rsplit(".", maxsplit=-1)[0]
+        if module_name in self.caller_names:
+            for line in buf.rstrip().splitlines():
+                # use caller level log
+                logger.opt(depth=2).log(self.level, line.rstrip())
+        else:
+            sys.__stdout__.write(buf)
+
+    def flush(self):
+        pass
+
+
+def redirect_sys_output(log_level="INFO"):
+    redirect_logger = StreamToLoguru(log_level)
+    sys.stderr = redirect_logger
+    sys.stdout = redirect_logger
+
+
+def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
+    """setup logger for training and testing.
+    Args:
+        save_dir(str): location to save log file
+        distributed_rank(int): device rank when multi-gpu environment
+        filename (string): log save name.
+        mode(str): log file write mode, `append` or `override`. default is `a`.
+
+    Return:
+        logger instance.
+    """
+    loguru_format = (
+        "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
+        "<level>{level: <8}</level> | "
+        "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
+    )
+
+    logger.remove()
+    save_file = os.path.join(save_dir, filename)
+    if mode == "o" and os.path.exists(save_file):
+        os.remove(save_file)
+    # only keep logger in rank0 process
+    if distributed_rank == 0:
+        logger.add(
+            sys.stderr,
+            format=loguru_format,
+            level="INFO",
+            enqueue=True,
+        )
+        logger.add(save_file)
+
+    # redirect stdout/stderr to loguru
+    redirect_sys_output("INFO")
diff --git a/det/yolox/utils/lr_scheduler.py b/det/yolox/utils/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dcf409b48816287f50864d51b44d66a14baccdf
--- /dev/null
+++ b/det/yolox/utils/lr_scheduler.py
@@ -0,0 +1,176 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import math
+from functools import partial
+
+
+class LRScheduler:
+    def __init__(self, name, lr, iters_per_epoch, total_epochs, **kwargs):
+        """Supported lr schedulers: [cos, warmcos, multistep]
+
+        Args:
+            lr (float): learning rate.
+            iters_per_peoch (int): number of iterations in one epoch.
+            total_epochs (int): number of epochs in training.
+            kwargs (dict):
+                - cos: None
+                - warmcos: [warmup_epochs, warmup_lr_start (default 1e-6)]
+                - multistep: [milestones (epochs), gamma (default 0.1)]
+        """
+
+        self.lr = lr
+        self.iters_per_epoch = iters_per_epoch
+        self.total_epochs = total_epochs
+        self.total_iters = iters_per_epoch * total_epochs
+
+        self.__dict__.update(kwargs)
+
+        self.lr_func = self._get_lr_func(name)
+
+    def update_lr(self, iters):
+        return self.lr_func(iters)
+
+    def _get_lr_func(self, name):
+        if name == "cos":  # cosine lr schedule
+            lr_func = partial(cos_lr, self.lr, self.total_iters)
+        elif name == "warmcos":
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            warmup_lr_start = getattr(self, "warmup_lr_start", 1e-6)
+            lr_func = partial(
+                warm_cos_lr,
+                self.lr,
+                self.total_iters,
+                warmup_total_iters,
+                warmup_lr_start,
+            )
+        elif name == "yoloxwarmcos":
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
+            warmup_lr_start = getattr(self, "warmup_lr_start", 0)
+            min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
+            lr_func = partial(
+                yolox_warm_cos_lr,
+                self.lr,
+                min_lr_ratio,
+                self.total_iters,
+                warmup_total_iters,
+                warmup_lr_start,
+                no_aug_iters,
+            )
+        elif name == "yoloxsemiwarmcos":
+            warmup_lr_start = getattr(self, "warmup_lr_start", 0)
+            min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
+            normal_iters = self.iters_per_epoch * self.semi_epoch
+            semi_iters = self.iters_per_epoch_semi * (self.total_epochs - self.semi_epoch - self.no_aug_epochs)
+            lr_func = partial(
+                yolox_semi_warm_cos_lr,
+                self.lr,
+                min_lr_ratio,
+                warmup_lr_start,
+                self.total_iters,
+                normal_iters,
+                no_aug_iters,
+                warmup_total_iters,
+                semi_iters,
+                self.iters_per_epoch,
+                self.iters_per_epoch_semi,
+            )
+        elif name == "multistep":  # stepwise lr schedule
+            milestones = [int(self.total_iters * milestone / self.total_epochs) for milestone in self.milestones]
+            gamma = getattr(self, "gamma", 0.1)
+            lr_func = partial(multistep_lr, self.lr, milestones, gamma)
+        else:
+            raise ValueError("Scheduler version {} not supported.".format(name))
+        return lr_func
+
+
+def cos_lr(lr, total_iters, iters):
+    """Cosine learning rate."""
+    lr *= 0.5 * (1.0 + math.cos(math.pi * iters / total_iters))
+    return lr
+
+
+def warm_cos_lr(lr, total_iters, warmup_total_iters, warmup_lr_start, iters):
+    """Cosine learning rate with warm up."""
+    if iters <= warmup_total_iters:
+        lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+    else:
+        lr *= 0.5 * (1.0 + math.cos(math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters)))
+    return lr
+
+
+def yolox_warm_cos_lr(
+    lr,
+    min_lr_ratio,
+    total_iters,
+    warmup_total_iters,
+    warmup_lr_start,
+    no_aug_iter,
+    iters,
+):
+    """Cosine learning rate with warm up.
+
+    iters: current iter
+    """
+    min_lr = lr * min_lr_ratio
+    if iters <= warmup_total_iters:
+        # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+        lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
+    elif iters >= total_iters - no_aug_iter:
+        lr = min_lr
+    else:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0 + math.cos(math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
+        )
+    return lr
+
+
+def yolox_semi_warm_cos_lr(
+    lr,
+    min_lr_ratio,
+    warmup_lr_start,
+    total_iters,
+    normal_iters,
+    no_aug_iters,
+    warmup_total_iters,
+    semi_iters,
+    iters_per_epoch,
+    iters_per_epoch_semi,
+    iters,
+):
+    """Cosine learning rate with warm up."""
+    min_lr = lr * min_lr_ratio
+    if iters <= warmup_total_iters:
+        # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+        lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
+    elif iters >= normal_iters + semi_iters:
+        lr = min_lr
+    elif iters <= normal_iters:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0 + math.cos(math.pi * (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iters))
+        )
+    else:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0
+            + math.cos(
+                math.pi
+                * (
+                    normal_iters
+                    - warmup_total_iters
+                    + (iters - normal_iters) * iters_per_epoch * 1.0 / iters_per_epoch_semi
+                )
+                / (total_iters - warmup_total_iters - no_aug_iters)
+            )
+        )
+    return lr
+
+
+def multistep_lr(lr, milestones, gamma, iters):
+    """MultiStep learning rate."""
+    for milestone in milestones:
+        lr *= gamma if iters >= milestone else 1.0
+    return lr
diff --git a/det/yolox/utils/metric.py b/det/yolox/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb8bf373ab6d034077c0389b73fb22990e7a6819
--- /dev/null
+++ b/det/yolox/utils/metric.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import functools
+import os
+import time
+from collections import defaultdict, deque
+
+import numpy as np
+
+import torch
+
+__all__ = [
+    "AverageMeter",
+    "MeterBuffer",
+    "get_total_and_free_memory_in_Mb",
+    "occupy_mem",
+    "gpu_mem_usage",
+]
+
+
+def get_total_and_free_memory_in_Mb(cuda_device):
+    devices_info_str = os.popen("nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader")
+    devices_info = devices_info_str.read().strip().split("\n")
+    total, used = devices_info[int(cuda_device)].split(",")
+    return int(total), int(used)
+
+
+def occupy_mem(cuda_device, mem_ratio=0.9):
+    """pre-allocate gpu memory for training to avoid memory Fragmentation."""
+    total, used = get_total_and_free_memory_in_Mb(cuda_device)
+    max_mem = int(total * mem_ratio)
+    block_mem = max_mem - used
+    x = torch.cuda.FloatTensor(256, 1024, block_mem)
+    del x
+    time.sleep(5)
+
+
+def gpu_mem_usage():
+    """Compute the GPU memory usage for the current device (MB)."""
+    mem_usage_bytes = torch.cuda.max_memory_allocated()
+    return mem_usage_bytes / (1024 * 1024)
+
+
+class AverageMeter:
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average."""
+
+    def __init__(self, window_size=50):
+        self._deque = deque(maxlen=window_size)
+        self._total = 0.0
+        self._count = 0
+
+    def update(self, value):
+        self._deque.append(value)
+        self._count += 1
+        self._total += value
+
+    @property
+    def median(self):
+        d = np.array(list(self._deque))
+        return np.median(d)
+
+    @property
+    def avg(self):
+        # if deque is empty, nan will be returned.
+        d = np.array(list(self._deque))
+        return d.mean()
+
+    @property
+    def global_avg(self):
+        return self._total / max(self._count, 1e-5)
+
+    @property
+    def latest(self):
+        return self._deque[-1] if len(self._deque) > 0 else None
+
+    @property
+    def total(self):
+        return self._total
+
+    def reset(self):
+        self._deque.clear()
+        self._total = 0.0
+        self._count = 0
+
+    def clear(self):
+        self._deque.clear()
+
+
+class MeterBuffer(defaultdict):
+    """Computes and stores the average and current value."""
+
+    def __init__(self, window_size=20):
+        factory = functools.partial(AverageMeter, window_size=window_size)
+        super().__init__(factory)
+
+    def reset(self):
+        for v in self.values():
+            v.reset()
+
+    def get_filtered_meter(self, filter_key="time"):
+        return {k: v for k, v in self.items() if filter_key in k}
+
+    def update(self, values=None, **kwargs):
+        if values is None:
+            values = {}
+        values.update(kwargs)
+        for k, v in values.items():
+            if isinstance(v, torch.Tensor):
+                v = v.detach()
+            self[k].update(v)
+
+    def clear_meters(self):
+        for v in self.values():
+            v.clear()
diff --git a/det/yolox/utils/model_utils.py b/det/yolox/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ba7fc2bd2ef05d577ed4ad7915f9ede85177d47
--- /dev/null
+++ b/det/yolox/utils/model_utils.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from thop import profile
+
+__all__ = [
+    "fuse_conv_and_bn",
+    "fuse_model",
+    "get_model_info",
+    "replace_module",
+    "scale_img",
+]
+
+
+def get_model_info(model, tsize):
+
+    stride = 64
+    img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
+    flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
+    params /= 1e6
+    flops /= 1e9
+    flops *= tsize[0] * tsize[1] / stride / stride * 2  # Gflops
+    info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
+    return info
+
+
+def fuse_conv_and_bn(conv, bn):
+    # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+    fusedconv = (
+        nn.Conv2d(
+            conv.in_channels,
+            conv.out_channels,
+            kernel_size=conv.kernel_size,
+            stride=conv.stride,
+            padding=conv.padding,
+            groups=conv.groups,
+            bias=True,
+        )
+        .requires_grad_(False)
+        .to(conv.weight.device)
+    )
+
+    # prepare filters
+    w_conv = conv.weight.clone().view(conv.out_channels, -1)
+    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
+
+    # prepare spatial bias
+    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+    return fusedconv
+
+
+def fuse_model(model):
+    from det.yolox.models.network_blocks import BaseConv
+
+    for m in model.modules():
+        if type(m) is BaseConv and hasattr(m, "bn"):
+            m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
+            delattr(m, "bn")  # remove batchnorm
+            m.forward = m.fuseforward  # update forward
+    return model
+
+
+def replace_module(module, replaced_module_type, new_module_type, replace_func=None):
+    """Replace given type in module to a new type. mostly used in deploy.
+
+    Args:
+        module (nn.Module): model to apply replace operation.
+        replaced_module_type (Type): module type to be replaced.
+        new_module_type (Type)
+        replace_func (function): python function to describe replace logic. Defalut value None.
+
+    Returns:
+        model (nn.Module): module that already been replaced.
+    """
+
+    def default_replace_func(replaced_module_type, new_module_type):
+        return new_module_type()
+
+    if replace_func is None:
+        replace_func = default_replace_func
+
+    model = module
+    if isinstance(module, replaced_module_type):
+        model = replace_func(replaced_module_type, new_module_type)
+    else:  # recurrsively replace
+        for name, child in module.named_children():
+            new_child = replace_module(child, replaced_module_type, new_module_type)
+            if new_child is not child:  # child is already replaced
+                model.add_module(name, new_child)
+
+    return model
+
+
+def scale_img(img, ratio=1.0, same_shape=False, gs=32):  # img(16,3,256,416)
+    # scales img(bs,3,y,x) by ratio constrained to gs-multiple
+    if ratio == 1.0:
+        return img
+    else:
+        h, w = img.shape[2:]
+        s = (int(h * ratio), int(w * ratio))  # new size
+        img = F.interpolate(img, size=s, mode="bilinear", align_corners=False)  # resize
+        if not same_shape:  # pad/crop img
+            h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
+        return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447 * 255)  # value = imagenet mean
diff --git a/det/yolox/utils/setup_env.py b/det/yolox/utils/setup_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..03e4c69e907103b3f71f2d1ed2c6cca0f4db5240
--- /dev/null
+++ b/det/yolox/utils/setup_env.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+import os
+import subprocess
+from loguru import logger
+import cv2
+from .dist import get_world_size, is_main_process
+
+
+__all__ = ["configure_nccl", "configure_module", "get_yolox_datadir", "configure_omp"]
+
+
+def get_yolox_datadir():
+    """get dataset dir of YOLOX.
+
+    If environment variable named `YOLOX_DATADIR` is set, this function
+    will return value of the environment variable. Otherwise, use data
+    """
+    yolox_datadir = os.getenv("YOLOX_DATADIR", None)
+    if yolox_datadir is None:
+        from det import yolox
+
+        yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
+        yolox_datadir = os.path.join(yolox_path, "../datasets")
+    return yolox_datadir
+
+
+def configure_nccl():
+    """Configure multi-machine environment variables of NCCL."""
+    os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
+    os.environ["NCCL_IB_HCA"] = subprocess.getoutput(
+        "pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; "
+        "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null "
+        "| grep v >/dev/null && echo $i ; done; popd > /dev/null"
+    )
+    os.environ["NCCL_IB_GID_INDEX"] = "3"
+    os.environ["NCCL_IB_TC"] = "106"
+
+
+def configure_omp(num_threads=1):
+    """If OMP_NUM_THREADS is not configured and world_size is greater than 1,
+    Configure OMP_NUM_THREADS environment variables of NCCL to `num_thread`.
+
+    Args:
+        num_threads (int): value of `OMP_NUM_THREADS` to set.
+    """
+    # We set OMP_NUM_THREADS=1 by default, which achieves the best speed on our machines
+    # feel free to change it for better performance.
+    if "OMP_NUM_THREADS" not in os.environ and get_world_size() > 1:
+        os.environ["OMP_NUM_THREADS"] = str(num_threads)
+        if is_main_process():
+            logger.info(
+                "\n***************************************************************\n"
+                "We set `OMP_NUM_THREADS` for each process to {} to speed up.\n"
+                "please further tune the variable for optimal performance.\n"
+                "***************************************************************".format(os.environ["OMP_NUM_THREADS"])
+            )
+
+
+def configure_module(ulimit_value=8192):
+    """Configure pytorch module environment. setting of ulimit and cv2 will be
+    set.
+
+    Args:
+        ulimit_value(int): default open file number on linux. Default value: 8192.
+    """
+    # system setting
+    try:
+        import resource
+
+        rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+        resource.setrlimit(resource.RLIMIT_NOFILE, (ulimit_value, rlimit[1]))
+    except Exception:
+        # Exception might be raised in Windows OS or rlimit reaches max limit number.
+        # However, set rlimit value might not be necessary.
+        pass
+
+    # cv2
+    # multiprocess might be harmful on performance of torch dataloader
+    os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
+    try:
+        cv2.setNumThreads(0)
+        cv2.ocl.setUseOpenCL(False)
+    except Exception:
+        # cv2 version mismatch might rasie exceptions.
+        pass
diff --git a/det/yolox/utils/visualize.py b/det/yolox/utils/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf234d9990d35e0de64caa2d334877ce1d358e4b
--- /dev/null
+++ b/det/yolox/utils/visualize.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
+
+import cv2
+import numpy as np
+import os.path as osp
+import time
+
+import ref
+
+__all__ = ["vis"]
+
+
+def vis_train(inps, targets, cfg):
+    for i in range(inps.shape[0]):
+        image = inps[i].cpu().numpy().transpose(2, 1, 0).astype(np.uint8).copy()
+        target = targets[i].cpu().numpy().astype(np.int).copy()
+        bbox = target[:, 1:]
+        # bbox[:, 2] = bbox[:, 0] + bbox[:, 2]
+        # bbox[:, 3] = bbox[:, 1] + bbox[:, 3]
+        # scene_id = int(scene_im_id[0].split("/")[0])
+        # im_id = int(scene_im_id[0].split("/")[1])
+        out_file = osp.join(cfg.train["output_dir"], "{}.png".format(str(time.perf_counter())))
+        # scores = np.ones(bbox.shape[0])
+        scores = np.zeros(bbox.shape[0])
+        cls_ids = target[:, 0]
+        class_names = ref.hb.objects
+        vis_image = vis(image, bbox, scores, cls_ids, 0.5, class_names)
+        cv2.imwrite(out_file, vis_image)
+
+
+def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
+
+    for i in range(len(boxes)):
+        box = boxes[i]
+        cls_id = int(cls_ids[i])
+        score = scores[i]
+        if score < conf:
+            continue
+        x0 = int(box[0])
+        y0 = int(box[1])
+        x1 = int(box[2])
+        y1 = int(box[3])
+
+        color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
+        text = "{}:{:.1f}%".format(class_names[cls_id], score * 100)
+        txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
+        font = cv2.FONT_HERSHEY_SIMPLEX
+
+        txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
+        cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
+
+        txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
+        cv2.rectangle(
+            img,
+            (x0, y0 + 1),
+            (x0 + txt_size[0] + 1, y0 + int(1.5 * txt_size[1])),
+            txt_bk_color,
+            -1,
+        )
+        cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
+
+    return img
+
+
+_COLORS = (
+    np.array(
+        [
+            0.000,
+            0.447,
+            0.741,
+            0.850,
+            0.325,
+            0.098,
+            0.929,
+            0.694,
+            0.125,
+            0.494,
+            0.184,
+            0.556,
+            0.466,
+            0.674,
+            0.188,
+            0.301,
+            0.745,
+            0.933,
+            0.635,
+            0.078,
+            0.184,
+            0.300,
+            0.300,
+            0.300,
+            0.600,
+            0.600,
+            0.600,
+            1.000,
+            0.000,
+            0.000,
+            1.000,
+            0.500,
+            0.000,
+            0.749,
+            0.749,
+            0.000,
+            0.000,
+            1.000,
+            0.000,
+            0.000,
+            0.000,
+            1.000,
+            0.667,
+            0.000,
+            1.000,
+            0.333,
+            0.333,
+            0.000,
+            0.333,
+            0.667,
+            0.000,
+            0.333,
+            1.000,
+            0.000,
+            0.667,
+            0.333,
+            0.000,
+            0.667,
+            0.667,
+            0.000,
+            0.667,
+            1.000,
+            0.000,
+            1.000,
+            0.333,
+            0.000,
+            1.000,
+            0.667,
+            0.000,
+            1.000,
+            1.000,
+            0.000,
+            0.000,
+            0.333,
+            0.500,
+            0.000,
+            0.667,
+            0.500,
+            0.000,
+            1.000,
+            0.500,
+            0.333,
+            0.000,
+            0.500,
+            0.333,
+            0.333,
+            0.500,
+            0.333,
+            0.667,
+            0.500,
+            0.333,
+            1.000,
+            0.500,
+            0.667,
+            0.000,
+            0.500,
+            0.667,
+            0.333,
+            0.500,
+            0.667,
+            0.667,
+            0.500,
+            0.667,
+            1.000,
+            0.500,
+            1.000,
+            0.000,
+            0.500,
+            1.000,
+            0.333,
+            0.500,
+            1.000,
+            0.667,
+            0.500,
+            1.000,
+            1.000,
+            0.500,
+            0.000,
+            0.333,
+            1.000,
+            0.000,
+            0.667,
+            1.000,
+            0.000,
+            1.000,
+            1.000,
+            0.333,
+            0.000,
+            1.000,
+            0.333,
+            0.333,
+            1.000,
+            0.333,
+            0.667,
+            1.000,
+            0.333,
+            1.000,
+            1.000,
+            0.667,
+            0.000,
+            1.000,
+            0.667,
+            0.333,
+            1.000,
+            0.667,
+            0.667,
+            1.000,
+            0.667,
+            1.000,
+            1.000,
+            1.000,
+            0.000,
+            1.000,
+            1.000,
+            0.333,
+            1.000,
+            1.000,
+            0.667,
+            1.000,
+            0.333,
+            0.000,
+            0.000,
+            0.500,
+            0.000,
+            0.000,
+            0.667,
+            0.000,
+            0.000,
+            0.833,
+            0.000,
+            0.000,
+            1.000,
+            0.000,
+            0.000,
+            0.000,
+            0.167,
+            0.000,
+            0.000,
+            0.333,
+            0.000,
+            0.000,
+            0.500,
+            0.000,
+            0.000,
+            0.667,
+            0.000,
+            0.000,
+            0.833,
+            0.000,
+            0.000,
+            1.000,
+            0.000,
+            0.000,
+            0.000,
+            0.167,
+            0.000,
+            0.000,
+            0.333,
+            0.000,
+            0.000,
+            0.500,
+            0.000,
+            0.000,
+            0.667,
+            0.000,
+            0.000,
+            0.833,
+            0.000,
+            0.000,
+            1.000,
+            0.000,
+            0.000,
+            0.000,
+            0.143,
+            0.143,
+            0.143,
+            0.286,
+            0.286,
+            0.286,
+            0.429,
+            0.429,
+            0.429,
+            0.571,
+            0.571,
+            0.571,
+            0.714,
+            0.714,
+            0.714,
+            0.857,
+            0.857,
+            0.857,
+            0.000,
+            0.447,
+            0.741,
+            0.314,
+            0.717,
+            0.741,
+            0.50,
+            0.5,
+            0,
+        ]
+    )
+    .astype(np.float32)
+    .reshape(-1, 3)
+)