Skip to content
Snippets Groups Projects
Commit 48bdf252 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Static slot

parent 52d11e07
No related branches found
No related tags found
No related merge requests found
......@@ -7,19 +7,19 @@ https://github.com/bowang-lab/MedSAM/blob/main/utils/precompute_img_embed.py
import numpy as np
import os
join = os.path.join
from skimage import io, segmentation
from tqdm import tqdm
import torch
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
import argparse
import cv2
parser = argparse.ArgumentParser(
description='Extract image embeddings from SAM model'
)
parser.add_argument('-i', '--img_path', type=str, default='', help='Path to the folder containing images')
parser.add_argument('-o', '--save_path', type=str, default='', help='Path to the file containing the embeddings')
parser.add_argument('-o', '--save_path', type=str, default='', help='Path to the folder containing the final embeddings')
parser.add_argument('--model_type', type=str, default='vit_h', help='model type')
parser.add_argument('--path_model', type=str, default='.', help='path to the pre-trained SAM model')
args = parser.parse_args()
......@@ -34,22 +34,22 @@ else:
model_type = 'vit_l'
checkpoint = args.path_model + '/sam_vit_l_0b3195.pth'
pre_img_path = args.img_path
save_img_emb_path = join(args.save_path, 'npy_embs')
save_gt_path = join(args.save_path, 'npy_gts')
os.makedirs(save_img_emb_path, exist_ok=True)
os.makedirs(save_gt_path, exist_ok=True)
npz_files = sorted(os.listdir(pre_img_path))
img_path = args.img_path
img_files= sorted(os.listdir(img_path))
sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint).to('cuda:0')
sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
# compute image embeddings
for name in tqdm(npz_files):
img = np.load(join(pre_img_path, name))['img'] # (256, 256, 3)
gt = np.load(join(pre_img_path, name))['gt']
images= []
for name in tqdm(img_files):
img = cv2.imread(name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_np = np.array(img)
resize_img = sam_transform.apply_image(img)
resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to('cuda:0')
# model input: (1, 3, 1024, 1024)
input_image = sam_model.preprocess(resize_img_tensor[None,:,:,:]) # (1, 3, 1024, 1024)
assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'
......@@ -57,10 +57,4 @@ for name in tqdm(npz_files):
embedding = sam_model.image_encoder(input_image)
# save as npy
np.save(join(save_img_emb_path, name.split('.npz')[0]+'.npy'), embedding.cpu().numpy()[0])
np.save(join(save_gt_path, name.split('.npz')[0]+'.npy'), gt)
# sanity check
img_idx = img.copy()
bd = segmentation.find_boundaries(gt, mode='inner')
img_idx[bd, :] = [255, 0, 0]
io.imsave(save_img_emb_path + '.png', img_idx)
\ No newline at end of file
np.save(join(args.save_path, name.split('.')[0]+'.npy'), embedding.cpu().numpy()[0])
\ No newline at end of file
......@@ -120,7 +120,7 @@ class FeatureMasking(nn.Module):
pred_iou_thresh=0.88,
points_per_batch=64,
min_mask_region_area=0,
num_slots=6,
num_slots=10,
slot_dim=1536,
slot_iters=1,
sam_model="default",
......@@ -193,14 +193,14 @@ class FeatureMasking(nn.Module):
set_latents = None
# TODO : set the number of slots according to either we want min or max
with torch.no_grad():
num_slots = 100000
#num_slots = 100000
embedding_batch = []
masks_batch = []
for b in range(B):
latents_batch = torch.empty((0, dim), device=self.mask_generator.device)
for n in range(N):
embeds = masks[b][n]["embeddings"]
num_slots = min(len(embeds), num_slots)
#num_slots = min(len(embeds), num_slots)
for embed in embeds:
latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
masks_batch.append(torch.zeros(latents_batch.shape[:1]))
......@@ -209,7 +209,7 @@ class FeatureMasking(nn.Module):
attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
# [batch_size, num_inputs = num_mask_embed x num_im, dim]
self.slot_attention.change_slots_number(num_slots)
#self.slot_attention.change_slots_number(num_slots)
slot_latents = self.slot_attention(set_latents, attention_mask)
if extract_masks:
......
......@@ -251,7 +251,7 @@ class SlotAttention(nn.Module):
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim]
slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
slots = slots + self.mlp(self.norm_pre_mlp(slots))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment