Skip to content
Snippets Groups Projects
Commit e6fe586f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change mask decoder

parent 26447806
No related branches found
No related tags found
No related merge requests found
...@@ -63,6 +63,7 @@ if __name__ == '__main__': ...@@ -63,6 +63,7 @@ if __name__ == '__main__':
'pred_iou_thresh': 0.88, 'pred_iou_thresh': 0.88,
'sam_model': model_type, 'sam_model': model_type,
'sam_path': checkpoint, 'sam_path': checkpoint,
'points_per_batch': 16
} }
cfg['decoder_kwargs'] = { cfg['decoder_kwargs'] = {
'pos_start_octave': -5, 'pos_start_octave': -5,
...@@ -91,7 +92,7 @@ if __name__ == '__main__': ...@@ -91,7 +92,7 @@ if __name__ == '__main__':
images= [] images= []
from torchvision import transforms from torchvision import transforms
transform = transforms.ToTensor() transform = transforms.ToTensor()
for j in range(1): for j in range(2):
image = images_path[j] image = images_path[j]
img = cv2.imread(image) img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
...@@ -99,8 +100,11 @@ if __name__ == '__main__': ...@@ -99,8 +100,11 @@ if __name__ == '__main__':
#images.append(np.expand_dims(img, axis=0)) #images.append(np.expand_dims(img, axis=0))
#images_np = np.array(images) #images_np = np.array(images)
images_t = torch.stack(images).to(device) images_t = torch.stack(images).to(device)
print(f"Shape image {images_t.shape}")
_, h, w = images_t[0][0].shape images_t = images_t.permute(0, 1, 3, 4, 2)
h, w, c = images_t[0][0].shape
#print(f"Begin shape {images_np.shape}") #print(f"Begin shape {images_np.shape}")
points = model.encoder.mask_generator.points_grid points = model.encoder.mask_generator.points_grid
new_points= [] new_points= []
...@@ -115,14 +119,18 @@ if __name__ == '__main__': ...@@ -115,14 +119,18 @@ if __name__ == '__main__':
# TODO : set ray and camera directions # TODO : set ray and camera directions
#with torch.no_grad(): #with torch.no_grad():
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
masks, slots = model.encoder(images_t, (h, w), None, None) masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True)
end = time.time() end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms") print(f"Inference time : {int((end-start) * 1000)}ms")
if args.visualize: if args.visualize:
plt.figure(figsize=(15,15)) for j in range(2):
plt.imshow(img) image = images_path[j]
show_anns(masks[0][0]) # show masks img = cv2.imread(image)
show_points(new_points, plt.gca()) # show points img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#plt.axis('off') plt.figure(figsize=(15,15))
plt.show() plt.imshow(img)
\ No newline at end of file show_anns(masks[j][0]) # show masks
show_points(new_points, plt.gca()) # show points
#plt.axis('off')
plt.show()
\ No newline at end of file
...@@ -131,25 +131,6 @@ class FeatureMasking(nn.Module): ...@@ -131,25 +131,6 @@ class FeatureMasking(nn.Module):
# We first initialize the automatic mask generator from SAM # We first initialize the automatic mask generator from SAM
# TODO : change the loading here !!!! # TODO : change the loading here !!!!
sam = sam_model_registry[sam_model](checkpoint=sam_path) sam = sam_model_registry[sam_model](checkpoint=sam_path)
"""if sam_model == "default" or sam_model == "vit_h":
encoder_embed_dim=1280
encoder_depth=32
encoder_num_heads=16
encoder_global_attn_indexes=[7, 15, 23, 31]
elif sam_model == "vit_l":
encoder_embed_dim=1024
encoder_depth=24
encoder_num_heads=16
encoder_global_attn_indexes=[5, 11, 17, 23]
else:
encoder_embed_dim=768
encoder_depth=12
encoder_num_heads=12
encoder_global_attn_indexes=[2, 5, 8, 11]
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size"""
self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder), self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder),
copy.deepcopy(sam.prompt_encoder), copy.deepcopy(sam.prompt_encoder),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment