Skip to content
Snippets Groups Projects
Commit e6fe586f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change mask decoder

parent 26447806
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,7 @@ if __name__ == '__main__':
'pred_iou_thresh': 0.88,
'sam_model': model_type,
'sam_path': checkpoint,
'points_per_batch': 16
}
cfg['decoder_kwargs'] = {
'pos_start_octave': -5,
......@@ -91,7 +92,7 @@ if __name__ == '__main__':
images= []
from torchvision import transforms
transform = transforms.ToTensor()
for j in range(1):
for j in range(2):
image = images_path[j]
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
......@@ -99,8 +100,11 @@ if __name__ == '__main__':
#images.append(np.expand_dims(img, axis=0))
#images_np = np.array(images)
images_t = torch.stack(images).to(device)
print(f"Shape image {images_t.shape}")
_, h, w = images_t[0][0].shape
images_t = images_t.permute(0, 1, 3, 4, 2)
h, w, c = images_t[0][0].shape
#print(f"Begin shape {images_np.shape}")
points = model.encoder.mask_generator.points_grid
new_points= []
......@@ -115,14 +119,18 @@ if __name__ == '__main__':
# TODO : set ray and camera directions
#with torch.no_grad():
with torch.cuda.amp.autocast():
masks, slots = model.encoder(images_t, (h, w), None, None)
masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
if args.visualize:
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks[0][0]) # show masks
show_points(new_points, plt.gca()) # show points
#plt.axis('off')
plt.show()
\ No newline at end of file
for j in range(2):
image = images_path[j]
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks[j][0]) # show masks
show_points(new_points, plt.gca()) # show points
#plt.axis('off')
plt.show()
\ No newline at end of file
......@@ -131,25 +131,6 @@ class FeatureMasking(nn.Module):
# We first initialize the automatic mask generator from SAM
# TODO : change the loading here !!!!
sam = sam_model_registry[sam_model](checkpoint=sam_path)
"""if sam_model == "default" or sam_model == "vit_h":
encoder_embed_dim=1280
encoder_depth=32
encoder_num_heads=16
encoder_global_attn_indexes=[7, 15, 23, 31]
elif sam_model == "vit_l":
encoder_embed_dim=1024
encoder_depth=24
encoder_num_heads=16
encoder_global_attn_indexes=[5, 11, 17, 23]
else:
encoder_embed_dim=768
encoder_depth=12
encoder_num_heads=12
encoder_global_attn_indexes=[2, 5, 8, 11]
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size"""
self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder),
copy.deepcopy(sam.prompt_encoder),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment