Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
ce20222c
Commit
ce20222c
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Add batch and multi-image + extract embeddings for slot attention
parent
1af58823
No related branches found
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
automatic_mask_train.py
+11
-9
11 additions, 9 deletions
automatic_mask_train.py
osrt/encoder.py
+57
-110
57 additions, 110 deletions
osrt/encoder.py
osrt/layers.py
+5
-3
5 additions, 3 deletions
osrt/layers.py
osrt/utils/common.py
+9
-0
9 additions, 0 deletions
osrt/utils/common.py
with
82 additions
and
122 deletions
automatic_mask_train.py
+
11
−
9
View file @
ce20222c
import
argparse
import
torch
from
osrt.encoder
import
SamAutomaticMask
from
osrt.encoder
import
SamAutomaticMask
,
FeatureMasking
from
segment_anything
import
sam_model_registry
import
time
import
matplotlib.pyplot
as
plt
...
...
@@ -54,12 +54,14 @@ if __name__ == '__main__':
import
random
random
.
shuffle
(
images_path
)
sam
=
sam_model_registry
[
model_type
](
checkpoint
=
checkpoint
)
sam_mask
=
SamAutomaticMask
(
sam
.
image_encoder
,
sam
.
prompt_encoder
,
sam
.
mask_decoder
,
box_nms_thresh
=
0.7
,
stability_score_thresh
=
0.9
,
pred_iou_thresh
=
0.88
,
points_per_side
=
12
,
points_per_batch
=
64
)
#, min_mask_region_area=2000)
sam_mask
.
to
(
device
)
"""
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam_mask = SamAutomaticMask(sam.image_encoder, sam.prompt_encoder, sam.mask_decoder, box_nms_thresh=0.7, stability_score_thresh= 0.9, points_per_side=12, points_per_batch=64)#, min_mask_region_area=2000)
sam_mask.to(device)
"""
model
=
FeatureMasking
(
points_per_side
=
12
,
box_nms_thresh
=
0.7
,
stability_score_thresh
=
0.9
,
pred_iou_thresh
=
0.88
,
points_per_batch
=
64
)
model
.
to
(
device
)
images
=
[]
for
j
in
range
(
8
):
for
j
in
range
(
1
):
image
=
images_path
[
j
]
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
...
...
@@ -68,7 +70,7 @@ if __name__ == '__main__':
#images_np = images_np.reshape(2, 2, images_np.shape[2], images_np.shape[3], images_np.shape[4])
h
,
w
,
_
=
images_np
[
0
][
0
].
shape
points
=
sam_mask
.
points_grid
points
=
model
.
mask_generator
.
points_grid
new_points
=
[]
for
val
in
points
:
x
,
y
=
val
[
0
],
val
[
1
]
...
...
@@ -80,14 +82,14 @@ if __name__ == '__main__':
start
=
time
.
time
()
# TODO : set ray and camera directions
with
torch
.
no_grad
():
masks
=
sam_mask
(
images_np
,
(
h
,
w
),
extract_
embedding
s
=
True
)
masks
,
slots
=
model
(
images_np
,
(
h
,
w
),
extract_
mask
s
=
True
)
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
if
args
.
visualize
:
plt
.
figure
(
figsize
=
(
15
,
15
))
plt
.
imshow
(
img
)
show_anns
(
masks
[
0
])
# show masks
show_anns
(
masks
[
0
]
[
0
]
)
# show masks
show_points
(
new_points
,
plt
.
gca
())
# show points
plt
.
axis
(
'
off
'
)
plt
.
show
()
...
...
This diff is collapsed.
Click to expand it.
osrt/encoder.py
+
57
−
110
View file @
ce20222c
...
...
@@ -3,13 +3,13 @@ import torch
import
torch.nn
as
nn
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
torch.nn
import
functional
as
F
from
torchvision.transforms.functional
import
resize
,
to_pil_image
# type: ignore
import
math
from
torchvision.ops.boxes
import
batched_nms
import
torchvision.transforms.functional
as
func
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
osrt.layers
import
RayEncoder
,
Transformer
,
SlotAttention
from
osrt.utils.common
import
batch_iterator
,
MaskData
,
calculate_stability_score
,
get_indices_sorted_pos
,
get_positional_embedding
from
osrt.utils.common
import
batch_iterator
,
MaskData
,
calculate_stability_score
,
get_indices_sorted_pos
,
get_positional_embedding
,
create_points_grid
from
segment_anything
import
sam_model_registry
from
segment_anything.modeling.image_encoder
import
ImageEncoderViT
...
...
@@ -116,12 +116,14 @@ class FeatureMasking(nn.Module):
stability_score_thresh
=
0.9
,
pred_iou_thresh
=
0.88
,
points_per_batch
=
64
,
min_mask_region_area
=
400
0
,
min_mask_region_area
=
0
,
num_slots
=
6
,
slot_dim
=
1536
,
slot_iters
=
1
,
slot_iters
=
1
,
num_att_blocks
=
5
,
sam_model
=
"
default
"
,
sam_path
=
"
sam_vit_h_4b8939.pth
"
,
tokenizer
=
"
mean
"
,
randomize_initial_slots
=
False
):
super
().
__init__
()
...
...
@@ -135,12 +137,16 @@ class FeatureMasking(nn.Module):
pred_iou_thresh
=
pred_iou_thresh
,
points_per_side
=
points_per_side
,
points_per_batch
=
points_per_batch
,
tokenizer
=
tokenizer
,
min_mask_region_area
=
min_mask_region_area
)
self
.
slot_attention
=
SlotAttention
(
num_slots
,
slot_dim
=
slot_dim
,
iters
=
slot_iters
,
self
.
transformer
=
Transformer
(
self
.
mask_generator
.
token_dim
,
depth
=
num_att_blocks
,
heads
=
12
,
dim_head
=
64
,
mlp_dim
=
1536
,
selfatt
=
True
)
self
.
slot_attention
=
SlotAttention
(
num_slots
,
input_dim
=
self
.
mask_generator
.
token_dim
,
slot_dim
=
slot_dim
,
iters
=
slot_iters
,
randomize_initial_slots
=
randomize_initial_slots
)
def
forward
(
self
,
images
,
original_size
,
camera_pos
=
None
,
rays
=
None
):
def
forward
(
self
,
images
,
original_size
,
camera_pos
=
None
,
rays
=
None
,
extract_masks
=
True
):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
...
...
@@ -160,20 +166,30 @@ class FeatureMasking(nn.Module):
# Generate images
masks
=
self
.
mask_generator
(
images
,
original_size
,
camera_pos
,
rays
,
extract_embeddings
=
True
)
# [B, N]
num_masks
=
[]
for
batch
in
masks
:
num_masks
.
append
(
len
(
batch
))
set_latents
=
masks
[:][
"
embeddings
"
]
# Set the number of slots for current batch
self
.
slot_attention
.
change_slots_number
(
num_masks
)
B
,
N
=
masks
.
shape
dim
=
masks
[
0
][
0
][
"
embeddings
"
][
0
].
shape
[
0
]
# We infer each batch separetely as it handle a different number of slots
set_latents
=
None
for
b
in
range
(
B
):
latents_batch
=
torch
.
empty
((
1
,
dim
),
device
=
self
.
mask_generator
.
device
)
# TODO : set a new number of slots
for
n
in
range
(
N
):
embeds
=
masks
[
b
][
n
][
"
embeddings
"
]
for
embed
in
embeds
:
latents_batch
=
torch
.
cat
((
latents_batch
,
embed
.
unsqueeze
(
0
)),
0
)
if
set_latents
==
None
:
set_latents
=
latents_batch
.
unsqueeze
(
0
)
else
:
set_latents
=
torch
.
cat
((
set_latents
,
latents_batch
.
unsqueeze
(
0
)),
0
)
# [batch_size, num_inputs, dim]
slot_latents
=
self
.
slot_attention
(
set_latents
)
return
slot_latents
if
extract_masks
:
return
masks
,
slot_latents
else
:
return
slot_latents
class
SamAutomaticMask
(
nn
.
Module
):
mask_threshold
:
float
=
0.0
...
...
@@ -194,6 +210,7 @@ class SamAutomaticMask(nn.Module):
box_nms_thresh
:
float
=
0.7
,
min_mask_region_area
:
int
=
0
,
pos_start_octave
=
0
,
tokenizer
=
"
mean
"
,
patch_size
=
16
)
->
None
:
"""
...
...
@@ -232,9 +249,9 @@ class SamAutomaticMask(nn.Module):
self
.
transform
=
ResizeLongestSide
(
self
.
image_encoder
.
img_size
)
if
points_per_side
>
0
:
self
.
points_grid
=
self
.
create_points_grid
(
points_per_side
)
self
.
points_grid
=
create_points_grid
(
points_per_side
)
else
:
self
.
points_grid
=
None
self
.
points_grid
=
create_points_grid
(
32
)
self
.
points_per_batch
=
points_per_batch
self
.
pred_iou_thresh
=
pred_iou_thresh
self
.
stability_score_thresh
=
stability_score_thresh
...
...
@@ -243,15 +260,18 @@ class SamAutomaticMask(nn.Module):
self
.
min_mask_region_area
=
min_mask_region_area
# TODO : set the token dim and the input size
input_size
=
0
# depends on the image size
self
.
token_dim
=
(
self
.
image_encoder
.
img_size
//
patch_size
)
**
2
self
.
tokenizer
=
nn
.
Sequential
(
nn
.
Linear
(
input_size
,
100
),
nn
.
ReLU
(),
nn
.
Linear
(
100
,
50
),
nn
.
ReLU
(),
nn
.
Linear
(
50
,
self
.
token_dim
),
)
self
.
tokenizer_type
=
tokenizer
if
tokenizer
==
"
mlp
"
:
input_size
=
0
# depends on the image size
self
.
token_dim
=
(
self
.
image_encoder
.
img_size
//
patch_size
)
**
2
self
.
tokenizer
=
nn
.
Sequential
(
nn
.
Linear
(
input_size
,
100
),
nn
.
ReLU
(),
nn
.
Linear
(
100
,
50
),
nn
.
ReLU
(),
nn
.
Linear
(
50
,
self
.
token_dim
),
)
# Space positional embedding
self
.
ray_encoder
=
RayEncoder
(
pos_octaves
=
15
,
pos_start_octave
=
pos_start_octave
,
...
...
@@ -268,7 +288,6 @@ class SamAutomaticMask(nn.Module):
camera_pos
=
None
,
rays
=
None
,
extract_embeddings
=
False
):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
...
...
@@ -299,7 +318,8 @@ class SamAutomaticMask(nn.Module):
image_embeddings
,
embed_no_red
=
self
.
image_encoder
(
input_images
,
before_channel_reduc
=
True
)
# [B x N, H, W, C]
# TODO : add camera position embedding to the 2D image embedding with @position_embeding_3d
annotations
=
[]
annotations
=
np
.
empty
((
B
,
N
),
dtype
=
object
)
i
=
0
for
curr_embedding
,
curr_emb_no_red
in
zip
(
image_embeddings
,
embed_no_red
):
mask_data
=
MaskData
()
for
(
points
,)
in
batch_iterator
(
self
.
points_per_batch
,
points_for_image
):
...
...
@@ -322,16 +342,7 @@ class SamAutomaticMask(nn.Module):
# Extract mask embeddings
if
extract_embeddings
:
self
.
extract_mask_embedding
(
mask_data
,
curr_emb_no_red
,
im_size
,
scale_box
=
1.5
)
"""
print(f
"
Before concat : {mask_data[
'
embeddings
'
][0].shape}, len {len(mask_data[
'
embeddings
'
])}
"
)
for tensor in mask_data[
'
embeddings
'
]:
print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break
"""
#mask_data['embeddings'] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data['embeddings']]
#print(f"After concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
# TODO : ajouter 3D positional embedding ici
mask_data
.
to_numpy
()
...
...
@@ -343,19 +354,8 @@ class SamAutomaticMask(nn.Module):
self
.
box_nms_thresh
,
)
# TODO : have a more efficient way to store the data
# Write mask records
"""
curr_anns = []
print(mask_data.items())
for idx in range(len(mask_data[
"
segmentations
"
])):
ann = {
"
segmentation
"
: mask_data[
"
segmentations
"
][idx],
"
area
"
: area_from_rle(mask_data[
"
rles
"
][idx]),
"
predicted_iou
"
: mask_data[
"
iou_preds
"
][idx].item(),
"
point_coords
"
: [mask_data[
"
points
"
][idx].tolist()],
"
stability_score
"
: mask_data[
"
stability_score
"
][idx].item()
}
if extract_embeddings:
ann[
"
embeddings
"
] = mask_data[
"
embeddings
"
][idx]
"""
if
extract_embeddings
:
curr_ann
=
{
"
embeddings
"
:
mask_data
[
"
embeddings
"
],
...
...
@@ -365,10 +365,11 @@ class SamAutomaticMask(nn.Module):
curr_ann
=
{
"
segmentations
"
:
mask_data
[
"
segmentations
"
]
}
#annotations.append({"annotations": curr_anns})
annotations
.
append
(
curr_ann
)
annotations
=
np
.
array
(
annotations
).
reshape
(
B
*
N
)
return
annotations
# [BxN] : dict containing diverse annotations such as segmentation, area or also embedding
batch
=
math
.
floor
((
i
/
N
))
num_im
=
i
%
N
annotations
[
batch
][
num_im
]
=
curr_ann
i
+=
1
return
annotations
# [B, N, 1] : dict containing diverse annotations such as segmentation, area or also embedding
def
postprocess_masks
(
self
,
...
...
@@ -467,15 +468,6 @@ class SamAutomaticMask(nn.Module):
x
=
F
.
pad
(
x
,
(
0
,
padw
,
0
,
padh
))
return
x
def
create_points_grid
(
self
,
number_points
):
"""
Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
"""
offset
=
1
/
(
2
*
number_points
)
points_one_side
=
np
.
linspace
(
offset
,
1
-
offset
,
number_points
)
points_x
=
np
.
tile
(
points_one_side
[
None
,
:],
(
number_points
,
1
))
points_y
=
np
.
tile
(
points_one_side
[:,
None
],
(
1
,
number_points
))
points
=
np
.
stack
([
points_x
,
points_y
],
axis
=-
1
).
reshape
(
-
1
,
2
)
return
points
def
process_batch
(
self
,
points
:
np
.
ndarray
,
...
...
@@ -629,7 +621,7 @@ class SamAutomaticMask(nn.Module):
indices
=
get_indices_sorted_pos
(
mask_data
)
mask_data
.
sort_by_indices
(
indices
)
# TODO : add positional encoding
# TODO : add
3D
positional encoding
for
idx
in
range
(
len
(
mask_data
[
"
segmentations
"
])):
mask
=
mask_data
[
"
segmentations
"
][
idx
]
...
...
@@ -648,51 +640,6 @@ class SamAutomaticMask(nn.Module):
# Apply mask to image embedding
mask_data
[
"
embeddings
"
].
append
(
torch
.
tensor
(
mask_embed
,
device
=
self
.
device
))
# [token_dim]
def
complete_holes
(
self
,
masks
):
""""
The purpose of this function is to segment EVERYTHING from the image, without letting any remaining hole
"""
total_mask
=
masks
[
0
]
for
idx
in
range
(
len
(
masks
)):
if
idx
>
0
:
total_mask
+=
masks
[
idx
]
des
=
total_mask
.
astype
(
np
.
uint8
)
*
255
kernel
=
np
.
ones
((
4
,
4
),
np
.
uint8
)
img_dilate
=
cv2
.
dilate
(
des
,
kernel
,
iterations
=
1
)
import
matplotlib.pyplot
as
plt
plt
.
imshow
(
img_dilate
)
plt
.
show
()
inverse_dilate
=
np
.
zeros
((
total_mask
.
shape
),
dtype
=
np
.
uint8
)
inverse_dilate
=
np
.
logical_not
(
img_dilate
).
astype
(
np
.
uint8
)
*
255
contours
,
_
=
cv2
.
findContours
(
inverse_dilate
,
cv2
.
RETR_EXTERNAL
,
cv2
.
CHAIN_APPROX_SIMPLE
)
result_masks
=
[]
for
contour
in
contours
:
area
=
cv2
.
contourArea
(
contour
)
if
area
>
4000
:
mask
=
np
.
zeros
((
total_mask
.
shape
),
dtype
=
np
.
uint8
)
cv2
.
drawContours
(
mask
,
[
contour
],
0
,
255
,
-
1
)
result_masks
.
append
(
mask
)
new_masks_data
=
MaskData
(
masks
=
torch
.
tensor
(
result_masks
),
iou_preds
=
torch
.
tensor
([
0.9
for
i
in
range
(
len
(
result_masks
))])
)
new_masks_data
[
"
stability_score
"
]
=
calculate_stability_score
(
new_masks_data
[
"
masks
"
],
self
.
mask_threshold
,
self
.
stability_score_offset
)
new_masks_data
[
"
boxes
"
]
=
batched_mask_to_box
(
new_masks_data
[
"
masks
"
])
new_masks_data
[
"
rles
"
]
=
mask_to_rle_pytorch
(
new_masks_data
[
"
masks
"
])
return
new_masks_data
.
to_numpy
()
def
position_embeding_3d
(
self
,
img_feats
,
camera_info
):
# TODO : adapter cette fonction à notre usage
"""
...
...
This diff is collapsed.
Click to expand it.
osrt/layers.py
+
5
−
3
View file @
ce20222c
...
...
@@ -184,6 +184,8 @@ class Transformer(nn.Module):
class
SlotAttention
(
nn
.
Module
):
"""
Slot Attention as introduced by Locatello et al.
@edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
"""
def
__init__
(
self
,
num_slots
,
input_dim
=
768
,
slot_dim
=
1536
,
hidden_dim
=
3072
,
iters
=
3
,
eps
=
1e-8
,
randomize_initial_slots
=
False
):
...
...
@@ -225,7 +227,7 @@ class SlotAttention(nn.Module):
inputs
=
self
.
norm_input
(
inputs
)
if
self
.
randomize_initial_slots
:
slot_means
=
self
.
initial_slots
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
slot_means
=
self
.
initial_slots
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
# from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slots
=
torch
.
distributions
.
Normal
(
slot_means
,
self
.
embedding_stdev
).
rsample
()
else
:
slots
=
self
.
initial_slots
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
)
...
...
@@ -242,13 +244,13 @@ class SlotAttention(nn.Module):
# shape: [batch_size, num_slots, num_inputs]
attn
=
dots
.
softmax
(
dim
=
1
)
+
self
.
eps
attn
=
attn
/
attn
.
sum
(
dim
=-
1
,
keepdim
=
True
)
updates
=
torch
.
einsum
(
'
bjd,bij->bid
'
,
v
,
attn
)
updates
=
torch
.
einsum
(
'
bjd,bij->bid
'
,
v
,
attn
)
# shape: [batch_size, num_inputs, slot_dim]
slots
=
self
.
gru
(
updates
.
flatten
(
0
,
1
),
slots_prev
.
flatten
(
0
,
1
))
slots
=
slots
.
reshape
(
batch_size
,
self
.
num_slots
,
self
.
slot_dim
)
slots
=
slots
+
self
.
mlp
(
self
.
norm_pre_mlp
(
slots
))
return
slots
return
slots
# [batch_size, num_slots, dim]
def
change_slots_number
(
self
,
num_slots
):
self
.
num_slots
=
num_slots
This diff is collapsed.
Click to expand it.
osrt/utils/common.py
+
9
−
0
View file @
ce20222c
...
...
@@ -99,6 +99,15 @@ class MaskData:
if
isinstance
(
v
,
torch
.
Tensor
):
self
.
_stats
[
k
]
=
v
.
detach
().
cpu
().
numpy
()
def
create_points_grid
(
number_points
):
"""
Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
"""
offset
=
1
/
(
2
*
number_points
)
points_one_side
=
np
.
linspace
(
offset
,
1
-
offset
,
number_points
)
points_x
=
np
.
tile
(
points_one_side
[
None
,
:],
(
number_points
,
1
))
points_y
=
np
.
tile
(
points_one_side
[:,
None
],
(
1
,
number_points
))
points
=
np
.
stack
([
points_x
,
points_y
],
axis
=-
1
).
reshape
(
-
1
,
2
)
return
points
def
get_positional_embedding
(
position
,
token_dim
):
position_encodings
=
np
.
zeros
(
token_dim
)
div_term
=
np
.
exp
(
np
.
arange
(
0
,
token_dim
,
2
).
astype
(
np
.
float32
)
*
(
-
math
.
log
(
10000.0
)
/
token_dim
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment