Skip to content
Snippets Groups Projects
Commit a037d337 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add beginning of batch forward and multi-images

parent 61bf2c1e
No related branches found
No related tags found
No related merge requests found
...@@ -68,50 +68,55 @@ if __name__ == '__main__': ...@@ -68,50 +68,55 @@ if __name__ == '__main__':
]) ])
labels = [1 for i in range(len(sam_mask.points_grid))] labels = [1 for i in range(len(sam_mask.points_grid))]
with torch.no_grad(): with torch.no_grad():
j=0 images= []
for image in images_path: for j in range(16):
image = images_path[j]
#import os #import os
#os.mkdir(f"./results/test_{j}") #os.mkdir(f"./results/test_{j}")
img = cv2.imread(image) img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(np.expand_dims(img, axis=0))
images_np = np.array(images)
h, w, _ = images_np[0][0].shape
points = sam_mask.points_grid
new_points= []
for val in points:
x, y = val[0], val[1]
x *= w
y *= h
new_points.append([x, y])
new_points = np.array(new_points)
start = time.time()
camera_pos = torch.tensor(np.expand_dims(np.expand_dims(np.array([0, 0, 0]), axis=0), axis=0))
# TODO : set the right direction
ray_dir = torch.tensor(np.expand_dims(np.expand_dims(np.expand_dims(np.expand_dims(np.array([0, 0, 0]), axis=0), axis=0), axis=0), axis=0))
h, w, _ = img.shape print(camera_pos.shape)
points = sam_mask.points_grid print(ray_dir.shape)
new_points= []
for val in points:
x, y = val[0], val[1]
x *= w
y *= h
new_points.append([x, y])
new_points = np.array(new_points)
img_batch = []
img_el = {}
img_el["image"] = img
img_el["original_size"] = (h, w)
img_batch.append(img_el)
start = time.time() masks = sam_mask(images_np, [(h, w)],extract_embeddings=True)
masks = sam_mask(img_batch, extract_embeddings=True) end = time.time()
end = time.time() print(f"Inference time : {int((end-start) * 1000)}ms")
print(f"Inference time : {int((end-start) * 1000)}ms") """plt.figure(figsize=(15,15))
plt.figure(figsize=(15,15)) plt.imshow(img)
plt.imshow(img) show_anns(masks[0]["annotations"])
show_anns(masks[0]["annotations"]) show_points(new_points, plt.gca())
show_points(new_points, plt.gca()) #plt.savefig(f"./results/test_{j}/masks.png")
#plt.savefig(f"./results/test_{j}/masks.png") plt.axis('off')
plt.axis('off') plt.show()
i = 0"""
"""for mask in masks[0]["annotations"]:
cm = matplotlib.cm.get_cmap('plasma')
plt.imshow(mask["embeddings"])
plt.show()
im = cm(mask["embeddings"])
#im = np.uint8(im * 255)
plt.imshow(im)
plt.show() plt.show()
i = 0 #im = Image.fromarray(im)
"""for mask in masks[0]["annotations"]: #im.save(f"./results/test_{j}/mask_{i}.png")
cm = matplotlib.cm.get_cmap('plasma') i+=1
plt.imshow(mask["embeddings"]) j+=1"""
plt.show()
im = cm(mask["embeddings"])
#im = np.uint8(im * 255)
plt.imshow(im)
plt.show()
#im = Image.fromarray(im)
#im.save(f"./results/test_{j}/mask_{i}.png")
i+=1
j+=1"""
...@@ -245,86 +245,140 @@ class SamAutomaticMask(nn.Module): ...@@ -245,86 +245,140 @@ class SamAutomaticMask(nn.Module):
self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
ray_octaves=15) ray_octaves=15)
@property @property
def device(self) -> Any: def device(self) -> Any:
return self.pixel_mean.device return self.pixel_mean.device
def forward( def forward(
self, self,
batched_input: List[Dict[str, Any]], images,
orig_size,
camera_pos=None, camera_pos=None,
rays=None, rays=None,
extract_embeddings: bool = False, extract_embeddings = False):
) -> List[Dict[str, torch.Tensor]]:
""" """
Predicts masks end-to-end from provided images and prompts. Args:
If prompts are not known in advance, using SamPredictor is images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
recommended over calling the model directly. original_size: tuple(height, width) The original size of the image before transformation.
camera_pos: [batch_size, num_images, 3]
Arguments: rays: [batch_size, num_images, height, width, 3]
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'image': The image as in 3xHxW format
'original_size': (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns: Returns:
(list(dict)): A list over input images, where each element is (list(dict)): A list over input images, where each element is
as dictionary with the following keys. as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions, masks: (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts, with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the C is determined by multimask_output, and (H, W) is the
original size of the image. original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions iou_predictions: (torch.Tensor) The model's predictions
of mask quality, in shape BxC. of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with low_res_logits: (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction. to subsequent iterations of prediction.
scene representation: [batch_size, num_patches, channels_per_patch]
""" """
# Extract image embeddings
# TODO : handle the following to concatenate token with camera position # TODO : handle the following to concatenate token with camera position
"""camera_pos = camera_pos.flatten(0, 1) # Encode camera and position and direction following SRT's paper
rays = rays.flatten(0, 1) """if len(camera_pos) > 0 and len(rays) > 0 and extract_embeddings:
camera_pos = camera_pos.flatten(0, 1)
ray_enc = self.ray_encoder(camera_pos, rays) rays = rays.flatten(0, 1)
x = torch.cat((x, ray_enc), 1)""" ray_enc = self.ray_encoder(camera_pos, rays).to(self.device)
#x = torch.cat((x, ray_enc), 1)"""
B, N, H, W, C = images.shape
outputs = []
for batch in range(B):
input_images = torch.zeros(0, C, self.image_encoder.img_size, self.image_encoder.img_size).to(self.device)
for img in images[batch]:
input_images = torch.cat((input_images, self.preprocess(img)), dim=0)
with torch.no_grad():
image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True)
for n in range(len(input_images)):
curr_embedding = image_embeddings[n]
curr_emb_no_red = embed_no_red[n]
image_record = input_images[n]
im_size = self.transform.apply_image(image_record).shape[:2]
points_scale = np.array(im_size)[None, ::-1]
points_for_image = self.points_grid * points_scale
mask_data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self.process_batch(points, im_size, curr_embedding, orig_size)
mask_data.cat(batch_data)
del batch_data
del curr_embedding
# Remove duplicates
keep_by_nms = batched_nms(
mask_data["boxes"].float(),
mask_data["iou_preds"],
torch.zeros_like(mask_data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
mask_data.filter(keep_by_nms)
input_images = [self.preprocess(x["image"]) for x in batched_input][0] mask_data["segmentations"] = mask_data["masks"]
with torch.no_grad():
image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True)
# Extract mask embeddings
# Handle images if extract_embeddings:
outputs = [] self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
for image_record, curr_embedding, curr_emb_no_red in zip(batched_input, image_embeddings, embed_no_red): """print(f"Before concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
# TODO : check if we've got the points given in the batch (to change the current point_grid !) for tensor in mask_data['embeddings']:
im_size = self.transform.apply_image(image_record["image"]).shape[:2] print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break"""
#mask_data['embeddings'] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data['embeddings']]
#print(f"After concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
mask_data.to_numpy()
# Filter small disconnected regions and holes in masks NOT USED
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
self.min_mask_region_area,
self.box_nms_thresh,
)
# Write mask records
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item()
}
if extract_embeddings:
ann["embeddings"] = mask_data["embeddings"][idx]
curr_anns.append(ann)
outputs.append(
{
"annotations": curr_anns
}
)
# Extract masks for individual images
"""for image_record, orig_size, curr_embedding, curr_emb_no_red in zip(images, orinal_size_img, image_embeddings, embed_no_red):
im_size = self.transform.apply_image(image_record).shape[:2]
points_scale = np.array(im_size)[None, ::-1] points_scale = np.array(im_size)[None, ::-1]
points_for_image = self.points_grid * points_scale points_for_image = self.points_grid * points_scale
mask_data = MaskData() mask_data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image): for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self.process_batch(points, im_size, curr_embedding, image_record["original_size"]) batch_data = self.process_batch(points, im_size, curr_embedding, orig_size)
mask_data.cat(batch_data) mask_data.cat(batch_data)
del batch_data del batch_data
del curr_embedding del curr_embedding
# Remove duplicates within this crop. # Remove duplicates
keep_by_nms = batched_nms( keep_by_nms = batched_nms(
mask_data["boxes"].float(), mask_data["boxes"].float(),
mask_data["iou_preds"], mask_data["iou_preds"],
...@@ -335,8 +389,19 @@ class SamAutomaticMask(nn.Module): ...@@ -335,8 +389,19 @@ class SamAutomaticMask(nn.Module):
mask_data["segmentations"] = mask_data["masks"] mask_data["segmentations"] = mask_data["masks"]
# Extract mask embeddings
if extract_embeddings: if extract_embeddings:
self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5) self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
print(f"Before concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
for tensor in mask_data['embeddings']:
print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break
#mask_data['embeddings'] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data['embeddings']]
#print(f"After concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
mask_data.to_numpy() mask_data.to_numpy()
...@@ -366,7 +431,7 @@ class SamAutomaticMask(nn.Module): ...@@ -366,7 +431,7 @@ class SamAutomaticMask(nn.Module):
"annotations": curr_anns "annotations": curr_anns
} }
) )
"""
return outputs return outputs
def postprocess_masks( def postprocess_masks(
...@@ -645,7 +710,7 @@ class SamAutomaticMask(nn.Module): ...@@ -645,7 +710,7 @@ class SamAutomaticMask(nn.Module):
mask_embed += pos_embedding mask_embed += pos_embedding
# Apply mask to image embedding # Apply mask to image embedding
mask_data["embeddings"].append(mask_embed) # [token_dim] mask_data["embeddings"].append(torch.tensor(mask_embed, device=self.device)) # [token_dim]
def complete_holes(self, def complete_holes(self,
masks): masks):
...@@ -690,4 +755,60 @@ class SamAutomaticMask(nn.Module): ...@@ -690,4 +755,60 @@ class SamAutomaticMask(nn.Module):
new_masks_data["rles"] = mask_to_rle_pytorch(new_masks_data["masks"]) new_masks_data["rles"] = mask_to_rle_pytorch(new_masks_data["masks"])
return new_masks_data.to_numpy() return new_masks_data.to_numpy()
\ No newline at end of file
def position_embeding_3d(self, img_feats, camera_info):
# TODO : adapter cette fonction à notre usage
"""
3D position embedding on image features following PETR's work in :
https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petr_head.py#L282
"""
eps = 1e-5
B, N, C, H, W = img_feats.shape
coords_h = torch.arange(H, device=self.device).float()
coords_w = torch.arange(W, device=self.device).float()
# TODO : checker ces deux nombres et voir quoi faire avec
depth_num = 64
position_range=[-65, -65, -8.0, 65, 65, 8.0] # TODO : set cette valeur avec les bon ranges
# [xmin, ymin zmin, xmax, ymax, zmax] ROI 3D world space
depth_start = 1
# END TODO
index = torch.arange(start=0, end=depth_num, device=self.device).float()
bin_size = (position_range[3] - depth_start) / depth_num
coords_d = depth_start + bin_size * index
D = coords_d.shape[0]
coords = torch.stack(torch.meshgrid([coords_w, coords_h, coords_d])).permute(1, 2, 3, 0) # W, H, D, 3
coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1)
coords[..., :2] = coords[..., :2] * torch.maximum(coords[..., 2:3], torch.ones_like(coords[..., 2:3])*eps)
# EXTRACT EXTRINSICS
camera_extrinsic = camera_info["extrinsics"] # (B, N, 4, 4) # TODO : prendre cette valeur avec infos camera
# Apply Transform
coords = coords.view(1, 1, W, H, D, 4, 1).repeat(B, N, 1, 1, 1, 1, 1)
camera_extrinsic = camera_extrinsic.view(B, N, 1, 1, 1, 4, 4).repeat(1, 1, W, H, D, 1, 1)
coords3d = torch.matmul(camera_extrinsic, coords).squeeze(-1)[..., :3]
# Normalize
coords3d[..., 0:1] = (coords3d[..., 0:1] - position_range[0]) / (position_range[3] - position_range[0])
coords3d[..., 1:2] = (coords3d[..., 1:2] - position_range[1]) / (position_range[4] - position_range[1])
coords3d[..., 2:3] = (coords3d[..., 2:3] - position_range[2]) / (position_range[5] - position_range[2])
# Final embedding
coords3d = coords3d.permute(0, 1, 4, 5, 3, 2).contiguous().view(B*N, -1, H, W)
position_dim = 3 * depth_num
embed_dims = 256
position_encoder = nn.Sequential(
nn.Conv2d(position_dim, embed_dims*4, kernel_size=1, stride=1, padding=0),
nn.ReLU(),
nn.Conv2d(embed_dims*4, embed_dims, kernel_size=1, stride=1, padding=0),
)
coords_position_embeding = position_encoder(coords3d)
return coords_position_embeding.view(B, N, embed_dims, H, W)
\ No newline at end of file
...@@ -58,7 +58,7 @@ class PositionalEncoding(nn.Module): ...@@ -58,7 +58,7 @@ class PositionalEncoding(nn.Module):
octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves) octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves)
octaves = octaves.float().to(coords) octaves = octaves.float().to(coords)
multipliers = 2**octaves * math.pi multipliers = 2**octaves * math.pi
coords = coords.unsqueeze(-1) coords = coords.unsqueeze(-1)
while len(multipliers.shape) < len(coords.shape): while len(multipliers.shape) < len(coords.shape):
multipliers = multipliers.unsqueeze(0) multipliers = multipliers.unsqueeze(0)
...@@ -79,12 +79,12 @@ class RayEncoder(nn.Module): ...@@ -79,12 +79,12 @@ class RayEncoder(nn.Module):
def forward(self, pos, rays): def forward(self, pos, rays):
if len(rays.shape) == 4: if len(rays.shape) == 4:
batchsize, height, width, dims = rays.shape batchsize, height, width, _ = rays.shape
pos_enc = self.pos_encoding(pos.unsqueeze(1)) pos_enc = self.pos_encoding(pos.unsqueeze(1))
pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1) pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1)
pos_enc = pos_enc.repeat(1, 1, height, width) pos_enc = pos_enc.repeat(1, 1, height, width)
rays = rays.flatten(1, 2)
rays = rays.flatten(1, 2)
ray_enc = self.ray_encoding(rays) ray_enc = self.ray_encoding(rays)
ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1]) ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1])
ray_enc = ray_enc.permute((0, 3, 1, 2)) ray_enc = ray_enc.permute((0, 3, 1, 2))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment