diff --git a/automatic_mask_train.py b/automatic_mask_train.py index b85f53c214be857867ff1f415635d2895790846c..1110115b001d048d25e115fdda97d47befcea4d9 100644 --- a/automatic_mask_train.py +++ b/automatic_mask_train.py @@ -63,6 +63,7 @@ if __name__ == '__main__': 'pred_iou_thresh': 0.88, 'sam_model': model_type, 'sam_path': checkpoint, + 'points_per_batch': 16 } cfg['decoder_kwargs'] = { 'pos_start_octave': -5, @@ -91,7 +92,7 @@ if __name__ == '__main__': images= [] from torchvision import transforms transform = transforms.ToTensor() - for j in range(1): + for j in range(2): image = images_path[j] img = cv2.imread(image) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) @@ -99,8 +100,11 @@ if __name__ == '__main__': #images.append(np.expand_dims(img, axis=0)) #images_np = np.array(images) images_t = torch.stack(images).to(device) + print(f"Shape image {images_t.shape}") - _, h, w = images_t[0][0].shape + images_t = images_t.permute(0, 1, 3, 4, 2) + + h, w, c = images_t[0][0].shape #print(f"Begin shape {images_np.shape}") points = model.encoder.mask_generator.points_grid new_points= [] @@ -115,14 +119,18 @@ if __name__ == '__main__': # TODO : set ray and camera directions #with torch.no_grad(): with torch.cuda.amp.autocast(): - masks, slots = model.encoder(images_t, (h, w), None, None) + masks, slots = model.encoder(images_t, (h, w), None, None, extract_masks=True) end = time.time() print(f"Inference time : {int((end-start) * 1000)}ms") if args.visualize: - plt.figure(figsize=(15,15)) - plt.imshow(img) - show_anns(masks[0][0]) # show masks - show_points(new_points, plt.gca()) # show points - #plt.axis('off') - plt.show() \ No newline at end of file + for j in range(2): + image = images_path[j] + img = cv2.imread(image) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + plt.figure(figsize=(15,15)) + plt.imshow(img) + show_anns(masks[j][0]) # show masks + show_points(new_points, plt.gca()) # show points + #plt.axis('off') + plt.show() \ No newline at end of file diff --git a/osrt/encoder.py b/osrt/encoder.py index a96650e3c43aa6cc4eee6834cbadd47a34aff773..aa258ee9355d00b4d7cb6828f35a91bc04736b21 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -131,25 +131,6 @@ class FeatureMasking(nn.Module): # We first initialize the automatic mask generator from SAM # TODO : change the loading here !!!! sam = sam_model_registry[sam_model](checkpoint=sam_path) - """if sam_model == "default" or sam_model == "vit_h": - encoder_embed_dim=1280 - encoder_depth=32 - encoder_num_heads=16 - encoder_global_attn_indexes=[7, 15, 23, 31] - elif sam_model == "vit_l": - encoder_embed_dim=1024 - encoder_depth=24 - encoder_num_heads=16 - encoder_global_attn_indexes=[5, 11, 17, 23] - else: - encoder_embed_dim=768 - encoder_depth=12 - encoder_num_heads=12 - encoder_global_attn_indexes=[2, 5, 8, 11] - prompt_embed_dim = 256 - image_size = 1024 - vit_patch_size = 16 - image_embedding_size = image_size // vit_patch_size""" self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder), copy.deepcopy(sam.prompt_encoder),