Skip to content
Snippets Groups Projects
Commit a9d24741 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add exraction of mask embeddings

parent ab044bc4
No related branches found
No related tags found
No related merge requests found
[submodule "segment-anything"]
path = segment-anything
url = git@github.com:facebookresearch/segment-anything.git
......@@ -7,6 +7,7 @@ import time
import matplotlib.pyplot as plt
import numpy as np
import cv2
import matplotlib
def show_anns(masks):
ax = plt.gca()
......@@ -58,7 +59,7 @@ if __name__ == '__main__':
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
#mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_side=8, points_per_batch=64, min_mask_region_area=4000)
sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_side=8, points_per_batch=64)#, min_mask_region_area=2000)
sam_mask.to(device)
transform = transforms.Compose([
......@@ -66,7 +67,10 @@ if __name__ == '__main__':
])
labels = [1 for i in range(len(sam_mask.points_grid))]
with torch.no_grad():
j=0
for image in images_path:
#import os
#os.mkdir(f"./results/test_{j}")
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
......@@ -86,13 +90,26 @@ if __name__ == '__main__':
img_batch.append(img_el)
start = time.time()
masks = sam_mask(img_batch)
masks = sam_mask(img_batch, extract_embeddings=True)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks[0]["annotations"])
show_points(new_points, plt.gca())
#plt.savefig(f"./results/test_{j}/masks.png")
plt.axis('off')
plt.show()
"""from PIL import Image
i = 0
for mask in masks[0]["annotations"]:
cm = matplotlib.cm.get_cmap('viridis')
img_src = Image.fromarray(mask["embeddings"]).convert('L')
im = np.array(img_src)
im = cm(im)
im = np.uint8(im * 255)
im = Image.fromarray(im)
im.save(f"./results/test_{j}/mask_{i}.png")
i+=1
j+=1"""
......@@ -10,9 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
from osrt.layers import RayEncoder, Transformer, SlotAttention
from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from segment_anything.modeling import Sam
from segment_anything.modeling import Sam
from segment_anything import sam_model_registry
from segment_anything.modeling.image_encoder import ImageEncoderViT
from segment_anything.modeling.mask_decoder import MaskDecoder
from segment_anything.modeling.prompt_encoder import PromptEncoder
......@@ -147,13 +145,15 @@ class FeatureMasking(nn.Module):
def forward(self, images):
# Generate images
masks = self.mask_generator(images)
masks = self.mask_generator(images, extract_embeddings=True)
# TODO : find a way to handle multiple image from a same scene instead of just one
set_latents = masks[:]["annotations"][:]["embeddings"]
num_masks = []
for batch in masks:
num_masks.append(len(batch["annotations"]))
set_latents = None
num_masks = None
# TODO : set the number of slots according to the masks number
# Set the number of slots for current batch
self.slot_attention.change_slots_number(num_masks)
# [batch_size, num_inputs, dim]
......@@ -233,7 +233,7 @@ class SamAutomaticMask(nn.Module):
def forward(
self,
batched_input: List[Dict[str, Any]],
extract_embeddings: bool = True
extract_embeddings: bool = False
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
......@@ -276,16 +276,11 @@ class SamAutomaticMask(nn.Module):
# Extract image embeddings
input_images = [self.preprocess(x["image"]) for x in batched_input][0]
with torch.no_grad():
image_embeddings = self.image_encoder(input_images)#, before_channel_reduc=True), embed_no_red
"""
# Extract image embedding before channel reduction, cf. https://github.com/facebookresearch/segment-anything/issues/283
if before_channel_reduc :
return x, embed_no_red """
image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
for image_record, curr_embedding, curr_emb_no_red in zip(batched_input, image_embeddings, embed_no_red):
# TODO : check if we've got the points given in the batch (to change the current point_grid !)
im_size = self.transform.apply_image(image_record["image"]).shape[:2]
points_scale = np.array(im_size)[None, ::-1]
points_for_image = self.points_grid * points_scale
......@@ -319,7 +314,8 @@ class SamAutomaticMask(nn.Module):
)
mask_data["segmentations"] = mask_data["masks"]
mask_embed = self.extract_mask_embedding(mask_data, embed_no_red, scale_box=1.5)
if extract_embeddings:
mask_embed = self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
# Write mask records
curr_anns = []
......@@ -332,8 +328,7 @@ class SamAutomaticMask(nn.Module):
"stability_score": mask_data["stability_score"][idx].item()
}
if extract_embeddings:
# TODO : add embeddings into the annotations
continue
ann["embeddings"] = mask_embed[idx]
curr_anns.append(ann)
outputs.append(
{
......@@ -575,7 +570,7 @@ class SamAutomaticMask(nn.Module):
return masks, iou_predictions, low_res_masks
def extract_mask_embedding(self, mask_data, image_embed, scale_box=1.5):
def extract_mask_embedding(self, mask_data, image_embed, input_size, scale_box=1.5):
"""
Predicts the embeddings from each mask given the global embedding and
a scale factor around each mask.
......@@ -588,31 +583,61 @@ class SamAutomaticMask(nn.Module):
Returns:
embeddings : the embeddings for each mask extracted from the image
"""
image_embed = image_embed.permute(2, 0, 1)
orig_H, orig_W = mask_data["segmentations"][0].shape[:2]
# We follow the same process to put the images back to the right format
scaled_img_emb = self.postprocess_masks(image_embed.unsqueeze(0), input_size, (orig_H, orig_W))
def scale_bounding_box(box, scale_factor, img_size):
x1, y1, x2, y2 = box
width = x2 - x1
height = y2 - y1
new_width = width * scale_factor
new_height = height * scale_factor
# Clamping values of the box inside of the image
new_x1 = int(max(0, x1 - (new_width - width) / 2))
new_y1 = int(max(0, y1 - (new_height - height) / 2))
new_x2 = int(min(img_size[1], new_x1 + new_width))
new_y2 = int(min(img_size[0], new_y1 + new_height))
return (new_x1, new_y1, new_x2, new_y2)
masks_embedding = []
for idx in range(len(mask_data["segmentations"])):
mask = mask_data["segmentations"][idx]
box = mask_data["boxes"][idx]
def scale_bounding_box(box, scale_factor):
x1, y1, x2, y2 = box
# Scale bounding box
scaled_box = scale_bounding_box(box, scale_box, (orig_H, orig_W))
width = x2 - x1
height = y2 - y1
# Crop image embedding around bbox
croped_im_embed = scaled_img_emb[0, :, scaled_box[1]:scaled_box[3], scaled_box[0]:scaled_box[2]].cpu().numpy() # [channels, h, w]
crop_mask = mask[scaled_box[1]:scaled_box[3], scaled_box[0]:scaled_box[2]]# [h, w]
new_width = width * scale_factor
new_height = height * scale_factor
new_x1 = x1 - (new_width - width) / 2
new_y1 = y1 - (new_height - height) / 2
new_x2 = new_x1 + new_width
new_y2 = new_y1 + new_height
# Apply mask to bounding box
print(f"{croped_im_embed[:].shape} {crop_mask.shape}")
masked_embed = croped_im_embed[:] * crop_mask # [channels, h, w]
return new_x1, new_y1, new_x2, new_y2
# Scale bounding box
scaled_box = scale_bounding_box(box, scale_box)
print(image_embed.shape)
# Apply average pooling on masked region
# TODO : find a way to export tokens
#final_token = np.mean(masked_embed, axis=(1, 2))
#print(f"Final token : {final_token}")
masks_embedding = None
####### masks_embedding.append(masked_embed)
#mean_embed = masked_embed / np.mean(masked_embed)
masks_embedding.append(masked_embed)
"""print(f"Shape of im embedding {scaled_img_emb.shape}")
print(f"Shape of masked embedding {masked_embed.shape}")
print(f"Shape of token {final_token.shape}")
print("########################")"""
return masks_embedding
def complete_holes(self,
......
......@@ -222,6 +222,7 @@ class SlotAttention(nn.Module):
Args:
inputs: set-latent representation [batch_size, num_inputs, dim]
"""
# TODO : change number slots depending on the batch
batch_size, num_inputs, dim = inputs.shape
inputs = self.norm_input(inputs)
......
Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment