Skip to content
Snippets Groups Projects
Unverified Commit 6b87b3b8 authored by Thomas Müller's avatar Thomas Müller Committed by GitHub
Browse files

Merge pull request #1193 from yenchenlin/add-segmentation

Add code to mask out specified object categories in training images
parents fdbaf996 3e8d9812
No related branches found
No related tags found
No related merge requests found
...@@ -147,11 +147,12 @@ jobs: ...@@ -147,11 +147,12 @@ jobs:
docs/assets_readme/ docs/assets_readme/
data/ data/
scripts/flip/* scripts/flip/*
scripts/category2id.json
scripts/colmap2nerf.py
scripts/common.py scripts/common.py
scripts/convert_image.py scripts/convert_image.py
scripts/download_colmap.bat scripts/download_colmap.bat
scripts/download_ffmpeg.bat scripts/download_ffmpeg.bat
scripts/colmap2nerf.py
scripts/nsvf2nerf.py scripts/nsvf2nerf.py
scripts/record3d2nerf.py scripts/record3d2nerf.py
......
{
"person": 0,
"bicycle": 1,
"car": 2,
"motorcycle": 3,
"airplane": 4,
"bus": 5,
"train": 6,
"truck": 7,
"boat": 8,
"traffic light": 9,
"fire hydrant": 10,
"stop sign": 11,
"parking meter": 12,
"bench": 13,
"bird": 14,
"cat": 15,
"dog": 16,
"horse": 17,
"sheep": 18,
"cow": 19,
"elephant": 20,
"bear": 21,
"zebra": 22,
"giraffe": 23,
"backpack": 24,
"umbrella": 25,
"handbag": 26,
"tie": 27,
"suitcase": 28,
"frisbee": 29,
"skis": 30,
"snowboard": 31,
"sports ball": 32,
"kite": 33,
"baseball bat": 34,
"baseball glove": 35,
"skateboard": 36,
"surfboard": 37,
"tennis racket": 38,
"bottle": 39,
"wine glass": 40,
"cup": 41,
"fork": 42,
"knife": 43,
"spoon": 44,
"bowl": 45,
"banana": 46,
"apple": 47,
"sandwich": 48,
"orange": 49,
"broccoli": 50,
"carrot": 51,
"hot dog": 52,
"pizza": 53,
"donut": 54,
"cake": 55,
"chair": 56,
"couch": 57,
"potted plant": 58,
"bed": 59,
"dining table": 60,
"toilet": 61,
"tv": 62,
"laptop": 63,
"mouse": 64,
"remote": 65,
"keyboard": 66,
"cell phone": 67,
"microwave": 68,
"oven": 69,
"toaster": 70,
"sink": 71,
"refrigerator": 72,
"book": 73,
"clock": 74,
"vase": 75,
"scissors": 76,
"teddy bear": 77,
"hair drier": 78,
"toothbrush": 79
}
\ No newline at end of file
...@@ -43,6 +43,7 @@ def parse_args(): ...@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--out", default="transforms.json", help="Output path.") parser.add_argument("--out", default="transforms.json", help="Output path.")
parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.") parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.")
parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.") parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.")
parser.add_argument("--mask_categories", nargs="*", type=str, default=[], help="Object categories that should be masked out from the training images. See `scripts/category2id.json` for supported categories.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -389,3 +390,46 @@ if __name__ == "__main__": ...@@ -389,3 +390,46 @@ if __name__ == "__main__":
print(f"writing {OUT_PATH}") print(f"writing {OUT_PATH}")
with open(OUT_PATH, "w") as outfile: with open(OUT_PATH, "w") as outfile:
json.dump(out, outfile, indent=2) json.dump(out, outfile, indent=2)
if len(args.mask_categories) > 0:
# Check if detectron2 is installed. If not, install it.
try:
import detectron2
except ModuleNotFoundError:
input("Detectron2 is not installed. Press enter to install it.")
import subprocess
package = 'git+https://github.com/facebookresearch/detectron2.git'
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
import detectron2
import torch
from pathlib import Path
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
dir_path = Path(os.path.dirname(os.path.realpath(__file__)))
category2id = json.load(open(dir_path / "category2id.json", "r"))
mask_ids = [category2id[c] for c in args.mask_categories]
cfg = get_cfg()
# Add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo.
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
for frame in out['frames']:
img = cv2.imread(frame['file_path'])
outputs = predictor(img)
output_mask = np.zeros((img.shape[0], img.shape[1]))
for i in range(len(outputs['instances'])):
if outputs['instances'][i].pred_classes.cpu().numpy()[0] in mask_ids:
pred_mask = outputs['instances'][i].pred_masks.cpu().numpy()[0]
output_mask = np.logical_or(output_mask, pred_mask)
rgb_path = Path(frame['file_path'])
mask_name = str(rgb_path.parents[0] / Path('dynamic_mask_' + rgb_path.name.replace('.jpg', '.png')))
cv2.imwrite(mask_name, (output_mask*255).astype(np.uint8))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment