Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
a9d24741
Commit
a9d24741
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Add exraction of mask embeddings
parent
ab044bc4
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
.gitmodules
+3
-0
3 additions, 0 deletions
.gitmodules
automatic_mask_train.py
+19
-2
19 additions, 2 deletions
automatic_mask_train.py
osrt/encoder.py
+61
-36
61 additions, 36 deletions
osrt/encoder.py
osrt/layers.py
+1
-0
1 addition, 0 deletions
osrt/layers.py
segment-anything
+1
-0
1 addition, 0 deletions
segment-anything
with
85 additions
and
38 deletions
.gitmodules
0 → 100644
+
3
−
0
View file @
a9d24741
[submodule "segment-anything"]
path = segment-anything
url = git@github.com:facebookresearch/segment-anything.git
This diff is collapsed.
Click to expand it.
automatic_mask_train.py
+
19
−
2
View file @
a9d24741
...
...
@@ -7,6 +7,7 @@ import time
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
cv2
import
matplotlib
def
show_anns
(
masks
):
ax
=
plt
.
gca
()
...
...
@@ -58,7 +59,7 @@ if __name__ == '__main__':
sam
=
sam_model_registry
[
model_type
](
checkpoint
=
checkpoint
)
sam
.
to
(
device
=
device
)
#mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=12, box_nms_thresh=0.7, crop_n_layers=0, points_per_batch=128, pred_iou_thresh=0.88)
sam_mask
=
SamAutomaticMask
(
sam
.
image_encoder
,
sam
.
prompt_encoder
,
sam
.
mask_decoder
,
box_nms_thresh
=
0.7
,
stability_score_thresh
=
0.9
,
pred_iou_thresh
=
0.88
,
points_per_side
=
8
,
points_per_batch
=
64
,
min_mask_region_area
=
4
000
)
sam_mask
=
SamAutomaticMask
(
sam
.
image_encoder
,
sam
.
prompt_encoder
,
sam
.
mask_decoder
,
box_nms_thresh
=
0.7
,
stability_score_thresh
=
0.9
,
pred_iou_thresh
=
0.88
,
points_per_side
=
8
,
points_per_batch
=
64
)
#
, min_mask_region_area=
2
000)
sam_mask
.
to
(
device
)
transform
=
transforms
.
Compose
([
...
...
@@ -66,7 +67,10 @@ if __name__ == '__main__':
])
labels
=
[
1
for
i
in
range
(
len
(
sam_mask
.
points_grid
))]
with
torch
.
no_grad
():
j
=
0
for
image
in
images_path
:
#import os
#os.mkdir(f"./results/test_{j}")
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
...
...
@@ -86,13 +90,26 @@ if __name__ == '__main__':
img_batch
.
append
(
img_el
)
start
=
time
.
time
()
masks
=
sam_mask
(
img_batch
)
masks
=
sam_mask
(
img_batch
,
extract_embeddings
=
True
)
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
plt
.
figure
(
figsize
=
(
15
,
15
))
plt
.
imshow
(
img
)
show_anns
(
masks
[
0
][
"
annotations
"
])
show_points
(
new_points
,
plt
.
gca
())
#plt.savefig(f"./results/test_{j}/masks.png")
plt
.
axis
(
'
off
'
)
plt
.
show
()
"""
from PIL import Image
i = 0
for mask in masks[0][
"
annotations
"
]:
cm = matplotlib.cm.get_cmap(
'
viridis
'
)
img_src = Image.fromarray(mask[
"
embeddings
"
]).convert(
'
L
'
)
im = np.array(img_src)
im = cm(im)
im = np.uint8(im * 255)
im = Image.fromarray(im)
im.save(f
"
./results/test_{j}/mask_{i}.png
"
)
i+=1
j+=1
"""
This diff is collapsed.
Click to expand it.
osrt/encoder.py
+
61
−
36
View file @
a9d24741
...
...
@@ -10,9 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple
from
osrt.layers
import
RayEncoder
,
Transformer
,
SlotAttention
from
osrt.utils.common
import
batch_iterator
,
MaskData
,
calculate_stability_score
from
segment_anything
import
SamAutomaticMaskGenerator
,
sam_model_registry
from
segment_anything.modeling
import
Sam
from
segment_anything.modeling
import
Sam
from
segment_anything
import
sam_model_registry
from
segment_anything.modeling.image_encoder
import
ImageEncoderViT
from
segment_anything.modeling.mask_decoder
import
MaskDecoder
from
segment_anything.modeling.prompt_encoder
import
PromptEncoder
...
...
@@ -147,13 +145,15 @@ class FeatureMasking(nn.Module):
def
forward
(
self
,
images
):
# Generate images
masks
=
self
.
mask_generator
(
images
)
masks
=
self
.
mask_generator
(
images
,
extract_embeddings
=
True
)
# TODO : find a way to handle multiple image from a same scene instead of just one
set_latents
=
masks
[:][
"
annotations
"
][:][
"
embeddings
"
]
num_masks
=
[]
for
batch
in
masks
:
num_masks
.
append
(
len
(
batch
[
"
annotations
"
]))
set_latents
=
None
num_masks
=
None
# TODO : set the number of slots according to the masks number
# Set the number of slots for current batch
self
.
slot_attention
.
change_slots_number
(
num_masks
)
# [batch_size, num_inputs, dim]
...
...
@@ -233,7 +233,7 @@ class SamAutomaticMask(nn.Module):
def
forward
(
self
,
batched_input
:
List
[
Dict
[
str
,
Any
]],
extract_embeddings
:
bool
=
Tru
e
extract_embeddings
:
bool
=
Fals
e
)
->
List
[
Dict
[
str
,
torch
.
Tensor
]]:
"""
Predicts masks end-to-end from provided images and prompts.
...
...
@@ -276,16 +276,11 @@ class SamAutomaticMask(nn.Module):
# Extract image embeddings
input_images
=
[
self
.
preprocess
(
x
[
"
image
"
])
for
x
in
batched_input
][
0
]
with
torch
.
no_grad
():
image_embeddings
=
self
.
image_encoder
(
input_images
)
#, before_channel_reduc=True), embed_no_red
"""
# Extract image embedding before channel reduction, cf. https://github.com/facebookresearch/segment-anything/issues/283
if before_channel_reduc :
return x, embed_no_red
"""
image_embeddings
,
embed_no_red
=
self
.
image_encoder
(
input_images
,
before_channel_reduc
=
True
)
outputs
=
[]
for
image_record
,
curr_embedding
in
zip
(
batched_input
,
image_embeddings
):
for
image_record
,
curr_embedding
,
curr_emb_no_red
in
zip
(
batched_input
,
image_embeddings
,
embed_no_red
):
# TODO : check if we've got the points given in the batch (to change the current point_grid !)
im_size
=
self
.
transform
.
apply_image
(
image_record
[
"
image
"
]).
shape
[:
2
]
points_scale
=
np
.
array
(
im_size
)[
None
,
::
-
1
]
points_for_image
=
self
.
points_grid
*
points_scale
...
...
@@ -319,7 +314,8 @@ class SamAutomaticMask(nn.Module):
)
mask_data
[
"
segmentations
"
]
=
mask_data
[
"
masks
"
]
mask_embed
=
self
.
extract_mask_embedding
(
mask_data
,
embed_no_red
,
scale_box
=
1.5
)
if
extract_embeddings
:
mask_embed
=
self
.
extract_mask_embedding
(
mask_data
,
curr_emb_no_red
,
im_size
,
scale_box
=
1.5
)
# Write mask records
curr_anns
=
[]
...
...
@@ -332,8 +328,7 @@ class SamAutomaticMask(nn.Module):
"
stability_score
"
:
mask_data
[
"
stability_score
"
][
idx
].
item
()
}
if
extract_embeddings
:
# TODO : add embeddings into the annotations
continue
ann
[
"
embeddings
"
]
=
mask_embed
[
idx
]
curr_anns
.
append
(
ann
)
outputs
.
append
(
{
...
...
@@ -575,7 +570,7 @@ class SamAutomaticMask(nn.Module):
return
masks
,
iou_predictions
,
low_res_masks
def
extract_mask_embedding
(
self
,
mask_data
,
image_embed
,
scale_box
=
1.5
):
def
extract_mask_embedding
(
self
,
mask_data
,
image_embed
,
input_size
,
scale_box
=
1.5
):
"""
Predicts the embeddings from each mask given the global embedding and
a scale factor around each mask.
...
...
@@ -588,31 +583,61 @@ class SamAutomaticMask(nn.Module):
Returns:
embeddings : the embeddings for each mask extracted from the image
"""
image_embed
=
image_embed
.
permute
(
2
,
0
,
1
)
orig_H
,
orig_W
=
mask_data
[
"
segmentations
"
][
0
].
shape
[:
2
]
# We follow the same process to put the images back to the right format
scaled_img_emb
=
self
.
postprocess_masks
(
image_embed
.
unsqueeze
(
0
),
input_size
,
(
orig_H
,
orig_W
))
def
scale_bounding_box
(
box
,
scale_factor
,
img_size
):
x1
,
y1
,
x2
,
y2
=
box
width
=
x2
-
x1
height
=
y2
-
y1
new_width
=
width
*
scale_factor
new_height
=
height
*
scale_factor
# Clamping values of the box inside of the image
new_x1
=
int
(
max
(
0
,
x1
-
(
new_width
-
width
)
/
2
))
new_y1
=
int
(
max
(
0
,
y1
-
(
new_height
-
height
)
/
2
))
new_x2
=
int
(
min
(
img_size
[
1
],
new_x1
+
new_width
))
new_y2
=
int
(
min
(
img_size
[
0
],
new_y1
+
new_height
))
return
(
new_x1
,
new_y1
,
new_x2
,
new_y2
)
masks_embedding
=
[]
for
idx
in
range
(
len
(
mask_data
[
"
segmentations
"
])):
mask
=
mask_data
[
"
segmentations
"
][
idx
]
box
=
mask_data
[
"
boxes
"
][
idx
]
def
s
cale
_
bounding
_
box
(
box
,
scale_factor
):
x1
,
y1
,
x2
,
y2
=
box
# S
cale
bounding
box
scaled_box
=
scale_bounding_box
(
box
,
scale_box
,
(
orig_H
,
orig_W
))
width
=
x2
-
x1
height
=
y2
-
y1
# Crop image embedding around bbox
croped_im_embed
=
scaled_img_emb
[
0
,
:,
scaled_box
[
1
]:
scaled_box
[
3
],
scaled_box
[
0
]:
scaled_box
[
2
]].
cpu
().
numpy
()
# [channels, h, w]
crop_mask
=
mask
[
scaled_box
[
1
]:
scaled_box
[
3
],
scaled_box
[
0
]:
scaled_box
[
2
]]
# [h, w]
new_width
=
width
*
scale_factor
new_height
=
height
*
scale_factor
new_x1
=
x1
-
(
new_width
-
width
)
/
2
new_y1
=
y1
-
(
new_height
-
height
)
/
2
new_x2
=
new_x1
+
new_width
new_y2
=
new_y1
+
new_height
# Apply mask to bounding box
print
(
f
"
{
croped_im_embed
[
:
].
shape
}
{
crop_mask
.
shape
}
"
)
masked_embed
=
croped_im_embed
[:]
*
crop_mask
# [channels, h, w]
return
new_x1
,
new_y1
,
new_x2
,
new_y2
# Scale bounding box
scaled_box
=
scale_bounding_box
(
box
,
scale_box
)
print
(
image_embed
.
shape
)
# Apply average pooling on masked region
# TODO : find a way to export tokens
#final_token = np.mean(masked_embed, axis=(1, 2))
#print(f"Final token : {final_token}")
masks_embedding
=
None
####### masks_embedding.append(masked_embed)
#mean_embed = masked_embed / np.mean(masked_embed)
masks_embedding
.
append
(
masked_embed
)
"""
print(f
"
Shape of im embedding {scaled_img_emb.shape}
"
)
print(f
"
Shape of masked embedding {masked_embed.shape}
"
)
print(f
"
Shape of token {final_token.shape}
"
)
print(
"
########################
"
)
"""
return
masks_embedding
def
complete_holes
(
self
,
...
...
This diff is collapsed.
Click to expand it.
osrt/layers.py
+
1
−
0
View file @
a9d24741
...
...
@@ -222,6 +222,7 @@ class SlotAttention(nn.Module):
Args:
inputs: set-latent representation [batch_size, num_inputs, dim]
"""
# TODO : change number slots depending on the batch
batch_size
,
num_inputs
,
dim
=
inputs
.
shape
inputs
=
self
.
norm_input
(
inputs
)
...
...
This diff is collapsed.
Click to expand it.
segment-anything
@
6fdee8f2
Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment