Skip to content
Snippets Groups Projects
Commit 2d0f120c authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Tests on time optimization

parent e6fe586f
No related branches found
No related tags found
No related merge requests found
import argparse import argparse
import torch import torch
from osrt.encoder import SamAutomaticMask, FeatureMasking
from osrt.model import OSRT from osrt.model import OSRT
from segment_anything import sam_model_registry
import time import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import cv2 import cv2
def show_anns(masks): def show_anns(masks):
ax = plt.gca() ax = plt.gca()
ax.set_autoscale_on(False) ax.set_autoscale_on(False)
...@@ -57,13 +56,13 @@ if __name__ == '__main__': ...@@ -57,13 +56,13 @@ if __name__ == '__main__':
cfg['encoder'] = 'sam' cfg['encoder'] = 'sam'
cfg['decoder'] = 'slot_mixer' cfg['decoder'] = 'slot_mixer'
cfg['encoder_kwargs'] = { cfg['encoder_kwargs'] = {
'points_per_side': 12, 'points_per_side': 32,
'box_nms_thresh': 0.7, 'box_nms_thresh': 0.7,
'stability_score_thresh': 0.9, 'stability_score_thresh': 0.9,
'pred_iou_thresh': 0.88, 'pred_iou_thresh': 0.88,
'sam_model': model_type, 'sam_model': model_type,
'sam_path': checkpoint, 'sam_path': checkpoint,
'points_per_batch': 16 'points_per_batch': 12
} }
cfg['decoder_kwargs'] = { cfg['decoder_kwargs'] = {
'pos_start_octave': -5, 'pos_start_octave': -5,
...@@ -71,10 +70,10 @@ if __name__ == '__main__': ...@@ -71,10 +70,10 @@ if __name__ == '__main__':
model = OSRT(cfg)#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64) model = OSRT(cfg)#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64)
model.to(device, non_blocking=True) model.to(device, non_blocking=True)
num_encoder_params = sum(p.numel() for p in model.encoder.parameters()) """num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
num_decoder_params = sum(p.numel() for p in model.decoder.parameters()) num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
"""print('Number of parameters:') print('Number of parameters:')
print(f'\tEncoder: {num_encoder_params}') print(f'\tEncoder: {num_encoder_params}')
num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters()) num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters())
...@@ -87,12 +86,12 @@ if __name__ == '__main__': ...@@ -87,12 +86,12 @@ if __name__ == '__main__':
print(f'\t\t\tMask Decoder: {num_mask_params}.') print(f'\t\t\tMask Decoder: {num_mask_params}.')
print(f'\t\t\tPrompt Encoder: {num_prompt_params}.') print(f'\t\t\tPrompt Encoder: {num_prompt_params}.')
print(f'\t\tSlot Attention: {num_slotatt_params}.') print(f'\t\tSlot Attention: {num_slotatt_params}.')
print(f'\tDecoder: {num_decoder_params}') print(f'\tDecoder: {num_decoder_params}')"""
"""
images= [] images= []
from torchvision import transforms from torchvision import transforms
transform = transforms.ToTensor() transform = transforms.ToTensor()
for j in range(2): for j in range(10):
image = images_path[j] image = images_path[j]
img = cv2.imread(image) img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
...@@ -119,12 +118,12 @@ if __name__ == '__main__': ...@@ -119,12 +118,12 @@ if __name__ == '__main__':
# TODO : set ray and camera directions # TODO : set ray and camera directions
#with torch.no_grad(): #with torch.no_grad():
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True) masks = model.encoder.mask_generator(images_t, (h, w), None, None, extract_embeddings=False)
end = time.time() end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms") print(f"Inference time : {int((end-start) * 1000)}ms")
if args.visualize: if args.visualize:
for j in range(2): for j in range(10):
image = images_path[j] image = images_path[j]
img = cv2.imread(image) img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
......
...@@ -123,8 +123,8 @@ class FeatureMasking(nn.Module): ...@@ -123,8 +123,8 @@ class FeatureMasking(nn.Module):
num_slots=32, num_slots=32,
slot_dim=1536, slot_dim=1536,
slot_iters=3, slot_iters=3,
sam_model="default", sam_model="vit_t",
sam_path="sam_vit_h_4b8939.pth", sam_path="mobile_sam.pt",
randomize_initial_slots=False): randomize_initial_slots=False):
super().__init__() super().__init__()
...@@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module): ...@@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module):
self.mask_decoder = mask_decoder self.mask_decoder = mask_decoder
for param in self.mask_decoder.parameters(): for param in self.mask_decoder.parameters():
param.requires_grad = True param.requires_grad = True
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
# Transform image to a square by putting it to the longest side # Transform image to a square by putting it to the longest side
#self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR) #self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
...@@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module): ...@@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module):
nn.Linear(2500, self.token_dim), nn.Linear(2500, self.token_dim),
) )
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
# Space positional embedding # Space positional embedding
self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
ray_octaves=15) ray_octaves=15)
......
import argparse
import torch import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from torchvision import transforms
from PIL import Image
import time
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np import numpy as np
import cv2 import cv2
import time
def show_anns(masks): def show_anns(anns):
if len(masks) == 0: if len(anns) == 0:
return return
sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True) sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca() ax = plt.gca()
ax.set_autoscale_on(False) ax.set_autoscale_on(False)
...@@ -24,100 +20,44 @@ def show_anns(masks): ...@@ -24,100 +20,44 @@ def show_anns(masks):
img[m] = color_mask img[m] = color_mask
ax.imshow(img) ax.imshow(img)
def show_points(coords, ax, marker_size=100):
ax.scatter(coords[:, 0], coords[:, 1], color='#2ca02c', marker='.', s=marker_size)
if __name__ == '__main__':
# Arguments
parser = argparse.ArgumentParser(
description='Test Segment Anything Auto Mask simplified implementation'
)
parser.add_argument('--model', default='vit_b', type=str, help='Model to use')
parser.add_argument('--path_model', default='.', type=str, help='Path to the model')
args = parser.parse_args()
device = "cuda"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
model_type = args.model
if args.model == 'vit_h':
checkpoint = args.path_model + '/sam_vit_h_4b8939.pth'
elif args.model == 'vit_b':
checkpoint = args.path_model + '/sam_vit_b_01ec64.pth'
else:
checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
images_path = []
with open(ycb_path + "image_sets/train.txt", 'r') as f:
for line in f.readlines():
line = line.strip()
images_path.append(ycb_path + 'data/' + line + "-color.png")
import random model_type = "vit_t"
random.shuffle(images_path) sam_checkpoint = "./mobile_sam.pt"
sam = sam_model_registry[model_type](checkpoint=checkpoint) device = "cuda" if torch.cuda.is_available() else "cpu"
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
transform = transforms.Compose([ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
transforms.ToTensor(), mobile_sam.to(device=device)
]) mobile_sam.eval()
labels = [1 for i in range(len(mask_generator.point_grids))]
with torch.no_grad():
for image in images_path:
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
"""img_depth = cv2.imread(image.replace("color", "depth"))
img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)"""
h, w, _ = img.shape mask_generator = SamAutomaticMaskGenerator(mobile_sam, points_per_side=16, points_per_batch= 12)
points = mask_generator.point_grids[0] ycb_path = "/home/achapin/Documents/Datasets/YCB_Video_Dataset/"
new_points= [] images_path = []
for val in points: with open(ycb_path + "image_sets/train.txt", 'r') as f:
x, y = val[0], val[1] for line in f.readlines():
x *= w line = line.strip()
y *= h images_path.append(ycb_path + 'data/' + line + "-color.png")
new_points.append([x, y])
new_points = np.array(new_points)
start = time.time() import random
masks = mask_generator.generate(img) #random.shuffle(images_path)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks)
show_points(new_points, plt.gca())
plt.axis('off')
plt.show()
"""fig, ax = plt.subplots()
cmap = plt.cm.get_cmap('plasma')
img = ax.imshow(img_depth, cmap=cmap)
cbar = fig.colorbar(img, ax=ax)
depth_array_new = img.get_array()
plt.show()
depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB)
plt.imshow(depth_array_new)
plt.show()
print(depth_array_new.shape)
start = time.time()
masks = mask_generator.generate(depth_array_new)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
plt.figure(figsize=(15,15))
plt.imshow(depth_array_new)
show_anns(masks)
show_points(new_points, plt.gca())
plt.axis('off')
plt.show()"""
images= []
from torchvision import transforms
transform = transforms.ToTensor()
for j in range(20):
image = images_path[j]
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(transform(img).unsqueeze(0))
images_t = torch.stack(images).to(device)
start = time.time()
masks = mask_generator.generate(images_t)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks) # show masks
plt.axis('off')
plt.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment