Skip to content
Snippets Groups Projects
Commit ce20222c authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add batch and multi-image + extract embeddings for slot attention

parent 1af58823
No related branches found
No related tags found
No related merge requests found
import argparse
import torch
from osrt.encoder import SamAutomaticMask
from osrt.encoder import SamAutomaticMask, FeatureMasking
from segment_anything import sam_model_registry
import time
import matplotlib.pyplot as plt
......@@ -54,12 +54,14 @@ if __name__ == '__main__':
import random
random.shuffle(images_path)
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_side=12, points_per_batch=64)#, min_mask_region_area=2000)
sam_mask.to(device)
"""sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, points_per_side=12, points_per_batch=64)#, min_mask_region_area=2000)
sam_mask.to(device)"""
model = FeatureMasking( points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64)
model.to(device)
images= []
for j in range(8):
for j in range(1):
image = images_path[j]
img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
......@@ -68,7 +70,7 @@ if __name__ == '__main__':
#images_np = images_np.reshape(2, 2, images_np.shape[2], images_np.shape[3], images_np.shape[4])
h, w, _ = images_np[0][0].shape
points = sam_mask.points_grid
points = model.mask_generator.points_grid
new_points= []
for val in points:
x, y = val[0], val[1]
......@@ -80,14 +82,14 @@ if __name__ == '__main__':
start = time.time()
# TODO : set ray and camera directions
with torch.no_grad():
masks = sam_mask(images_np, (h, w),extract_embeddings=True)
masks, slots = model(images_np, (h, w), extract_masks=True)
end = time.time()
print(f"Inference time : {int((end-start) * 1000)}ms")
if args.visualize:
plt.figure(figsize=(15,15))
plt.imshow(img)
show_anns(masks[0]) # show masks
show_anns(masks[0][0]) # show masks
show_points(new_points, plt.gca()) # show points
plt.axis('off')
plt.show()
......
......@@ -3,13 +3,13 @@ import torch
import torch.nn as nn
from typing import Any, Dict, List, Optional, Tuple
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore
import math
from torchvision.ops.boxes import batched_nms
import torchvision.transforms.functional as func
from typing import Any, Dict, List, Optional, Tuple
from osrt.layers import RayEncoder, Transformer, SlotAttention
from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score, get_indices_sorted_pos, get_positional_embedding
from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score, get_indices_sorted_pos, get_positional_embedding, create_points_grid
from segment_anything import sam_model_registry
from segment_anything.modeling.image_encoder import ImageEncoderViT
......@@ -116,12 +116,14 @@ class FeatureMasking(nn.Module):
stability_score_thresh = 0.9,
pred_iou_thresh=0.88,
points_per_batch=64,
min_mask_region_area=4000,
min_mask_region_area=0,
num_slots=6,
slot_dim=1536,
slot_iters=1,
slot_iters=1,
num_att_blocks=5,
sam_model="default",
sam_path="sam_vit_h_4b8939.pth",
tokenizer="mean",
randomize_initial_slots=False):
super().__init__()
......@@ -135,12 +137,16 @@ class FeatureMasking(nn.Module):
pred_iou_thresh=pred_iou_thresh,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
tokenizer=tokenizer,
min_mask_region_area=min_mask_region_area)
self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters,
self.transformer = Transformer(self.mask_generator.token_dim, depth=num_att_blocks, heads=12, dim_head=64,
mlp_dim=1536, selfatt=True)
self.slot_attention = SlotAttention(num_slots, input_dim=self.mask_generator.token_dim, slot_dim=slot_dim, iters=slot_iters,
randomize_initial_slots=randomize_initial_slots)
def forward(self, images, original_size, camera_pos=None, rays=None):
def forward(self, images, original_size, camera_pos=None, rays=None, extract_masks=True):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
......@@ -160,20 +166,30 @@ class FeatureMasking(nn.Module):
# Generate images
masks = self.mask_generator(images, original_size, camera_pos, rays, extract_embeddings=True) # [B, N]
num_masks = []
for batch in masks:
num_masks.append(len(batch))
set_latents = masks[:]["embeddings"]
# Set the number of slots for current batch
self.slot_attention.change_slots_number(num_masks)
B, N = masks.shape
dim = masks[0][0]["embeddings"][0].shape[0]
# We infer each batch separetely as it handle a different number of slots
set_latents = None
for b in range(B):
latents_batch = torch.empty((1, dim), device=self.mask_generator.device)
# TODO : set a new number of slots
for n in range(N):
embeds = masks[b][n]["embeddings"]
for embed in embeds:
latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
if set_latents == None:
set_latents = latents_batch.unsqueeze(0)
else:
set_latents = torch.cat((set_latents, latents_batch.unsqueeze(0)), 0)
# [batch_size, num_inputs, dim]
slot_latents = self.slot_attention(set_latents)
return slot_latents
if extract_masks:
return masks, slot_latents
else:
return slot_latents
class SamAutomaticMask(nn.Module):
mask_threshold: float = 0.0
......@@ -194,6 +210,7 @@ class SamAutomaticMask(nn.Module):
box_nms_thresh: float = 0.7,
min_mask_region_area: int = 0,
pos_start_octave=0,
tokenizer="mean",
patch_size = 16
) -> None:
"""
......@@ -232,9 +249,9 @@ class SamAutomaticMask(nn.Module):
self.transform = ResizeLongestSide(self.image_encoder.img_size)
if points_per_side > 0:
self.points_grid = self.create_points_grid(points_per_side)
self.points_grid = create_points_grid(points_per_side)
else:
self.points_grid = None
self.points_grid = create_points_grid(32)
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
......@@ -243,15 +260,18 @@ class SamAutomaticMask(nn.Module):
self.min_mask_region_area = min_mask_region_area
# TODO : set the token dim and the input size
input_size = 0 # depends on the image size
self.token_dim = (self.image_encoder.img_size // patch_size)**2
self.tokenizer = nn.Sequential(
nn.Linear(input_size, 100),
nn.ReLU(),
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, self.token_dim),
)
self.tokenizer_type = tokenizer
if tokenizer == "mlp":
input_size = 0 # depends on the image size
self.token_dim = (self.image_encoder.img_size // patch_size)**2
self.tokenizer = nn.Sequential(
nn.Linear(input_size, 100),
nn.ReLU(),
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, self.token_dim),
)
# Space positional embedding
self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
......@@ -268,7 +288,6 @@ class SamAutomaticMask(nn.Module):
camera_pos=None,
rays=None,
extract_embeddings = False):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
......@@ -299,7 +318,8 @@ class SamAutomaticMask(nn.Module):
image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True) # [B x N, H, W, C]
# TODO : add camera position embedding to the 2D image embedding with @position_embeding_3d
annotations = []
annotations = np.empty((B, N), dtype=object)
i = 0
for curr_embedding, curr_emb_no_red in zip(image_embeddings, embed_no_red):
mask_data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
......@@ -322,16 +342,7 @@ class SamAutomaticMask(nn.Module):
# Extract mask embeddings
if extract_embeddings:
self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
"""print(f"Before concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
for tensor in mask_data['embeddings']:
print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break"""
#mask_data['embeddings'] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data['embeddings']]
#print(f"After concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
# TODO : ajouter 3D positional embedding ici
mask_data.to_numpy()
......@@ -343,19 +354,8 @@ class SamAutomaticMask(nn.Module):
self.box_nms_thresh,
)
# TODO : have a more efficient way to store the data
# Write mask records
"""curr_anns = []
print(mask_data.items())
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item()
}
if extract_embeddings:
ann["embeddings"] = mask_data["embeddings"][idx]"""
if extract_embeddings:
curr_ann = {
"embeddings": mask_data["embeddings"],
......@@ -365,10 +365,11 @@ class SamAutomaticMask(nn.Module):
curr_ann = {
"segmentations": mask_data["segmentations"]
}
#annotations.append({"annotations": curr_anns})
annotations.append(curr_ann)
annotations = np.array(annotations).reshape(B*N)
return annotations # [BxN] : dict containing diverse annotations such as segmentation, area or also embedding
batch = math.floor((i / N))
num_im = i % N
annotations[batch][num_im] = curr_ann
i+=1
return annotations # [B, N, 1] : dict containing diverse annotations such as segmentation, area or also embedding
def postprocess_masks(
self,
......@@ -467,15 +468,6 @@ class SamAutomaticMask(nn.Module):
x = F.pad(x, (0, padw, 0, padh))
return x
def create_points_grid(self, number_points):
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * number_points)
points_one_side = np.linspace(offset, 1 - offset, number_points)
points_x = np.tile(points_one_side[None, :], (number_points, 1))
points_y = np.tile(points_one_side[:, None], (1, number_points))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def process_batch(
self,
points: np.ndarray,
......@@ -629,7 +621,7 @@ class SamAutomaticMask(nn.Module):
indices = get_indices_sorted_pos(mask_data)
mask_data.sort_by_indices(indices)
# TODO : add positional encoding
# TODO : add 3D positional encoding
for idx in range(len(mask_data["segmentations"])):
mask = mask_data["segmentations"][idx]
......@@ -648,51 +640,6 @@ class SamAutomaticMask(nn.Module):
# Apply mask to image embedding
mask_data["embeddings"].append(torch.tensor(mask_embed, device=self.device)) # [token_dim]
def complete_holes(self,
masks):
""""
The purpose of this function is to segment EVERYTHING from the image, without letting any remaining hole
"""
total_mask = masks[0]
for idx in range(len(masks)):
if idx > 0:
total_mask += masks[idx]
des = total_mask.astype(np.uint8)*255
kernel = np.ones((4, 4), np.uint8)
img_dilate = cv2.dilate(des, kernel, iterations=1)
import matplotlib.pyplot as plt
plt.imshow(img_dilate)
plt.show()
inverse_dilate = np.zeros((total_mask.shape), dtype=np.uint8)
inverse_dilate = np.logical_not(img_dilate).astype(np.uint8)*255
contours, _ = cv2.findContours(inverse_dilate, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
result_masks = []
for contour in contours:
area = cv2.contourArea(contour)
if area > 4000:
mask = np.zeros((total_mask.shape), dtype=np.uint8)
cv2.drawContours(mask, [contour], 0, 255, -1)
result_masks.append(mask)
new_masks_data = MaskData(
masks=torch.tensor(result_masks),
iou_preds=torch.tensor([0.9 for i in range(len(result_masks))])
)
new_masks_data["stability_score"] = calculate_stability_score(
new_masks_data["masks"], self.mask_threshold, self.stability_score_offset
)
new_masks_data["boxes"] = batched_mask_to_box(new_masks_data["masks"])
new_masks_data["rles"] = mask_to_rle_pytorch(new_masks_data["masks"])
return new_masks_data.to_numpy()
def position_embeding_3d(self, img_feats, camera_info):
# TODO : adapter cette fonction à notre usage
"""
......
......@@ -184,6 +184,8 @@ class Transformer(nn.Module):
class SlotAttention(nn.Module):
"""
Slot Attention as introduced by Locatello et al.
@edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
"""
def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
randomize_initial_slots=False):
......@@ -225,7 +227,7 @@ class SlotAttention(nn.Module):
inputs = self.norm_input(inputs)
if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1)
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
else:
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1)
......@@ -242,13 +244,13 @@ class SlotAttention(nn.Module):
# shape: [batch_size, num_slots, num_inputs]
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum('bjd,bij->bid', v, attn)
updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim]
slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
slots = slots + self.mlp(self.norm_pre_mlp(slots))
return slots
return slots # [batch_size, num_slots, dim]
def change_slots_number(self, num_slots):
self.num_slots = num_slots
......@@ -99,6 +99,15 @@ class MaskData:
if isinstance(v, torch.Tensor):
self._stats[k] = v.detach().cpu().numpy()
def create_points_grid(number_points):
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * number_points)
points_one_side = np.linspace(offset, 1 - offset, number_points)
points_x = np.tile(points_one_side[None, :], (number_points, 1))
points_y = np.tile(points_one_side[:, None], (1, number_points))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
def get_positional_embedding(position, token_dim):
position_encodings = np.zeros(token_dim)
div_term = np.exp(np.arange(0, token_dim, 2).astype(np.float32) * (-math.log(10000.0) / token_dim))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment