Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
2d0f120c
Commit
2d0f120c
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Tests on time optimization
parent
e6fe586f
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
automatic_mask_train.py
+10
-11
10 additions, 11 deletions
automatic_mask_train.py
osrt/encoder.py
+5
-4
5 additions, 4 deletions
osrt/encoder.py
sam_test.py
+38
-98
38 additions, 98 deletions
sam_test.py
with
53 additions
and
113 deletions
automatic_mask_train.py
+
10
−
11
View file @
2d0f120c
import
argparse
import
argparse
import
torch
import
torch
from
osrt.encoder
import
SamAutomaticMask
,
FeatureMasking
from
osrt.model
import
OSRT
from
osrt.model
import
OSRT
from
segment_anything
import
sam_model_registry
import
time
import
time
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
def
show_anns
(
masks
):
def
show_anns
(
masks
):
ax
=
plt
.
gca
()
ax
=
plt
.
gca
()
ax
.
set_autoscale_on
(
False
)
ax
.
set_autoscale_on
(
False
)
...
@@ -57,13 +56,13 @@ if __name__ == '__main__':
...
@@ -57,13 +56,13 @@ if __name__ == '__main__':
cfg
[
'
encoder
'
]
=
'
sam
'
cfg
[
'
encoder
'
]
=
'
sam
'
cfg
[
'
decoder
'
]
=
'
slot_mixer
'
cfg
[
'
decoder
'
]
=
'
slot_mixer
'
cfg
[
'
encoder_kwargs
'
]
=
{
cfg
[
'
encoder_kwargs
'
]
=
{
'
points_per_side
'
:
1
2
,
'
points_per_side
'
:
3
2
,
'
box_nms_thresh
'
:
0.7
,
'
box_nms_thresh
'
:
0.7
,
'
stability_score_thresh
'
:
0.9
,
'
stability_score_thresh
'
:
0.9
,
'
pred_iou_thresh
'
:
0.88
,
'
pred_iou_thresh
'
:
0.88
,
'
sam_model
'
:
model_type
,
'
sam_model
'
:
model_type
,
'
sam_path
'
:
checkpoint
,
'
sam_path
'
:
checkpoint
,
'
points_per_batch
'
:
1
6
'
points_per_batch
'
:
1
2
}
}
cfg
[
'
decoder_kwargs
'
]
=
{
cfg
[
'
decoder_kwargs
'
]
=
{
'
pos_start_octave
'
:
-
5
,
'
pos_start_octave
'
:
-
5
,
...
@@ -71,10 +70,10 @@ if __name__ == '__main__':
...
@@ -71,10 +70,10 @@ if __name__ == '__main__':
model
=
OSRT
(
cfg
)
#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64)
model
=
OSRT
(
cfg
)
#FeatureMasking(points_per_side=12, box_nms_thresh=0.7, stability_score_thresh= 0.9, pred_iou_thresh=0.88, points_per_batch=64)
model
.
to
(
device
,
non_blocking
=
True
)
model
.
to
(
device
,
non_blocking
=
True
)
num_encoder_params
=
sum
(
p
.
numel
()
for
p
in
model
.
encoder
.
parameters
())
"""
num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
"""
print(
'
Number of parameters:
'
)
print(
'
Number of parameters:
'
)
print(f
'
\t
Encoder: {num_encoder_params}
'
)
print(f
'
\t
Encoder: {num_encoder_params}
'
)
num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters())
num_mask_encoder_params = sum(p.numel() for p in model.encoder.mask_generator.parameters())
...
@@ -87,12 +86,12 @@ if __name__ == '__main__':
...
@@ -87,12 +86,12 @@ if __name__ == '__main__':
print(f
'
\t\t\t
Mask Decoder: {num_mask_params}.
'
)
print(f
'
\t\t\t
Mask Decoder: {num_mask_params}.
'
)
print(f
'
\t\t\t
Prompt Encoder: {num_prompt_params}.
'
)
print(f
'
\t\t\t
Prompt Encoder: {num_prompt_params}.
'
)
print(f
'
\t\t
Slot Attention: {num_slotatt_params}.
'
)
print(f
'
\t\t
Slot Attention: {num_slotatt_params}.
'
)
print(f
'
\t
Decoder: {num_decoder_params}
'
)
print(f
'
\t
Decoder: {num_decoder_params}
'
)
"""
"""
images
=
[]
images
=
[]
from
torchvision
import
transforms
from
torchvision
import
transforms
transform
=
transforms
.
ToTensor
()
transform
=
transforms
.
ToTensor
()
for
j
in
range
(
2
):
for
j
in
range
(
10
):
image
=
images_path
[
j
]
image
=
images_path
[
j
]
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
...
@@ -119,12 +118,12 @@ if __name__ == '__main__':
...
@@ -119,12 +118,12 @@ if __name__ == '__main__':
# TODO : set ray and camera directions
# TODO : set ray and camera directions
#with torch.no_grad():
#with torch.no_grad():
with
torch
.
cuda
.
amp
.
autocast
():
with
torch
.
cuda
.
amp
.
autocast
():
masks
,
slots
=
model
.
encoder
(
images_t
,
(
h
,
w
),
None
,
None
,
extract_
masks
=
Tru
e
)
masks
=
model
.
encoder
.
mask_generator
(
images_t
,
(
h
,
w
),
None
,
None
,
extract_
embeddings
=
Fals
e
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
if
args
.
visualize
:
if
args
.
visualize
:
for
j
in
range
(
2
):
for
j
in
range
(
10
):
image
=
images_path
[
j
]
image
=
images_path
[
j
]
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
...
...
This diff is collapsed.
Click to expand it.
osrt/encoder.py
+
5
−
4
View file @
2d0f120c
...
@@ -123,8 +123,8 @@ class FeatureMasking(nn.Module):
...
@@ -123,8 +123,8 @@ class FeatureMasking(nn.Module):
num_slots
=
32
,
num_slots
=
32
,
slot_dim
=
1536
,
slot_dim
=
1536
,
slot_iters
=
3
,
slot_iters
=
3
,
sam_model
=
"
defaul
t
"
,
sam_model
=
"
vit_
t
"
,
sam_path
=
"
sam_vit_h_4b8939
.pt
h
"
,
sam_path
=
"
mobile_sam
.pt
"
,
randomize_initial_slots
=
False
):
randomize_initial_slots
=
False
):
super
().
__init__
()
super
().
__init__
()
...
@@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module):
...
@@ -252,8 +252,6 @@ class SamAutomaticMask(nn.Module):
self
.
mask_decoder
=
mask_decoder
self
.
mask_decoder
=
mask_decoder
for
param
in
self
.
mask_decoder
.
parameters
():
for
param
in
self
.
mask_decoder
.
parameters
():
param
.
requires_grad
=
True
param
.
requires_grad
=
True
self
.
register_buffer
(
"
pixel_mean
"
,
torch
.
Tensor
(
pixel_mean
).
view
(
-
1
,
1
,
1
),
False
)
self
.
register_buffer
(
"
pixel_std
"
,
torch
.
Tensor
(
pixel_std
).
view
(
-
1
,
1
,
1
),
False
)
# Transform image to a square by putting it to the longest side
# Transform image to a square by putting it to the longest side
#self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
#self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
...
@@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module):
...
@@ -280,6 +278,9 @@ class SamAutomaticMask(nn.Module):
nn
.
Linear
(
2500
,
self
.
token_dim
),
nn
.
Linear
(
2500
,
self
.
token_dim
),
)
)
self
.
register_buffer
(
"
pixel_mean
"
,
torch
.
Tensor
(
pixel_mean
).
view
(
-
1
,
1
,
1
),
False
)
self
.
register_buffer
(
"
pixel_std
"
,
torch
.
Tensor
(
pixel_std
).
view
(
-
1
,
1
,
1
),
False
)
# Space positional embedding
# Space positional embedding
self
.
ray_encoder
=
RayEncoder
(
pos_octaves
=
15
,
pos_start_octave
=
pos_start_octave
,
self
.
ray_encoder
=
RayEncoder
(
pos_octaves
=
15
,
pos_start_octave
=
pos_start_octave
,
ray_octaves
=
15
)
ray_octaves
=
15
)
...
...
This diff is collapsed.
Click to expand it.
sam_test.py
+
38
−
98
View file @
2d0f120c
import
argparse
import
torch
import
torch
from
segment_anything
import
sam_model_registry
,
SamAutomaticMaskGenerator
from
segment_anything
import
sam_model_registry
,
SamAutomaticMaskGenerator
from
torchvision
import
transforms
from
PIL
import
Image
import
time
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
matplotlib
as
mpl
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
import
time
def
show_anns
(
mask
s
):
def
show_anns
(
ann
s
):
if
len
(
mask
s
)
==
0
:
if
len
(
ann
s
)
==
0
:
return
return
sorted_anns
=
sorted
(
mask
s
,
key
=
(
lambda
x
:
x
[
'
area
'
]),
reverse
=
True
)
sorted_anns
=
sorted
(
ann
s
,
key
=
(
lambda
x
:
x
[
'
area
'
]),
reverse
=
True
)
ax
=
plt
.
gca
()
ax
=
plt
.
gca
()
ax
.
set_autoscale_on
(
False
)
ax
.
set_autoscale_on
(
False
)
...
@@ -24,100 +20,44 @@ def show_anns(masks):
...
@@ -24,100 +20,44 @@ def show_anns(masks):
img
[
m
]
=
color_mask
img
[
m
]
=
color_mask
ax
.
imshow
(
img
)
ax
.
imshow
(
img
)
def
show_points
(
coords
,
ax
,
marker_size
=
100
):
ax
.
scatter
(
coords
[:,
0
],
coords
[:,
1
],
color
=
'
#2ca02c
'
,
marker
=
'
.
'
,
s
=
marker_size
)
if
__name__
==
'
__main__
'
:
# Arguments
parser
=
argparse
.
ArgumentParser
(
description
=
'
Test Segment Anything Auto Mask simplified implementation
'
)
parser
.
add_argument
(
'
--model
'
,
default
=
'
vit_b
'
,
type
=
str
,
help
=
'
Model to use
'
)
parser
.
add_argument
(
'
--path_model
'
,
default
=
'
.
'
,
type
=
str
,
help
=
'
Path to the model
'
)
args
=
parser
.
parse_args
()
device
=
"
cuda
"
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
allow_tf32
=
True
model_type
=
args
.
model
if
args
.
model
==
'
vit_h
'
:
checkpoint
=
args
.
path_model
+
'
/sam_vit_h_4b8939.pth
'
elif
args
.
model
==
'
vit_b
'
:
checkpoint
=
args
.
path_model
+
'
/sam_vit_b_01ec64.pth
'
else
:
checkpoint
=
args
.
path_model
+
'
/sam_vit_l_0b3195.pth
'
ycb_path
=
"
/home/achapin/Documents/Datasets/YCB_Video_Dataset/
"
images_path
=
[]
with
open
(
ycb_path
+
"
image_sets/train.txt
"
,
'
r
'
)
as
f
:
for
line
in
f
.
readlines
():
line
=
line
.
strip
()
images_path
.
append
(
ycb_path
+
'
data/
'
+
line
+
"
-color.png
"
)
import
random
model_type
=
"
vit_t
"
random
.
shuffle
(
images_path
)
sam_checkpoint
=
"
./mobile_sam.pt
"
sam
=
sam_model_registry
[
model_type
](
checkpoint
=
checkpoint
)
device
=
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
sam
.
to
(
device
=
device
)
mask_generator
=
SamAutomaticMaskGenerator
(
sam
,
points_per_side
=
12
,
box_nms_thresh
=
0.7
,
crop_n_layers
=
0
,
points_per_batch
=
128
,
pred_iou_thresh
=
0.88
)
transform
=
transforms
.
Compose
([
mobile_sam
=
sam_model_registry
[
model_type
](
checkpoint
=
sam_checkpoint
)
transforms
.
ToTensor
(),
mobile_sam
.
to
(
device
=
device
)
])
mobile_sam
.
eval
()
labels
=
[
1
for
i
in
range
(
len
(
mask_generator
.
point_grids
))]
with
torch
.
no_grad
():
for
image
in
images_path
:
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
"""
img_depth = cv2.imread(image.replace(
"
color
"
,
"
depth
"
))
img_depth = cv2.cvtColor(img_depth, cv2.COLOR_BGR2GRAY)
"""
h
,
w
,
_
=
img
.
shape
mask_generator
=
SamAutomaticMaskGenerator
(
mobile_sam
,
points_per_side
=
16
,
points_per_batch
=
12
)
points
=
mask_generator
.
point_grids
[
0
]
ycb_path
=
"
/home/achapin/Documents/Datasets/YCB_Video_Dataset/
"
new_points
=
[]
images_path
=
[]
for
val
in
points
:
with
open
(
ycb_path
+
"
image_sets/train.txt
"
,
'
r
'
)
as
f
:
x
,
y
=
val
[
0
],
val
[
1
]
for
line
in
f
.
readlines
():
x
*=
w
line
=
line
.
strip
()
y
*=
h
images_path
.
append
(
ycb_path
+
'
data/
'
+
line
+
"
-color.png
"
)
new_points
.
append
([
x
,
y
])
new_points
=
np
.
array
(
new_points
)
start
=
time
.
time
()
import
random
masks
=
mask_generator
.
generate
(
img
)
#random.shuffle(images_path)
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
plt
.
figure
(
figsize
=
(
15
,
15
))
plt
.
imshow
(
img
)
show_anns
(
masks
)
show_points
(
new_points
,
plt
.
gca
())
plt
.
axis
(
'
off
'
)
plt
.
show
()
"""
fig, ax = plt.subplots()
cmap = plt.cm.get_cmap(
'
plasma
'
)
img = ax.imshow(img_depth, cmap=cmap)
cbar = fig.colorbar(img, ax=ax)
depth_array_new = img.get_array()
plt.show()
depth_array_new = cv2.cvtColor(depth_array_new, cv2.COLOR_GRAY2RGB)
plt.imshow(depth_array_new)
plt.show()
print(depth_array_new.shape)
start = time.time()
masks = mask_generator.generate(depth_array_new)
end = time.time()
print(f
"
Inference time : {int((end-start) * 1000)}ms
"
)
plt.figure(figsize=(15,15))
plt.imshow(depth_array_new)
show_anns(masks)
show_points(new_points, plt.gca())
plt.axis(
'
off
'
)
plt.show()
"""
images
=
[]
from
torchvision
import
transforms
transform
=
transforms
.
ToTensor
()
for
j
in
range
(
20
):
image
=
images_path
[
j
]
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
images
.
append
(
transform
(
img
).
unsqueeze
(
0
))
images_t
=
torch
.
stack
(
images
).
to
(
device
)
start
=
time
.
time
()
masks
=
mask_generator
.
generate
(
images_t
)
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
plt
.
figure
(
figsize
=
(
15
,
15
))
plt
.
imshow
(
img
)
show_anns
(
masks
)
# show masks
plt
.
axis
(
'
off
'
)
plt
.
show
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment