Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
672cd717
Commit
672cd717
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Make new model slot attention
parent
78417e2e
No related branches found
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
.visualisation_1639.png
+0
-0
0 additions, 0 deletions
.visualisation_1639.png
osrt/layers.py
+42
-52
42 additions, 52 deletions
osrt/layers.py
osrt/model.py
+34
-53
34 additions, 53 deletions
osrt/model.py
visualise.py
+81
-0
81 additions, 0 deletions
visualise.py
with
157 additions
and
105 deletions
.visualisation_1639.png
0 → 100644
+
0
−
0
View file @
672cd717
77.6 KiB
This diff is collapsed.
Click to expand it.
osrt/layers.py
+
42
−
52
View file @
672cd717
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
import
math
import
math
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
import
torch.nn.functional
as
F
__USE_DEFAULT_INIT__
=
False
__USE_DEFAULT_INIT__
=
False
...
@@ -194,10 +195,11 @@ class SlotAttention(nn.Module):
...
@@ -194,10 +195,11 @@ class SlotAttention(nn.Module):
@edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
@edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
"""
"""
def
__init__
(
self
,
num_slots
,
input_dim
=
768
,
slot_dim
=
1536
,
hidden_dim
=
3072
,
iters
=
3
,
eps
=
1e-8
,
def
__init__
(
self
,
num_slots
,
input_dim
=
768
,
slot_dim
=
1536
,
hidden_dim
=
3072
,
iters
=
3
,
eps
=
1e-8
,
randomize_initial_slots
=
False
):
randomize_initial_slots
=
False
,
gain
=
1
,
temperature_factor
=
1
):
super
().
__init__
()
super
().
__init__
()
self
.
num_slots
=
num_slots
self
.
num_slots
=
num_slots
self
.
temperature_factor
=
temperature_factor
self
.
batch_slots
=
[]
self
.
batch_slots
=
[]
self
.
iters
=
iters
self
.
iters
=
iters
self
.
scale
=
slot_dim
**
-
0.5
self
.
scale
=
slot_dim
**
-
0.5
...
@@ -207,24 +209,31 @@ class SlotAttention(nn.Module):
...
@@ -207,24 +209,31 @@ class SlotAttention(nn.Module):
self
.
initial_slots
=
nn
.
Parameter
(
torch
.
randn
(
num_slots
,
slot_dim
))
self
.
initial_slots
=
nn
.
Parameter
(
torch
.
randn
(
num_slots
,
slot_dim
))
self
.
eps
=
eps
self
.
eps
=
eps
self
.
slots_mu
=
nn
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
empty
(
1
,
1
,
self
.
slot_dim
)))
self
.
slots_log_sigma
=
nn
.
Parameter
(
nn
.
init
.
xavier_uniform_
(
torch
.
empty
(
1
,
1
,
self
.
slot_dim
)))
self
.
to_q
=
nn
.
Linear
(
slot_dim
,
slot_dim
,
bias
=
False
)
self
.
to_k
=
nn
.
Linear
(
input_dim
,
slot_dim
,
bias
=
False
)
self
.
to_v
=
nn
.
Linear
(
input_dim
,
slot_dim
,
bias
=
False
)
self
.
to_q
=
JaxLinear
(
slot_dim
,
slot_dim
,
bias
=
False
)
nn
.
init
.
xavier_uniform_
(
self
.
to_q
.
weight
,
gain
=
gain
)
self
.
to_k
=
JaxLinear
(
input_dim
,
slot_dim
,
bias
=
False
)
nn
.
init
.
xavier_uniform_
(
self
.
to_k
.
weight
,
gain
=
gain
)
self
.
to_v
=
JaxLinear
(
input_dim
,
slot_dim
,
bias
=
False
)
nn
.
init
.
xavier_uniform_
(
self
.
to_v
.
weight
,
gain
=
gain
)
self
.
gru
=
nn
.
GRUCell
(
slot_dim
,
slot_dim
)
self
.
gru
=
nn
.
GRUCell
(
slot_dim
,
slot_dim
)
self
.
mlp
=
nn
.
Sequential
(
self
.
mlp
=
nn
.
Sequential
(
Jax
Linear
(
slot_dim
,
hidden_dim
),
nn
.
Linear
(
slot_dim
,
hidden_dim
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
Jax
Linear
(
hidden_dim
,
slot_dim
)
nn
.
Linear
(
hidden_dim
,
slot_dim
)
)
)
self
.
norm_input
=
nn
.
LayerNorm
(
input_dim
)
self
.
norm_input
=
nn
.
LayerNorm
(
input_dim
)
self
.
norm_slots
=
nn
.
LayerNorm
(
slot_dim
)
self
.
norm_slots
=
nn
.
LayerNorm
(
slot_dim
)
self
.
norm_pre_mlp
=
nn
.
LayerNorm
(
slot_dim
)
self
.
norm_pre_mlp
=
nn
.
LayerNorm
(
slot_dim
)
def
forward
(
self
,
inputs
,
masks
=
None
):
def
forward
(
self
,
inputs
):
"""
"""
Args:
Args:
inputs: set-latent representation [batch_size, num_inputs, dim]
inputs: set-latent representation [batch_size, num_inputs, dim]
...
@@ -232,74 +241,56 @@ class SlotAttention(nn.Module):
...
@@ -232,74 +241,56 @@ class SlotAttention(nn.Module):
batch_size
,
num_inputs
,
dim
=
inputs
.
shape
batch_size
,
num_inputs
,
dim
=
inputs
.
shape
inputs
=
self
.
norm_input
(
inputs
)
inputs
=
self
.
norm_input
(
inputs
)
# Initialize the slots. Shape: [batch_size, num_slots, slot_dim].
if
self
.
randomize_initial_slots
:
slot_means
=
self
.
initial_slots
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
).
to
(
inputs
.
device
)
# from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slots
=
torch
.
distributions
.
Normal
(
slot_means
,
self
.
embedding_stdev
).
rsample
()
else
:
slots
=
self
.
initial_slots
.
unsqueeze
(
0
).
expand
(
batch_size
,
-
1
,
-
1
).
to
(
inputs
.
device
)
k
,
v
=
self
.
to_k
(
inputs
),
self
.
to_v
(
inputs
)
k
,
v
=
self
.
to_k
(
inputs
),
self
.
to_v
(
inputs
)
if
slots
is
None
:
slots
=
self
.
slots_mu
+
torch
.
exp
(
self
.
slots_log_sigma
)
*
torch
.
randn
(
len
(
inputs
),
self
.
num_slots
,
self
.
slot_size
,
device
=
self
.
slots_mu
.
device
)
# Multiple rounds of attention.
# Multiple rounds of attention.
for
_
in
range
(
self
.
iters
):
for
_
in
range
(
self
.
iters
):
slots_prev
=
slots
slots_prev
=
slots
norm_
slots
=
self
.
norm_slots
(
slots
)
slots
=
self
.
norm_slots
(
slots
)
q
=
self
.
to_q
(
norm_slots
)
q
=
self
.
to_q
(
slots
)
q
*=
self
.
scale
dots
=
torch
.
einsum
(
'
bid,bjd->bij
'
,
q
,
k
)
*
self
.
scale
# Dot product and normalization
attn_logits
=
torch
.
bmm
(
q
,
k
.
transpose
(
-
1
,
-
2
))
attn_pixelwise
=
F
.
softmax
(
attn_logits
/
self
.
temperature_factor
,
dim
=
1
)
if
masks
!=
None
:
temp_masks
=
masks
.
unsqueeze
(
1
)
attention_masking
=
torch
.
where
(
temp_masks
==
1.0
,
float
(
"
-inf
"
),
temp_masks
).
to
(
device
=
dots
.
device
)
dots
+=
attention_masking
# shape: [batch_size, num_slots, num_inputs]
# shape: [batch_size, num_slots, num_inputs]
attn
=
dots
.
softmax
(
dim
=
1
)
+
self
.
eps
attn
_slotwise
=
F
.
normalize
(
attn_pixelwise
+
self
.
eps
,
p
=
1
,
dim
=
-
1
)
# Weighted mean
# shape: [batch_size, num_inputs, slot_dim]
attn
=
attn
/
attn
.
sum
(
dim
=-
1
,
keepdim
=
True
)
updates
=
torch
.
bmm
(
attn_slotwise
,
v
)
updates
=
torch
.
einsum
(
'
bjd,bij->bid
'
,
v
,
attn
)
# shape: [batch_size, num_inputs, slot_dim]
# Slot update
# Slot update
slots
=
self
.
gru
(
updates
.
flatten
(
0
,
1
),
slots_prev
.
flatten
(
0
,
1
))
slots
=
self
.
gru
(
updates
.
flatten
(
0
,
1
),
slots_prev
.
flatten
(
0
,
1
))
slots
=
slots
.
reshape
(
batch_size
,
self
.
num_slots
,
self
.
slot_dim
)
slots
=
slots
.
reshape
(
batch_size
,
self
.
num_slots
,
self
.
slot_dim
)
slots
=
slots
+
self
.
mlp
(
self
.
norm_pre_mlp
(
slots
))
slots
=
slots
+
self
.
mlp
(
self
.
norm_pre_mlp
(
slots
))
return
slots
# [batch_size, num_slots, dim]
return
slots
,
attn_logits
,
attn_slotwise
# [batch_size, num_slots, dim]
def
change_slots_number
(
self
,
num_slots
):
def
change_slots_number
(
self
,
num_slots
):
self
.
num_slots
=
num_slots
self
.
num_slots
=
num_slots
self
.
initial_slots
=
nn
.
Parameter
(
torch
.
randn
(
num_slots
,
self
.
slot_dim
))
self
.
initial_slots
=
nn
.
Parameter
(
torch
.
randn
(
num_slots
,
self
.
slot_dim
))
### Utils for SlotAttentionAutoEncoder
def
build_grid
(
resolution
):
ranges
=
[
np
.
linspace
(
0.
,
1.
,
num
=
res
)
for
res
in
resolution
]
grid
=
np
.
meshgrid
(
*
ranges
,
sparse
=
False
,
indexing
=
"
ij
"
)
grid
=
np
.
stack
(
grid
,
axis
=-
1
)
grid
=
np
.
reshape
(
grid
,
[
resolution
[
0
],
resolution
[
1
],
-
1
])
grid
=
np
.
expand_dims
(
grid
,
axis
=
0
)
grid
=
grid
.
astype
(
np
.
float32
)
return
np
.
concatenate
([
grid
,
1.0
-
grid
],
axis
=-
1
)
class
SoftPositionEmbed
(
nn
.
Module
):
"""
Adds soft positional embedding with learnable projection.
Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def
__init__
(
self
,
hidden_size
,
resolution
):
"""
Builds the soft position embedding layer.
Args:
class
PositionEmbeddingImplicit
(
nn
.
Module
):
hidden_size: Size of input feature dimension.
"""
resolution: Tuple of integers specifying width and height of grid.
Position embedding extracted from
"""
https://github.com/vadimkantorov/yet_another_pytorch_slot_attention/blob/master/models.py
"""
def
__init__
(
self
,
hidden_dim
):
super
().
__init__
()
super
().
__init__
()
self
.
dense
=
JaxLinear
(
4
,
hidden_size
)
self
.
dense
=
nn
.
Linear
(
4
,
hidden_dim
)
self
.
grid
=
build_grid
(
resolution
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
x
):
return
inputs
+
self
.
dense
(
torch
.
tensor
(
self
.
grid
).
cuda
()).
permute
(
0
,
3
,
1
,
2
)
# from [b, h, w, c] to [b, c, h, w]
spatial_shape
=
x
.
shape
[
-
3
:
-
1
]
grid
=
torch
.
stack
(
torch
.
meshgrid
(
*
[
torch
.
linspace
(
0.
,
1.
,
r
,
device
=
x
.
device
)
for
r
in
spatial_shape
]),
dim
=
-
1
)
grid
=
torch
.
cat
([
grid
,
1
-
grid
],
dim
=
-
1
)
return
x
+
self
.
dense
(
grid
)
def
fourier_encode
(
x
,
max_freq
,
num_bands
=
4
):
def
fourier_encode
(
x
,
max_freq
,
num_bands
=
4
):
x
=
x
.
unsqueeze
(
-
1
)
x
=
x
.
unsqueeze
(
-
1
)
...
@@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4):
...
@@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4):
x
=
torch
.
cat
((
x
,
orig_x
),
dim
=
-
1
)
x
=
torch
.
cat
((
x
,
orig_x
),
dim
=
-
1
)
return
x
return
x
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class
TransformerSlotAttention
(
nn
.
Module
):
class
TransformerSlotAttention
(
nn
.
Module
):
"""
"""
...
@@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module):
...
@@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module):
x_d
=
self_attn
(
inputs
)
+
inputs
x_d
=
self_attn
(
inputs
)
+
inputs
inputs
=
self_ff
(
x_d
)
+
x_d
inputs
=
self_ff
(
x_d
)
+
x_d
return
slots
# [batch_size, num_slots, dim]
return
slots
,
None
,
None
# [batch_size, num_slots, dim]
This diff is collapsed.
Click to expand it.
osrt/model.py
+
34
−
53
View file @
672cd717
from
torch
import
nn
from
torch
import
nn
import
torch
import
torch
import
torch.nn.functional
as
F
import
numpy
as
np
import
numpy
as
np
from
osrt.encoder
import
OSRTEncoder
,
ImprovedSRTEncoder
,
FeatureMasking
from
osrt.encoder
import
OSRTEncoder
,
ImprovedSRTEncoder
,
FeatureMasking
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.layers
import
SlotAttention
,
JaxLinear
,
SoftPositionEmbed
,
TransformerSlotAttention
from
osrt.layers
import
SlotAttention
,
PositionEmbeddingImplicit
,
TransformerSlotAttention
import
osrt.layers
as
layers
import
osrt.layers
as
layers
class
OSRT
(
nn
.
Module
):
class
OSRT
(
nn
.
Module
):
def
__init__
(
self
,
cfg
):
def
__init__
(
self
,
cfg
):
super
().
__init__
()
super
().
__init__
()
...
@@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module):
...
@@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module):
self
.
num_iterations
=
num_iterations
self
.
num_iterations
=
num_iterations
self
.
encoder_cnn
=
nn
.
Sequential
(
self
.
encoder_cnn
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(
inplace
=
True
)
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
()
)
)
self
.
decoder_initial_size
=
(
8
,
8
)
self
.
decoder_initial_size
=
(
8
,
8
)
self
.
decoder_cnn
=
nn
.
Sequential
(
self
.
decoder_cnn
=
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
(
2
,
2
),
padding
=
2
,
output_padding
=
1
),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
(
2
,
2
),
padding
=
2
,
output_padding
=
1
),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
,
padding
=
1
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
(
2
,
2
),
padding
=
2
,
output_padding
=
1
),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
4
,
kernel_size
=
3
)
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
(
2
,
2
),
padding
=
2
,
output_padding
=
1
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
(
1
,
1
),
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
4
,
kernel_size
=
3
,
stride
=
(
1
,
1
),
padding
=
1
)
)
)
self
.
encoder_pos
=
Soft
PositionEmbed
(
64
,
self
.
resolution
)
self
.
encoder_pos
=
PositionEmbed
dingImplicit
(
64
)
self
.
decoder_pos
=
Soft
PositionEmbed
(
64
,
self
.
decoder_initial_size
)
self
.
decoder_pos
=
PositionEmbed
dingImplicit
(
64
)
self
.
layer_norm
=
nn
.
LayerNorm
(
64
)
self
.
layer_norm
=
nn
.
LayerNorm
(
64
)
self
.
mlp
=
nn
.
Sequential
(
self
.
mlp
=
nn
.
Sequential
(
Jax
Linear
(
64
,
64
),
nn
.
Linear
(
64
,
64
),
nn
.
ReLU
(),
nn
.
ReLU
(
inplace
=
True
),
Jax
Linear
(
64
,
64
)
nn
.
Linear
(
64
,
64
)
)
)
model_type
=
cfg
[
'
model
'
][
'
model_type
'
]
model_type
=
cfg
[
'
model
'
][
'
model_type
'
]
...
@@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module):
...
@@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module):
depth
=
self
.
num_iterations
)
# in a way, the depth of the transformer corresponds to the number of iterations in the original model
depth
=
self
.
num_iterations
)
# in a way, the depth of the transformer corresponds to the number of iterations in the original model
def
forward
(
self
,
image
):
def
forward
(
self
,
image
):
# `image` has shape: [batch_size, num_channels, width, height].
x
=
self
.
encoder_cnn
(
image
).
movedim
(
1
,
-
1
)
# Convolutional encoder with position embedding.
x
=
self
.
encoder_pos
(
x
)
x
=
self
.
encoder_cnn
(
image
)
# CNN Backbone.
x
=
self
.
mlp
(
self
.
layer_norm
(
x
))
x
=
self
.
encoder_pos
(
x
).
permute
(
0
,
2
,
3
,
1
)
# Position embedding.
x
=
spatial_flatten
(
x
)
# Flatten spatial dimensions (treat image as set).
slots
,
attn_logits
,
attn_slotwise
=
self
.
slot_attention
(
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
),
slots
=
slots
)
x
=
self
.
mlp
(
self
.
layer_norm
(
x
))
# Feedforward network on set.
x
=
slots
.
reshape
(
-
1
,
1
,
1
,
slots
.
shape
[
-
1
]).
expand
(
-
1
,
*
self
.
decoder_initial_size
,
-
1
)
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots
=
self
.
slot_attention
(
x
)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x
=
spatial_broadcast
(
slots
,
self
.
decoder_initial_size
).
permute
(
0
,
3
,
1
,
2
)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
x
=
self
.
decoder_pos
(
x
)
x
=
self
.
decoder_pos
(
x
)
x
=
self
.
decoder_cnn
(
x
).
permute
(
0
,
2
,
3
,
1
)
x
=
self
.
decoder_cnn
(
x
.
movedim
(
-
1
,
1
))
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
x
=
F
.
interpolate
(
x
,
image
.
shape
[
-
2
:],
mode
=
self
.
interpolate_mode
)
# Undo combination of slot and batch dimension; split alpha masks.
x
=
x
.
unflatten
(
0
,
(
len
(
image
),
len
(
x
)
//
len
(
image
)))
recons
,
masks
=
unstack_and_split
(
x
,
batch_size
=
image
.
shape
[
0
])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
recons
,
masks
=
x
.
split
((
3
,
1
),
dim
=
2
)
masks
=
torch
.
softmax
(
masks
,
dim
=
1
)
masks
=
masks
.
softmax
(
dim
=
1
)
recon_combined
=
torch
.
sum
(
recons
*
masks
,
dim
=
1
)
# Recombine image.
recon_combined
=
(
recons
*
masks
).
sum
(
dim
=
1
)
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return
recon_combined
,
recons
,
masks
,
slots
return
recon_combined
,
recons
,
masks
,
slots
,
attn_slotwise
.
unsqueeze
(
-
2
).
unflatten
(
-
1
,
x
.
shape
[
-
2
:])
This diff is collapsed.
Click to expand it.
visualise.py
0 → 100644
+
81
−
0
View file @
672cd717
import
datetime
import
time
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
import
argparse
import
yaml
from
osrt.model
import
SlotAttentionAutoEncoder
from
osrt
import
data
from
osrt.utils.visualize
import
visualize_slot_attention
from
osrt.utils.common
import
mse2psnr
from
torch.utils.data
import
DataLoader
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
def
main
():
# Arguments
parser
=
argparse
.
ArgumentParser
(
description
=
'
Train a 3D scene representation model.
'
)
parser
.
add_argument
(
'
config
'
,
type
=
str
,
help
=
"
Where to save the checkpoints.
"
)
parser
.
add_argument
(
'
--wandb
'
,
action
=
'
store_true
'
,
help
=
'
Log run to Weights and Biases.
'
)
parser
.
add_argument
(
'
--seed
'
,
type
=
int
,
default
=
0
,
help
=
'
Random seed.
'
)
parser
.
add_argument
(
'
--ckpt
'
,
type
=
str
,
default
=
"
.
"
,
help
=
'
Model checkpoint path
'
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config
,
'
r
'
)
as
f
:
cfg
=
yaml
.
load
(
f
,
Loader
=
yaml
.
CLoader
)
### Set random seed.
torch
.
manual_seed
(
args
.
seed
)
### Hyperparameters of the model.
num_slots
=
cfg
[
"
model
"
][
"
num_slots
"
]
num_iterations
=
cfg
[
"
model
"
][
"
iters
"
]
base_learning_rate
=
0.0004
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
resolution
=
(
128
,
128
)
#### Create datasets
vis_dataset
=
data
.
get_dataset
(
'
test
'
,
cfg
[
'
data
'
])
vis_loader
=
DataLoader
(
vis_dataset
,
batch_size
=
1
,
num_workers
=
cfg
[
"
training
"
][
"
num_workers
"
],
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
#### Create model
model
=
SlotAttentionAutoEncoder
(
resolution
,
10
,
num_iterations
,
cfg
=
cfg
).
to
(
device
)
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
print
(
'
Number of parameters:
'
)
print
(
f
'
Model slot attention:
{
num_params
}
'
)
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
base_learning_rate
,
eps
=
1e-08
)
ckpt
=
{
'
network
'
:
model
,
'
optimizer
'
:
optimizer
,
'
global_step
'
:
1639
}
#ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
"""
ckpt = torch.load(
'
~/ckpt.pth
'
)
model = ckpt[
'
network
'
]
"""
model
.
load_state_dict
(
torch
.
load
(
'
/home/achapin/ckpt.pth
'
)[
"
model_state_dict
"
])
image
=
torch
.
squeeze
(
next
(
iter
(
vis_loader
)).
get
(
'
input_images
'
).
to
(
device
),
dim
=
1
)
image
=
F
.
interpolate
(
image
,
size
=
128
)
image
=
image
.
to
(
device
)
recon_combined
,
recons
,
masks
,
slots
=
model
(
image
)
loss
=
nn
.
MSELoss
()
input_image
=
image
.
permute
(
0
,
2
,
3
,
1
)
loss_value
=
loss
(
recon_combined
,
input_image
)
psnr
=
mse2psnr
(
loss_value
)
print
(
f
"
MSE value :
{
loss_value
}
VS PSNR
{
psnr
}
"
)
visualize_slot_attention
(
num_slots
,
image
,
recon_combined
,
recons
,
masks
,
folder_save
=
args
.
ckpt
,
step
=
1639
,
save_file
=
True
)
if
__name__
==
"
__main__
"
:
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment