Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
b79778fe
Commit
b79778fe
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Fix issue with pos encode slot att
parent
803079e6
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
osrt/layers.py
+1
-118
1 addition, 118 deletions
osrt/layers.py
osrt/model.py
+120
-0
120 additions, 0 deletions
osrt/model.py
train_sa.py
+1
-1
1 addition, 1 deletion
train_sa.py
with
122 additions
and
119 deletions
osrt/layers.py
+
1
−
118
View file @
b79778fe
...
@@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module):
...
@@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module):
return
slots
# [batch_size, num_slots, dim]
return
slots
# [batch_size, num_slots, dim]
def
unstack_and_split
(
x
,
batch_size
,
num_channels
=
3
):
"""
Unstack batch dimension and split into channels and alpha mask.
"""
unstacked
=
x
.
view
(
batch_size
,
-
1
,
*
x
.
shape
[
1
:])
channels
,
masks
=
torch
.
split
(
unstacked
,
[
num_channels
,
1
],
dim
=-
1
)
return
channels
,
masks
def
spatial_flatten
(
x
):
return
x
.
view
(
-
1
,
x
.
shape
[
1
]
*
x
.
shape
[
2
],
x
.
shape
[
-
1
])
def
spatial_broadcast
(
slots
,
resolution
):
"""
Broadcast slot features to a 2D grid and collapse slot dimension.
"""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots
=
slots
.
view
(
-
1
,
slots
.
shape
[
-
1
])[:,
None
,
None
,
:]
grid
=
slots
.
repeat
(
1
,
resolution
[
0
],
resolution
[
1
],
1
)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return
grid
# TODO : adapt this model
class
SlotAttentionAutoEncoder
(
nn
.
Module
):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def
__init__
(
self
,
resolution
,
num_slots
,
num_iterations
):
"""
Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super
().
__init__
()
self
.
resolution
=
resolution
self
.
num_slots
=
num_slots
self
.
num_iterations
=
num_iterations
self
.
encoder_cnn
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
()
)
self
.
decoder_initial_size
=
(
8
,
8
)
self
.
decoder_cnn
=
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
1
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
4
,
kernel_size
=
3
,
stride
=
1
,
padding
=
2
)
)
self
.
encoder_pos
=
SoftPositionEmbed
(
64
,
self
.
resolution
)
self
.
decoder_pos
=
SoftPositionEmbed
(
64
,
self
.
decoder_initial_size
)
self
.
layer_norm
=
nn
.
LayerNorm
(
64
)
self
.
mlp
=
nn
.
Sequential
(
JaxLinear
(
64
,
64
),
nn
.
ReLU
(),
JaxLinear
(
64
,
64
)
)
self
.
slot_attention
=
SlotAttention
(
num_slots
=
self
.
num_slots
,
slot_dim
=
64
,
hidden_dim
=
128
,
iters
=
self
.
num_iterations
)
def
forward
(
self
,
image
):
# `image` has shape: [batch_size, width, height, num_channels].
# Convolutional encoder with position embedding.
x
=
self
.
encoder_cnn
(
image
)
# CNN Backbone.
#x = self.encoder_pos(x) # Position embedding.
x
=
spatial_flatten
(
x
)
# Flatten spatial dimensions (treat image as set).
x
=
self
.
mlp
(
self
.
layer_norm
(
x
))
# Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots
=
self
.
slot_attention
(
x
)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x
=
spatial_broadcast
(
slots
,
self
.
decoder_initial_size
)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
#x = self.decoder_pos(x)
x
=
self
.
decoder_cnn
(
x
)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
# Undo combination of slot and batch dimension; split alpha masks.
recons
,
masks
=
unstack_and_split
(
x
,
batch_size
=
image
.
shape
[
0
])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
masks
=
torch
.
softmax
(
masks
,
dim
=
1
)
recon_combined
=
torch
.
sum
(
recons
*
masks
,
dim
=
1
)
# Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return
recon_combined
,
recons
,
masks
,
slots
def
build_grid
(
resolution
):
def
build_grid
(
resolution
):
ranges
=
[
np
.
linspace
(
0.
,
1.
,
num
=
res
)
for
res
in
resolution
]
ranges
=
[
np
.
linspace
(
0.
,
1.
,
num
=
res
)
for
res
in
resolution
]
grid
=
np
.
meshgrid
(
*
ranges
,
sparse
=
False
,
indexing
=
"
ij
"
)
grid
=
np
.
meshgrid
(
*
ranges
,
sparse
=
False
,
indexing
=
"
ij
"
)
...
@@ -469,7 +354,7 @@ def build_grid(resolution):
...
@@ -469,7 +354,7 @@ def build_grid(resolution):
grid
=
np
.
reshape
(
grid
,
[
resolution
[
0
],
resolution
[
1
],
-
1
])
grid
=
np
.
reshape
(
grid
,
[
resolution
[
0
],
resolution
[
1
],
-
1
])
grid
=
np
.
expand_dims
(
grid
,
axis
=
0
)
grid
=
np
.
expand_dims
(
grid
,
axis
=
0
)
grid
=
grid
.
astype
(
np
.
float32
)
grid
=
grid
.
astype
(
np
.
float32
)
return
np
.
concatenate
([
grid
,
1.0
-
grid
],
axis
=-
1
)
return
np
.
concatenate
([
grid
,
1.0
-
grid
],
axis
=-
1
)
.
transpose
(
0
,
3
,
1
,
2
)
# from [b, h, w, c] to [b, c, h, w]
class
SoftPositionEmbed
(
nn
.
Module
):
class
SoftPositionEmbed
(
nn
.
Module
):
"""
Adds soft positional embedding with learnable projection.
"""
Adds soft positional embedding with learnable projection.
...
@@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module):
...
@@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module):
self
.
grid
=
build_grid
(
resolution
)
self
.
grid
=
build_grid
(
resolution
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
print
(
inputs
.
shape
)
print
(
self
.
dense
(
torch
.
tensor
(
self
.
grid
).
cuda
()).
shape
)
return
inputs
+
self
.
dense
(
torch
.
tensor
(
self
.
grid
).
cuda
())
return
inputs
+
self
.
dense
(
torch
.
tensor
(
self
.
grid
).
cuda
())
\ No newline at end of file
This diff is collapsed.
Click to expand it.
osrt/model.py
+
120
−
0
View file @
b79778fe
from
torch
import
nn
from
torch
import
nn
import
torch
import
numpy
as
np
from
osrt.encoder
import
OSRTEncoder
,
ImprovedSRTEncoder
,
FeatureMasking
from
osrt.encoder
import
OSRTEncoder
,
ImprovedSRTEncoder
,
FeatureMasking
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.layers
import
SlotAttention
,
JaxLinear
,
SoftPositionEmbed
import
osrt.layers
as
layers
import
osrt.layers
as
layers
...
@@ -33,3 +36,120 @@ class OSRT(nn.Module):
...
@@ -33,3 +36,120 @@ class OSRT(nn.Module):
raise
ValueError
(
f
'
Unknown decoder type:
{
decoder_type
}
'
)
raise
ValueError
(
f
'
Unknown decoder type:
{
decoder_type
}
'
)
def
unstack_and_split
(
x
,
batch_size
,
num_channels
=
3
):
"""
Unstack batch dimension and split into channels and alpha mask.
"""
unstacked
=
x
.
view
(
batch_size
,
-
1
,
*
x
.
shape
[
1
:])
channels
,
masks
=
torch
.
split
(
unstacked
,
[
num_channels
,
1
],
dim
=-
1
)
return
channels
,
masks
def
spatial_flatten
(
x
):
return
x
.
view
(
-
1
,
x
.
shape
[
1
]
*
x
.
shape
[
2
],
x
.
shape
[
-
1
])
def
spatial_broadcast
(
slots
,
resolution
):
"""
Broadcast slot features to a 2D grid and collapse slot dimension.
"""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots
=
slots
.
view
(
-
1
,
slots
.
shape
[
-
1
])[:,
None
,
None
,
:]
grid
=
slots
.
repeat
(
1
,
resolution
[
0
],
resolution
[
1
],
1
)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return
grid
# TODO : adapt this model
class
SlotAttentionAutoEncoder
(
nn
.
Module
):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def
__init__
(
self
,
resolution
,
num_slots
,
num_iterations
):
"""
Builds the Slot Attention-based auto-encoder.
Args:
resolution: Tuple of integers specifying width and height of input image.
num_slots: Number of slots in Slot Attention.
num_iterations: Number of iterations in Slot Attention.
"""
super
().
__init__
()
self
.
resolution
=
resolution
self
.
num_slots
=
num_slots
self
.
num_iterations
=
num_iterations
self
.
encoder_cnn
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
Conv2d
(
64
,
64
,
kernel_size
=
5
,
padding
=
2
),
nn
.
ReLU
()
)
self
.
decoder_initial_size
=
(
8
,
8
)
self
.
decoder_cnn
=
nn
.
Sequential
(
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
2
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
64
,
kernel_size
=
5
,
stride
=
1
,
padding
=
2
),
nn
.
ReLU
(),
nn
.
ConvTranspose2d
(
64
,
4
,
kernel_size
=
3
,
stride
=
1
,
padding
=
2
)
)
self
.
encoder_pos
=
SoftPositionEmbed
(
64
,
(
32
,
32
))
self
.
decoder_pos
=
SoftPositionEmbed
(
64
,
self
.
decoder_initial_size
)
self
.
layer_norm
=
nn
.
LayerNorm
(
64
)
self
.
mlp
=
nn
.
Sequential
(
JaxLinear
(
64
,
64
),
nn
.
ReLU
(),
JaxLinear
(
64
,
64
)
)
self
.
slot_attention
=
SlotAttention
(
num_slots
=
self
.
num_slots
,
slot_dim
=
64
,
hidden_dim
=
128
,
iters
=
self
.
num_iterations
)
def
forward
(
self
,
image
):
# `image` has shape: [batch_size, num_channels, width, height].
print
(
f
"
Shape input
{
image
.
shape
}
"
)
# Convolutional encoder with position embedding.
x
=
self
.
encoder_cnn
(
image
)
# CNN Backbone.
print
(
f
"
Shape after encoder
{
x
.
shape
}
"
)
x
=
self
.
encoder_pos
(
x
)
# Position embedding.
x
=
spatial_flatten
(
x
)
# Flatten spatial dimensions (treat image as set).
x
=
self
.
mlp
(
self
.
layer_norm
(
x
))
# Feedforward network on set.
# `x` has shape: [batch_size, width*height, input_size].
# Slot Attention module.
slots
=
self
.
slot_attention
(
x
)
# `slots` has shape: [batch_size, num_slots, slot_size].
# Spatial broadcast decoder.
x
=
spatial_broadcast
(
slots
,
self
.
decoder_initial_size
)
# `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
#x = self.decoder_pos(x)
x
=
self
.
decoder_cnn
(
x
)
# `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
# Undo combination of slot and batch dimension; split alpha masks.
recons
,
masks
=
unstack_and_split
(
x
,
batch_size
=
image
.
shape
[
0
])
# `recons` has shape: [batch_size, num_slots, width, height, num_channels].
# `masks` has shape: [batch_size, num_slots, width, height, 1].
# Normalize alpha masks over slots.
masks
=
torch
.
softmax
(
masks
,
dim
=
1
)
recon_combined
=
torch
.
sum
(
recons
*
masks
,
dim
=
1
)
# Recombine image.
# `recon_combined` has shape: [batch_size, width, height, num_channels].
return
recon_combined
,
recons
,
masks
,
slots
This diff is collapsed.
Click to expand it.
train_sa.py
+
1
−
1
View file @
b79778fe
...
@@ -5,7 +5,7 @@ import torch.nn as nn
...
@@ -5,7 +5,7 @@ import torch.nn as nn
import
torch.optim
as
optim
import
torch.optim
as
optim
import
argparse
import
argparse
import
yaml
import
yaml
from
osrt.
layers
import
SlotAttentionAutoEncoder
from
osrt.
model
import
SlotAttentionAutoEncoder
from
osrt
import
data
from
osrt
import
data
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment