Skip to content
Snippets Groups Projects
Commit 1c4c72b2 authored by Yen-Chen Lin's avatar Yen-Chen Lin
Browse files

Add code to segment out certain categories

parent fdbaf996
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,7 @@ def parse_args(): ...@@ -43,6 +43,7 @@ def parse_args():
parser.add_argument("--out", default="transforms.json", help="Output path.") parser.add_argument("--out", default="transforms.json", help="Output path.")
parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.") parser.add_argument("--vocab_path", default="", help="Vocabulary tree path.")
parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.") parser.add_argument("--overwrite", action="store_true", help="Do not ask for confirmation for overwriting existing images and COLMAP data.")
parser.add_argument("--mask_categories", nargs="*", type=int, default=[], help="Object categories that should be masked out from the training images. See the category-to-id mapping here: https://gist.github.com/yenchenlin/75931e87cedad3d23868a79046fc97c2")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -389,3 +390,40 @@ if __name__ == "__main__": ...@@ -389,3 +390,40 @@ if __name__ == "__main__":
print(f"writing {OUT_PATH}") print(f"writing {OUT_PATH}")
with open(OUT_PATH, "w") as outfile: with open(OUT_PATH, "w") as outfile:
json.dump(out, outfile, indent=2) json.dump(out, outfile, indent=2)
if len(args.mask_categories) > 1:
# Check if detectron2 is installed. If not, install it.
import importlib.util
package_name = 'detectron2'
spec = importlib.util.find_spec(package_name)
if spec is None:
input("Detectron2 is not installed. Press enter to install it.")
os.system("python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'")
import torch, detectron2
from pathlib import Path
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
cfg = get_cfg()
# Add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo.
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
for frame in out['frames']:
img = cv2.imread(frame['file_path'])
outputs = predictor(img)
output_mask = np.zeros((img.shape[0], img.shape[1]))
for i in range(len(outputs['instances'])):
if outputs['instances'][i].pred_classes.cpu().numpy()[0] in args.mask_categories:
pred_mask = outputs['instances'][i].pred_masks.cpu().numpy()[0]
output_mask = np.logical_or(output_mask, pred_mask)
rgb_path = Path(frame['file_path'])
mask_name = str(rgb_path.parents[0] / Path('dynamic_mask_' + rgb_path.name.replace('.jpg', '.png')))
cv2.imwrite(mask_name, (output_mask*255).astype(np.uint8))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment