diff --git a/.gitignore b/.gitignore index daa7f462c86eaf2b5fb13386aefb02423f316eb3..ded95b5c8558122940b59f598f87605d69d90c16 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pt +*.pth .ipynb_checkpoints /data /wandb diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e2de6e56b61f5f6a995c66d32d58b7cf79582a66..0000000000000000000000000000000000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "segment-anything"] - path = segment-anything - url = git@github.com:facebookresearch/segment-anything.git diff --git a/README.md b/README.md index c43ed0956bdca3733a5a9fa8b00b985ff332349c..f3b7874a06966944e245d2b40ec3b28b67018ce6 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ All credit for the model goes to the original authors. <img src="https://drive.google.com/uc?id=1Gsoxlab6c3wOL0Bdj6SEV8L1RsI-mhWF" alt="MSN Example" width="900"/> ## Setup +`git clone -r git@github.com:alexcbb/OSRT-experiments.git` After cloning the repository and creating a new conda environment, the following steps will get you started: ### Data @@ -27,6 +28,13 @@ required to load OSRT's MultiShapeNet data, though the CPU version suffices. Rendering videos additionally depends on `ffmpeg>=4.3` being available in your `$PATH`. +To install Segment Anything dependencies : +```bash +cd segment-anything/ + +pip install -e +``` + ## Running Experiments Each run's config, checkpoints, and visualization are stored in a dedicated directory. Recommended configs can be found under `runs/[dataset]/[model]`. diff --git a/array.txt b/array.txt new file mode 100644 index 0000000000000000000000000000000000000000..8f52fa4c2d67ed09eaee8206684309132a60b92b --- /dev/null +++ b/array.txt @@ -0,0 +1,480 @@ +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +1.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +1.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 +0.000000000000000000e+00 diff --git a/automatic_mask_train.py b/automatic_mask_train.py new file mode 100644 index 0000000000000000000000000000000000000000..8df774b2337558bcec063e009feecef0e87420d7 --- /dev/null +++ b/automatic_mask_train.py @@ -0,0 +1,98 @@ +import argparse +import torch +from osrt.encoder import SamAutomaticMask +from segment_anything import sam_model_registry +from torchvision import transforms +import time +import matplotlib.pyplot as plt +import numpy as np +import cv2 + +def show_anns(masks): + ax = plt.gca() + ax.set_autoscale_on(False) + sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3), [0.95]]) + img[m] = color_mask + ax.imshow(img) + +def show_points(coords, ax, marker_size=100): + ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size) + + +if __name__ == '__main__': + # Arguments + parser = argparse.ArgumentParser( + description='Test Segment Anything Auto Mask simplified implementation' + ) + parser.add_argument('--model', default='vit_l', type=str, help='Model to use') + parser.add_argument('--path_model', default='.', type=str, help='Path to the model') + + args = parser.parse_args() + device = "cuda" + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model_type = args.model + if args.model == 'vit_h': + checkpoint = args.path_model + '/sam_vit_h_4b8939.pth' + elif args.model == 'vit_b': + checkpoint = args.path_model + '/sam_vit_b_01ec64.pth' + else: + checkpoint = args.path_model + '/sam_vit_l_0b3195.pth' + + ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/" + images_path = [] + with open(ycb_path + "image_sets/train.txt", 'r') as f: + for line in f.readlines(): + line = line.strip() + images_path.append(ycb_path + 'data/' + line + "-color.png") + + import random + random.shuffle(images_path) + + sam = sam_model_registry[model_type](checkpoint=checkpoint) + sam.to(device=device) + #mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88) + sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_side=8, points_per_batch=64, min_mask_region_area=4000) + sam_mask.to(device) + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + labels = [1 for i in range(len(sam_mask.points_grid))] + with torch.no_grad(): + for image in images_path: + img = cv2.imread(image) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + h, w, _ = img.shape + points = sam_mask.points_grid + new_points= [] + for val in points: + x, y = val[0], val[1] + x *= w + y *= h + new_points.append([x, y]) + new_points = np.array(new_points) + img_batch = [] + img_el = {} + img_el["image"] = img + img_el["original_size"] = (h, w) + img_batch.append(img_el) + + start = time.time() + masks = sam_mask(img_batch) + end = time.time() + print(f"Inference time : {int((end-start) * 1000)}ms") + plt.figure(figsize=(15,15)) + plt.imshow(img) + show_anns(masks[0]["annotations"]) + show_points(new_points, plt.gca()) + plt.axis('off') + plt.show() + diff --git a/osrt/encoder.py b/osrt/encoder.py index 1dff3bc325e574b586487ab5bfcb199a6321f47b..8f2dcb8e84fec9ec913f16400d78d89731bd2ec0 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -1,9 +1,25 @@ import numpy as np import torch import torch.nn as nn +from typing import Any, Dict, List, Optional, Tuple +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore +from torchvision.ops.boxes import batched_nms +from typing import Any, Dict, List, Optional, Tuple + from osrt.layers import RayEncoder, Transformer, SlotAttention +from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score + from segment_anything import SamAutomaticMaskGenerator, sam_model_registry +from segment_anything.modeling import Sam +from segment_anything.modeling import Sam +from segment_anything.modeling.image_encoder import ImageEncoderViT +from segment_anything.modeling.mask_decoder import MaskDecoder +from segment_anything.modeling.prompt_encoder import PromptEncoder +from segment_anything.utils.transforms import ResizeLongestSide +from segment_anything.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle +import cv2 class SRTConvBlock(nn.Module): def __init__(self, idim, hdim=None, odim=None): @@ -92,60 +108,554 @@ class OSRTEncoder(nn.Module): slot_latents = self.slot_attention(set_latents) return slot_latents + class FeatureMasking(nn.Module): - def __init__(self, pos_start_octave=0, num_slots=6, num_conv_blocks=3, num_att_blocks=5, slot_dim=1536, slot_iters=1, sam_model="default", sam_path="sam_vit_h_4b8939.pth", + def __init__(self, + #pos_start_octave=0, + points_per_side=8, + box_nms_thresh = 0.7, + stability_score_thresh = 0.9, + pred_iou_thresh=0.88, + points_per_batch=64, + min_mask_region_area=4000, + num_slots=6, + slot_dim=1536, + slot_iters=1, + sam_model="default", + sam_path="sam_vit_h_4b8939.pth", randomize_initial_slots=False): - super().__init__() + super().__init__() - ########################################################## - ### TODO : REPLACE + # We first initialize the automatic mask generator from SAM sam = sam_model_registry[sam_model](checkpoint=sam_path) - self.mask_generator = SamAutomaticMaskGenerator(sam) + self.mask_generator = SamAutomaticMask(sam.image_encoder, + sam.prompt_encoder, + sam.mask_decoder, + box_nms_thresh=box_nms_thresh, + stability_score_thresh = stability_score_thresh, + pred_iou_thresh=pred_iou_thresh, + points_per_side=points_per_side, + points_per_batch=points_per_batch, + min_mask_region_area=min_mask_region_area) + + """self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, + ray_octaves=15)""" + # We will see if this is usefull later on... for now, keep it simple - # They first encode the rays - self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, - ray_octaves=15) - # This is concatenated to the image features - # They then extract features for each images - conv_blocks = [SRTConvBlock(idim=183, hdim=96)] - cur_hdim = 192 - for i in range(1, num_conv_blocks): - conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None)) - cur_hdim *= 2 + self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters, + randomize_initial_slots=randomize_initial_slots) - self.conv_blocks = nn.Sequential(*conv_blocks) + def forward(self, images): + # Generate images + masks = self.mask_generator(images) - self.per_patch_linear = nn.Conv2d(cur_hdim, 768, kernel_size=1) - self.transformer = Transformer(768, depth=num_att_blocks, heads=12, dim_head=64, - mlp_dim=1536, selfatt=True) + set_latents = None + num_masks = None + + # TODO : set the number of slots according to the masks number + self.slot_attention.change_slots_number(num_masks) - ### TODO : REPLACE - ########################################################## + # [batch_size, num_inputs, dim] + slot_latents = self.slot_attention(set_latents) - self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters, - randomize_initial_slots=randomize_initial_slots) + return slot_latents - def forward(self, images, camera_pos, rays): + +class SamAutomaticMask(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + points_per_side: Optional[int] = 0, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + min_mask_region_area: int = 0 + ) -> None: + """ + This class adapts SAM implementation from original repository but adapting it to our needs : + - Training only the MaskDecoder + - Performing automatic Mask Discovery (combined with AutomaticMask from original repo) + + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + + # Freeze the image encoder and prompt encoder + for param in self.image_encoder.parameters(): + param.requires_grad = False + for param in self.prompt_encoder.parameters(): + param.requires_grad = False + + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + # Transform image to a square by putting it to the longest side + self.transform = ResizeLongestSide(self.image_encoder.img_size) - masks = self.mask_generator.generate(images) - batch_size, num_images = images.shape[:2] + if points_per_side > 0: + self.points_grid = self.create_points_grid(points_per_side) + else: + self.points_grid = None + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.min_mask_region_area = min_mask_region_area + + + @property + def device(self) -> Any: + return self.pixel_mean.device + + def forward( + self, + batched_input: List[Dict[str, Any]], + extract_embeddings: bool = True + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as in 3xHxW format + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. - x = images.flatten(0, 1) - camera_pos = camera_pos.flatten(0, 1) - rays = rays.flatten(0, 1) + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + # TODO : add a way to extract mask embeddings + # Extract image embeddings + input_images = [self.preprocess(x["image"]) for x in batched_input][0] + with torch.no_grad(): + image_embeddings = self.image_encoder(input_images)#, before_channel_reduc=True), embed_no_red + """ + # Extract image embedding before channel reduction, cf. https://github.com/facebookresearch/segment-anything/issues/283 + if before_channel_reduc : + return x, embed_no_red """ + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + # TODO : check if we've got the points given in the batch (to change the current point_grid !) + + im_size = self.transform.apply_image(image_record["image"]).shape[:2] + points_scale = np.array(im_size)[None, ::-1] + points_for_image = self.points_grid * points_scale + + mask_data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self.process_batch(points, im_size, curr_embedding, image_record["original_size"]) + mask_data.cat(batch_data) + del batch_data + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + mask_data["boxes"].float(), + mask_data["iou_preds"], + torch.zeros_like(mask_data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + mask_data.filter(keep_by_nms) + + mask_data.to_numpy() + # TODO : find a way to extract masks + #new_masks = self.complete_holes(mask_data["masks"]) + #print(new_masks) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + self.box_nms_thresh, + ) + + mask_data["segmentations"] = mask_data["masks"] + mask_embed = self.extract_mask_embedding(mask_data, embed_no_red, scale_box=1.5) + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item() + } + if extract_embeddings: + # TODO : add embeddings into the annotations + continue + curr_anns.append(ann) + outputs.append( + { + "annotations": curr_anns + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. - ray_enc = self.ray_encoder(camera_pos, rays) - x = torch.cat((x, ray_enc), 1) - x = self.conv_blocks(x) - x = self.per_patch_linear(x) - x = x.flatten(2, 3).permute(0, 2, 1) + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. - patches_per_image, channels_per_patch = x.shape[1:] - x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch) + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. - x = self.transformer(x) + Edits mask_data in place. - slot_latents = self.slot_attention(x) + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in mask_data["masks"]: + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["masks"][i_mask] = mask_torch + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Rescale the image relative to the longest side + x = self.transform.apply_image(x) + x = torch.as_tensor(x, device=self.device) + x = x.permute(2, 0, 1).contiguous()[None, :, :, :] + + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x - return slot_latents + def create_points_grid(self, number_points): + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * number_points) + points_one_side = np.linspace(offset, 1 - offset, number_points) + points_x = np.tile(points_one_side[None, :], (number_points, 1)) + points_y = np.tile(points_one_side[:, None], (1, number_points)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + def process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + curr_embedding, + curr_orig_size + ): + + # Run model on this batch + transformed_points = self.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + + masks, iou_preds, _ = self.predict_masks( + in_points[:, None, :], + in_labels[:, None], + curr_orig_size, + im_size, + curr_embedding, + multimask_output=True, + return_logits=True + ) + + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + data["rles"] = mask_to_rle_pytorch(data["masks"]) + + return data + + def predict_masks( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + curr_orig_size, + curr_input_size, + curr_embedding, + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the current image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + # Embed prompts + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.postprocess_masks(low_res_masks, curr_input_size, curr_orig_size) + + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def extract_mask_embedding(self, mask_data, image_embed, scale_box=1.5): + """ + Predicts the embeddings from each mask given the global embedding and + a scale factor around each mask. + + Arguments: + mask_data : the data of the masks extracted by SAM + image_embed : embedding of the corresponding image + scale_box : factor by which the bounding box will be scaled + + Returns: + embeddings : the embeddings for each mask extracted from the image + """ + for idx in range(len(mask_data["segmentations"])): + mask = mask_data["segmentations"][idx] + box = mask_data["boxes"][idx] + + def scale_bounding_box(box, scale_factor): + x1, y1, x2, y2 = box + + width = x2 - x1 + height = y2 - y1 + + new_width = width * scale_factor + new_height = height * scale_factor + + new_x1 = x1 - (new_width - width) / 2 + new_y1 = y1 - (new_height - height) / 2 + new_x2 = new_x1 + new_width + new_y2 = new_y1 + new_height + + return new_x1, new_y1, new_x2, new_y2 + + # Scale bounding box + scaled_box = scale_bounding_box(box, scale_box) + print(image_embed.shape) + + masks_embedding = None + return masks_embedding + + def complete_holes(self, + masks): + """" + The purpose of this function is to segment EVERYTHING from the image, without letting any remaining hole + """ + total_mask = masks[0] + for idx in range(len(masks)): + if idx > 0: + total_mask += masks[idx] + + des = total_mask.astype(np.uint8)*255 + kernel = np.ones((4, 4), np.uint8) + img_dilate = cv2.dilate(des, kernel, iterations=1) + + import matplotlib.pyplot as plt + plt.imshow(img_dilate) + plt.show() + + inverse_dilate = np.zeros((total_mask.shape), dtype=np.uint8) + inverse_dilate = np.logical_not(img_dilate).astype(np.uint8)*255 + + contours, _ = cv2.findContours(inverse_dilate, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE) + result_masks = [] + for contour in contours: + area = cv2.contourArea(contour) + if area > 4000: + mask = np.zeros((total_mask.shape), dtype=np.uint8) + cv2.drawContours(mask, [contour], 0, 255, -1) + result_masks.append(mask) + + new_masks_data = MaskData( + masks=torch.tensor(result_masks), + iou_preds=torch.tensor([0.9 for i in range(len(result_masks))]) + ) + + new_masks_data["stability_score"] = calculate_stability_score( + new_masks_data["masks"], self.mask_threshold, self.stability_score_offset + ) + + new_masks_data["boxes"] = batched_mask_to_box(new_masks_data["masks"]) + + new_masks_data["rles"] = mask_to_rle_pytorch(new_masks_data["masks"]) + + return new_masks_data.to_numpy() \ No newline at end of file diff --git a/osrt/layers.py b/osrt/layers.py index 2cfdecc1610caf22087d7befdea7ffe388169431..019d5c025d9981e19d90bdcdf31a840d1917f0a7 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -251,4 +251,5 @@ class SlotAttention(nn.Module): return slots - + def change_slots_number(self, num_slots): + self.num_slots = num_slots diff --git a/osrt/utils/common.py b/osrt/utils/common.py index 6ac5489d9833f72ec015d5fd4044e62cff69d0c5..832d5e3ef87c17503ac1dc244554916f10163916 100644 --- a/osrt/utils/common.py +++ b/osrt/utils/common.py @@ -4,9 +4,106 @@ import os import torch import torch.distributed as dist +from typing import Any, List, Generator, ItemsView +import numpy as np +import math +from copy import deepcopy __LOG10 = math.log(10) +# Method extracted from SAM repo in segment_anything.utils.amg.py +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + +# Method extracted from SAM repo in segment_anything.utils.amg.py +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + +# Method extracted from SAM repo in segment_anything.utils.amg.py +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions def mse2psnr(x): return -10.*torch.log(x)/__LOG10 diff --git a/sam_test.py b/sam_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a62e962695918cc8412443e86f4a61eb85093b59 --- /dev/null +++ b/sam_test.py @@ -0,0 +1,123 @@ +import argparse +import torch +from segment_anything import sam_model_registry, SamAutomaticMaskGenerator +from torchvision import transforms +from PIL import Image +import time +import matplotlib.pyplot as plt +import matplotlib as mpl +import numpy as np +import cv2 + +def show_anns(masks): + if len(masks) == 0: + return + sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) + ax = plt.gca() + ax.set_autoscale_on(False) + + img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) + img[:,:,3] = 0 + for ann in sorted_anns: + m = ann['segmentation'] + color_mask = np.concatenate([np.random.random(3), [0.95]]) + img[m] = color_mask + ax.imshow(img) + +def show_points(coords, ax, marker_size=100): + ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size) + + +if __name__ == '__main__': + # Arguments + parser = argparse.ArgumentParser( + description='Test Segment Anything Auto Mask simplified implementation' + ) + parser.add_argument('--model', default='vit_b', type=str, help='Model to use') + parser.add_argument('--path_model', default='.', type=str, help='Path to the model') + + args = parser.parse_args() + device = "cuda" + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model_type = args.model + if args.model == 'vit_h': + checkpoint = args.path_model + '/sam_vit_h_4b8939.pth' + elif args.model == 'vit_b': + checkpoint = args.path_model + '/sam_vit_b_01ec64.pth' + else: + checkpoint = args.path_model + '/sam_vit_l_0b3195.pth' + + ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/" + images_path = [] + with open(ycb_path + "image_sets/train.txt", 'r') as f: + for line in f.readlines(): + line = line.strip() + images_path.append(ycb_path + 'data/' + line + "-color.png") + + import random + #random.shuffle(images_path) + + sam = sam_model_registry[model_type](checkpoint=checkpoint) + sam.to(device=device) + mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88) + + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + labels = [1 for i in range(len(mask_generator.point_grids))] + with torch.no_grad(): + for image in images_path: + img = cv2.imread(image) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + """img_depth = cv2.imread(image.replace("color", "depth")) + img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)""" + + h, w, _ = img.shape + points = mask_generator.point_grids[0] + new_points= [] + for val in points: + x, y = val[0], val[1] + x *= w + y *= h + new_points.append([x, y]) + new_points = np.array(new_points) + + start = time.time() + masks = mask_generator.generate(img) + end = time.time() + print(f"Inference time : {int((end-start) * 1000)}ms") + + plt.figure(figsize=(15,15)) + plt.imshow(img) + show_anns(masks) + show_points(new_points, plt.gca()) + plt.axis('off') + plt.show() + + """fig, ax = plt.subplots() + cmap = plt.cm.get_cmap('plasma') + img = ax.imshow(img_depth, cmap=cmap) + cbar = fig.colorbar(img, ax=ax) + depth_array_new = img.get_array() + plt.show() + + depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB) + plt.imshow(depth_array_new) + plt.show() + print(depth_array_new.shape) + + start = time.time() + masks = mask_generator.generate(depth_array_new) + end = time.time() + print(f"Inference time : {int((end-start) * 1000)}ms") + + + plt.figure(figsize=(15,15)) + plt.imshow(depth_array_new) + show_anns(masks) + show_points(new_points, plt.gca()) + plt.axis('off') + plt.show()""" + diff --git a/segment-anything b/segment-anything deleted file mode 160000 index 6fdee8f2727f4506cfbbe553e23b895e27956588..0000000000000000000000000000000000000000 --- a/segment-anything +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588