diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 79e47d146d4634472099336ef5312c0477cc70b8..af4129437b8c45c737584d563896584d391b70cc 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -147,11 +147,12 @@ jobs: docs/assets_readme/ data/ scripts/flip/* + scripts/category2id.json + scripts/colmap2nerf.py scripts/common.py scripts/convert_image.py scripts/download_colmap.bat scripts/download_ffmpeg.bat - scripts/colmap2nerf.py scripts/nsvf2nerf.py scripts/record3d2nerf.py diff --git a/scripts/category2id.json b/scripts/category2id.json new file mode 100644 index 0000000000000000000000000000000000000000..8a66f7378c17afcf6a9ea68612cad0897bc9fa5e --- /dev/null +++ b/scripts/category2id.json @@ -0,0 +1,82 @@ +{ + "person": 0, + "bicycle": 1, + "car": 2, + "motorcycle": 3, + "airplane": 4, + "bus": 5, + "train": 6, + "truck": 7, + "boat": 8, + "traffic light": 9, + "fire hydrant": 10, + "stop sign": 11, + "parking meter": 12, + "bench": 13, + "bird": 14, + "cat": 15, + "dog": 16, + "horse": 17, + "sheep": 18, + "cow": 19, + "elephant": 20, + "bear": 21, + "zebra": 22, + "giraffe": 23, + "backpack": 24, + "umbrella": 25, + "handbag": 26, + "tie": 27, + "suitcase": 28, + "frisbee": 29, + "skis": 30, + "snowboard": 31, + "sports ball": 32, + "kite": 33, + "baseball bat": 34, + "baseball glove": 35, + "skateboard": 36, + "surfboard": 37, + "tennis racket": 38, + "bottle": 39, + "wine glass": 40, + "cup": 41, + "fork": 42, + "knife": 43, + "spoon": 44, + "bowl": 45, + "banana": 46, + "apple": 47, + "sandwich": 48, + "orange": 49, + "broccoli": 50, + "carrot": 51, + "hot dog": 52, + "pizza": 53, + "donut": 54, + "cake": 55, + "chair": 56, + "couch": 57, + "potted plant": 58, + "bed": 59, + "dining table": 60, + "toilet": 61, + "tv": 62, + "laptop": 63, + "mouse": 64, + "remote": 65, + "keyboard": 66, + "cell phone": 67, + "microwave": 68, + "oven": 69, + "toaster": 70, + "sink": 71, + "refrigerator": 72, + "book": 73, + "clock": 74, + "vase": 75, + "scissors": 76, + "teddy bear": 77, + "hair drier": 78, + "toothbrush": 79 +} \ No newline at end of file diff --git a/scripts/colmap2nerf.py b/scripts/colmap2nerf.py index d61b67208f00f30fc2fb774471c771a94bc6e598..ed1faf53ba1be7bd7600736f385fd5eeb5662b37 100755 --- a/scripts/colmap2nerf.py +++ b/scripts/colmap2nerf.py @@ -43,6 +43,7 @@ def parse_args(): parser.add_argument("--out", default="transforms.json", help="Output path.") parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.") parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.") + parser.add_argument("--mask_categories", nargs="*", type=str, default=[], help="Object categories that should be masked out from the training images. See `scripts/category2id.json` for supported categories.") args = parser.parse_args() return args @@ -389,3 +390,46 @@ if __name__ == "__main__": print(f"writing {OUT_PATH}") with open(OUT_PATH, "w") as outfile: json.dump(out, outfile, indent=2) + + if len(args.mask_categories) > 0: + # Check if detectron2 is installed. If not, install it. + try: + import detectron2 + except ModuleNotFoundError: + input("Detectron2 is not installed. Press enter to install it.") + import subprocess + package = 'git+https://github.com/facebookresearch/detectron2.git' + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + import detectron2 + + import torch + from pathlib import Path + from detectron2.config import get_cfg + from detectron2 import model_zoo + from detectron2.engine import DefaultPredictor + + dir_path = Path(os.path.dirname(os.path.realpath(__file__))) + category2id = json.load(open(dir_path / "category2id.json", "r")) + mask_ids = [category2id[c] for c in args.mask_categories] + + cfg = get_cfg() + # Add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library + cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model + # Find a model from detectron2's model zoo. + cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") + predictor = DefaultPredictor(cfg) + + for frame in out['frames']: + img = cv2.imread(frame['file_path']) + outputs = predictor(img) + + output_mask = np.zeros((img.shape[0], img.shape[1])) + for i in range(len(outputs['instances'])): + if outputs['instances'][i].pred_classes.cpu().numpy()[0] in mask_ids: + pred_mask = outputs['instances'][i].pred_masks.cpu().numpy()[0] + output_mask = np.logical_or(output_mask, pred_mask) + + rgb_path = Path(frame['file_path']) + mask_name = str(rgb_path.parents[0] / Path('dynamic_mask_' + rgb_path.name.replace('.jpg', '.png'))) + cv2.imwrite(mask_name, (output_mask*255).astype(np.uint8))