Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
349fae70
Commit
349fae70
authored
1 year ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Implement lightning Slot Attention
parent
672cd717
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
osrt/model.py
+67
-19
67 additions, 19 deletions
osrt/model.py
runs/clevr3d/slot_att/config.yaml
+1
-4
1 addition, 4 deletions
runs/clevr3d/slot_att/config.yaml
train_sa.py
+36
-127
36 additions, 127 deletions
train_sa.py
visualise.py
+4
-2
4 additions, 2 deletions
visualise.py
with
108 additions
and
152 deletions
osrt/model.py
+
67
−
19
View file @
349fae70
from
typing
import
Any
from
lightning.pytorch.utilities.types
import
STEP_OUTPUT
from
torch
import
nn
from
torch
import
nn
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
numpy
as
np
import
numpy
as
np
...
@@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
...
@@ -8,7 +11,9 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.decoder
import
SlotMixerDecoder
,
SpatialBroadcastDecoder
,
ImprovedSRTDecoder
from
osrt.layers
import
SlotAttention
,
PositionEmbeddingImplicit
,
TransformerSlotAttention
from
osrt.layers
import
SlotAttention
,
PositionEmbeddingImplicit
,
TransformerSlotAttention
import
osrt.layers
as
layers
import
osrt.layers
as
layers
from
osrt.utils.common
import
mse2psnr
import
lightning
as
pl
class
OSRT
(
nn
.
Module
):
class
OSRT
(
nn
.
Module
):
...
@@ -39,25 +44,7 @@ class OSRT(nn.Module):
...
@@ -39,25 +44,7 @@ class OSRT(nn.Module):
raise
ValueError
(
f
'
Unknown decoder type:
{
decoder_type
}
'
)
raise
ValueError
(
f
'
Unknown decoder type:
{
decoder_type
}
'
)
class
LitSlotAttentionAutoEncoder
(
pl
.
LightningModule
):
def
unstack_and_split
(
x
,
batch_size
,
num_channels
=
3
):
"""
Unstack batch dimension and split into channels and alpha mask.
"""
unstacked
=
x
.
view
(
batch_size
,
-
1
,
*
x
.
shape
[
1
:])
channels
,
masks
=
torch
.
split
(
unstacked
,
[
num_channels
,
1
],
dim
=-
1
)
return
channels
,
masks
def
spatial_flatten
(
x
):
return
x
.
view
(
-
1
,
x
.
shape
[
1
]
*
x
.
shape
[
2
],
x
.
shape
[
-
1
])
def
spatial_broadcast
(
slots
,
resolution
):
"""
Broadcast slot features to a 2D grid and collapse slot dimension.
"""
# `slots` has shape: [batch_size, num_slots, slot_size].
slots
=
slots
.
view
(
-
1
,
slots
.
shape
[
-
1
])[:,
None
,
None
,
:]
grid
=
slots
.
repeat
(
1
,
resolution
[
0
],
resolution
[
1
],
1
)
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return
grid
class
SlotAttentionAutoEncoder
(
nn
.
Module
):
"""
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
...
@@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module):
...
@@ -140,4 +127,65 @@ class SlotAttentionAutoEncoder(nn.Module):
recon_combined
=
(
recons
*
masks
).
sum
(
dim
=
1
)
recon_combined
=
(
recons
*
masks
).
sum
(
dim
=
1
)
return
recon_combined
,
recons
,
masks
,
slots
,
attn_slotwise
.
unsqueeze
(
-
2
).
unflatten
(
-
1
,
x
.
shape
[
-
2
:])
return
recon_combined
,
recons
,
masks
,
slots
,
attn_slotwise
.
unsqueeze
(
-
2
).
unflatten
(
-
1
,
x
.
shape
[
-
2
:])
def
configure_optimizers
(
self
)
->
Any
:
optimizer
=
optim
.
Adam
(
self
.
parameters
,
lr
=
1e-3
,
eps
=
1e-08
)
return
optimizer
def
one_step
(
self
,
image
):
x
=
self
.
encoder_cnn
(
image
).
movedim
(
1
,
-
1
)
x
=
self
.
encoder_pos
(
x
)
x
=
self
.
mlp
(
self
.
layer_norm
(
x
))
slots
,
attn_logits
,
attn_slotwise
=
self
.
slot_attention
(
x
.
flatten
(
start_dim
=
1
,
end_dim
=
2
),
slots
=
slots
)
x
=
slots
.
reshape
(
-
1
,
1
,
1
,
slots
.
shape
[
-
1
]).
expand
(
-
1
,
*
self
.
decoder_initial_size
,
-
1
)
x
=
self
.
decoder_pos
(
x
)
x
=
self
.
decoder_cnn
(
x
.
movedim
(
-
1
,
1
))
x
=
F
.
interpolate
(
x
,
image
.
shape
[
-
2
:],
mode
=
self
.
interpolate_mode
)
x
=
x
.
unflatten
(
0
,
(
len
(
image
),
len
(
x
)
//
len
(
image
)))
recons
,
masks
=
x
.
split
((
3
,
1
),
dim
=
2
)
masks
=
masks
.
softmax
(
dim
=
1
)
recon_combined
=
(
recons
*
masks
).
sum
(
dim
=
1
)
return
recon_combined
,
recons
,
masks
,
slots
,
attn_slotwise
.
unsqueeze
(
-
2
).
unflatten
(
-
1
,
x
.
shape
[
-
2
:])
def
training_step
(
self
,
batch
,
criterion
):
"""
Perform a single training step.
"""
input_image
=
torch
.
squeeze
(
batch
.
get
(
'
input_images
'
),
dim
=
1
)
input_image
=
F
.
interpolate
(
input_image
,
size
=
128
)
# Get the prediction of the model and compute the loss.
preds
=
self
.
one_step
(
input_image
)
recon_combined
,
recons
,
masks
,
slots
=
preds
input_image
=
input_image
.
permute
(
0
,
2
,
3
,
1
)
loss_value
=
criterion
(
recon_combined
,
input_image
)
del
recons
,
masks
,
slots
# Unused.
# Get and apply gradients.
self
.
optimizer
.
zero_grad
()
loss_value
.
backward
()
self
.
optimizer
.
step
()
self
.
log
(
'
train_mse
'
,
loss_value
,
on_epoch
=
True
)
return
loss_value
.
item
()
def
validation_step
(
self
,
batch
,
criterion
):
"""
Perform a single eval step.
"""
input_image
=
torch
.
squeeze
(
batch
.
get
(
'
input_images
'
),
dim
=
1
)
input_image
=
F
.
interpolate
(
input_image
,
size
=
128
)
# Get the prediction of the model and compute the loss.
preds
=
self
.
one_step
(
input_image
)
recon_combined
,
recons
,
masks
,
slots
=
preds
input_image
=
input_image
.
permute
(
0
,
2
,
3
,
1
)
loss_value
=
criterion
(
recon_combined
,
input_image
)
del
recons
,
masks
,
slots
# Unused.
psnr
=
mse2psnr
(
loss_value
)
self
.
log
(
'
val_mse
'
,
loss_value
)
self
.
log
(
'
val_psnr
'
,
psnr
)
return
loss_value
.
item
(),
psnr
.
item
()
This diff is collapsed.
Click to expand it.
runs/clevr3d/slot_att/config.yaml
+
1
−
4
View file @
349fae70
...
@@ -6,13 +6,10 @@ model:
...
@@ -6,13 +6,10 @@ model:
model_type
:
sa
model_type
:
sa
training
:
training
:
num_workers
:
2
num_workers
:
2
num_gpus
:
8
batch_size
:
32
batch_size
:
32
max_it
:
333000000
max_it
:
333000000
warmup_it
:
10000
warmup_it
:
10000
decay_rate
:
0.5
decay_rate
:
0.5
decay_it
:
100000
decay_it
:
100000
print_every
:
1
validate_every
:
1
checkpoint_every
:
1
visualize_every
:
2
This diff is collapsed.
Click to expand it.
train_sa.py
+
36
−
127
View file @
349fae70
...
@@ -2,52 +2,21 @@ import datetime
...
@@ -2,52 +2,21 @@ import datetime
import
time
import
time
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
argparse
import
argparse
import
yaml
import
yaml
from
osrt.model
import
SlotAttentionAutoEncoder
from
osrt.model
import
Lit
SlotAttentionAutoEncoder
from
osrt
import
data
from
osrt
import
data
from
osrt.utils.visualize
import
visualize_slot_attention
from
osrt.utils.visualize
import
visualize_slot_attention
from
osrt.utils.common
import
mse2psnr
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
tqdm
import
tqdm
def
train_step
(
batch
,
model
,
optimizer
,
device
,
criterion
):
import
lightning
as
pl
"""
Perform a single training step.
"""
from
lightning.pytorch.loggers.wandb
import
WandbLogger
input_image
=
torch
.
squeeze
(
batch
.
get
(
'
input_images
'
).
to
(
device
),
dim
=
1
)
from
lightning.pytorch.callbacks
import
ModelCheckpoint
input_image
=
F
.
interpolate
(
input_image
,
size
=
128
)
# Get the prediction of the model and compute the loss.
preds
=
model
(
input_image
)
recon_combined
,
recons
,
masks
,
slots
=
preds
input_image
=
input_image
.
permute
(
0
,
2
,
3
,
1
)
loss_value
=
criterion
(
recon_combined
,
input_image
)
del
recons
,
masks
,
slots
# Unused.
# Get and apply gradients.
optimizer
.
zero_grad
()
loss_value
.
backward
()
optimizer
.
step
()
return
loss_value
.
item
()
def
eval_step
(
batch
,
model
,
device
,
criterion
):
"""
Perform a single eval step.
"""
input_image
=
torch
.
squeeze
(
batch
.
get
(
'
input_images
'
).
to
(
device
),
dim
=
1
)
input_image
=
F
.
interpolate
(
input_image
,
size
=
128
)
# Get the prediction of the model and compute the loss.
preds
=
model
(
input_image
)
recon_combined
,
recons
,
masks
,
slots
=
preds
input_image
=
input_image
.
permute
(
0
,
2
,
3
,
1
)
loss_value
=
criterion
(
recon_combined
,
input_image
)
del
recons
,
masks
,
slots
# Unused.
psnr
=
mse2psnr
(
loss_value
)
return
loss_value
.
item
(),
psnr
.
item
()
def
main
():
def
main
():
# Arguments
# Arguments
...
@@ -64,20 +33,17 @@ def main():
...
@@ -64,20 +33,17 @@ def main():
cfg
=
yaml
.
load
(
f
,
Loader
=
yaml
.
CLoader
)
cfg
=
yaml
.
load
(
f
,
Loader
=
yaml
.
CLoader
)
### Set random seed.
### Set random seed.
torch
.
manual_seed
(
args
.
seed
)
pl
.
seed_everything
(
42
,
workers
=
True
)
### Hyperparameters of the model.
### Hyperparameters of the model.
batch_size
=
cfg
[
"
training
"
][
"
batch_size
"
]
batch_size
=
cfg
[
"
training
"
][
"
batch_size
"
]
num_gpus
=
cfg
[
"
training
"
][
"
num_gpus
"
]
num_slots
=
cfg
[
"
model
"
][
"
num_slots
"
]
num_slots
=
cfg
[
"
model
"
][
"
num_slots
"
]
num_iterations
=
cfg
[
"
model
"
][
"
iters
"
]
num_iterations
=
cfg
[
"
model
"
][
"
iters
"
]
base_learning_rate
=
0.0004
num_train_steps
=
cfg
[
"
training
"
][
"
max_it
"
]
num_train_steps
=
cfg
[
"
training
"
][
"
max_it
"
]
warmup_steps
=
cfg
[
"
training
"
][
"
warmup_it
"
]
warmup_steps
=
cfg
[
"
training
"
][
"
warmup_it
"
]
decay_rate
=
cfg
[
"
training
"
][
"
decay_rate
"
]
decay_rate
=
cfg
[
"
training
"
][
"
decay_rate
"
]
decay_steps
=
cfg
[
"
training
"
][
"
decay_it
"
]
decay_steps
=
cfg
[
"
training
"
][
"
decay_it
"
]
device
=
torch
.
device
(
"
cuda
"
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
criterion
=
nn
.
MSELoss
()
resolution
=
(
128
,
128
)
resolution
=
(
128
,
128
)
#### Create datasets
#### Create datasets
...
@@ -90,71 +56,36 @@ def main():
...
@@ -90,71 +56,36 @@ def main():
val_loader
=
DataLoader
(
val_loader
=
DataLoader
(
val_dataset
,
batch_size
=
batch_size
,
num_workers
=
1
,
val_dataset
,
batch_size
=
batch_size
,
num_workers
=
1
,
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
vis_dataset
=
data
.
get_dataset
(
'
test
'
,
cfg
[
'
data
'
])
vis_loader
=
DataLoader
(
vis_dataset
,
batch_size
=
1
,
num_workers
=
cfg
[
"
training
"
][
"
num_workers
"
],
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
#### Create model
#### Create model
model
=
SlotAttentionAutoEncoder
(
resolution
,
num_slots
,
num_iterations
,
cfg
=
cfg
).
to
(
device
)
model
=
LitSlotAttentionAutoEncoder
(
resolution
,
num_slots
,
num_iterations
,
cfg
=
cfg
)
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
wandb_logger
=
WandbLogger
()
print
(
'
Number of parameters:
'
)
print
(
f
'
Model slot attention:
{
num_params
}
'
)
checkpoint_callback
=
ModelCheckpoint
(
save_top_k
=
10
,
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
lr
=
base_learning_rate
,
eps
=
1e-08
)
monitor
=
"
val_psnr
"
,
mode
=
"
max
"
,
#### Prepare checkpoint manager.
dirpath
=
"
./checkpoints
"
if
cfg
[
"
model
"
][
"
model_type
"
]
==
"
sa
"
else
"
./checkpoints_tsa
"
,
global_step
=
0
filename
=
"
slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth
"
,
ckpt
=
{
)
'
network
'
:
model
,
'
optimizer
'
:
optimizer
,
trainer
=
pl
.
Trainer
(
accelerator
=
"
gpu
"
,
devices
=
num_gpus
,
profiler
=
"
simple
"
,
'
global_step
'
:
global_step
default_root_dir
=
"
./logs
"
,
logger
=
wandb_logger
,
}
strategy
=
"
ddp
"
if
num_gpus
>
1
else
"
default
"
,
callbacks
=
[
checkpoint_callback
],
deterministic
=
True
,
ckpt_manager
=
torch
.
save
(
ckpt
,
args
.
ckpt
+
'
/ckpt.pth
'
)
log_every_n_steps
=
100
,
max_steps
=
num_train_steps
)
# ckpt = torch.load(args.ckpt + '/ckpt.pth')
model
=
ckpt
[
'
network
'
]
trainer
.
fit
(
model
,
train_loader
,
val_loader
)
optimizer
=
ckpt
[
'
optimizer
'
]
global_step
=
ckpt
[
'
global_step
'
]
if
__name__
==
"
__main__
"
:
main
()
"""
TODO : setup wandb
if args.wandb:
if run_id is None:
#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
run_id = wandb.util.generate_id()
print(f
'
Sampled new wandb run_id {run_id}.
'
)
"""
else:
print(f
'
Resuming wandb with existing run_id {run_id}.
'
)
if not epoch % cfg[
"
training
"
][
"
checkpoint_every
"
]:
# Tell in which mode to launch the logging in W&B (for offline cluster)
if args.offline_log:
mode =
"
offline
"
else:
mode =
"
online
"
wandb.init(project=
'
osrt
'
, name=os.path.dirname(args.config),
id=run_id, resume=True, mode=mode, sync_tensorboard=True)
wandb.config = cfg
"""
start
=
time
.
time
()
epochs
=
num_train_steps
//
len
(
train_loader
)
for
epoch
in
range
(
epochs
):
total_loss
=
0
model
.
train
()
for
batch
in
tqdm
(
train_loader
):
# Learning rate warm-up.
if
global_step
<
warmup_steps
:
learning_rate
=
base_learning_rate
*
global_step
/
warmup_steps
else
:
learning_rate
=
base_learning_rate
learning_rate
=
learning_rate
*
(
decay_rate
**
(
global_step
/
decay_steps
))
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'
lr
'
]
=
learning_rate
total_loss
+=
train_step
(
batch
,
model
,
optimizer
,
device
,
criterion
)
global_step
+=
1
total_loss
/=
len
(
train_loader
)
# We save the checkpoints
if
not
epoch
%
cfg
[
"
training
"
][
"
checkpoint_every
"
]:
# Save the checkpoint of the model.
# Save the checkpoint of the model.
ckpt[
'
global_step
'
] = global_step
ckpt[
'
global_step
'
] = global_step
ckpt[
'
model_state_dict
'
] = model.state_dict()
ckpt[
'
model_state_dict
'
] = model.state_dict()
...
@@ -163,27 +94,5 @@ def main():
...
@@ -163,27 +94,5 @@ def main():
# We visualize some test data
# We visualize some test data
if not epoch % cfg[
"
training
"
][
"
visualize_every
"
]:
if not epoch % cfg[
"
training
"
][
"
visualize_every
"
]:
image
=
torch
.
squeeze
(
next
(
iter
(
vis_loader
)).
get
(
'
input_images
'
).
to
(
device
),
dim
=
1
)
image
=
F
.
interpolate
(
image
,
size
=
128
)
"""
image
=
image
.
to
(
device
)
\ No newline at end of file
recon_combined
,
recons
,
masks
,
slots
=
model
(
image
)
visualize_slot_attention
(
num_slots
,
image
,
recon_combined
,
recons
,
masks
,
folder_save
=
args
.
ckpt
,
step
=
global_step
,
save_file
=
True
)
# Log the training loss.
if
not
epoch
%
cfg
[
"
training
"
][
"
print_every
"
]:
print
(
f
"
[TRAIN] Epoch :
{
epoch
}
|| Step:
{
global_step
}
, Loss:
{
total_loss
}
, Time:
{
datetime
.
timedelta
(
seconds
=
time
.
time
()
-
start
)
}
"
)
# We visualize some test data
if
not
epoch
%
cfg
[
"
training
"
][
"
validate_every
"
]:
val_loss
=
0
val_psnr
=
0
model
.
eval
()
for
batch
in
tqdm
(
val_loader
):
mse
,
psnr
=
eval_step
(
batch
,
model
,
device
,
criterion
)
val_loss
+=
mse
val_psnr
+=
psnr
val_loss
/=
len
(
val_loader
)
val_psnr
/=
len
(
val_loader
)
print
(
f
"
[EVAL] Epoch :
{
epoch
}
|| Loss (MSE):
{
val_loss
}
; PSNR:
{
val_psnr
}
, Time:
{
datetime
.
timedelta
(
seconds
=
time
.
time
()
-
start
)
}
"
)
model
.
train
()
if
__name__
==
"
__main__
"
:
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
visualise.py
+
4
−
2
View file @
349fae70
...
@@ -6,7 +6,7 @@ import torch.optim as optim
...
@@ -6,7 +6,7 @@ import torch.optim as optim
import
argparse
import
argparse
import
yaml
import
yaml
from
osrt.model
import
SlotAttentionAutoEncoder
from
osrt.model
import
Lit
SlotAttentionAutoEncoder
from
osrt
import
data
from
osrt
import
data
from
osrt.utils.visualize
import
visualize_slot_attention
from
osrt.utils.visualize
import
visualize_slot_attention
from
osrt.utils.common
import
mse2psnr
from
osrt.utils.common
import
mse2psnr
...
@@ -15,6 +15,8 @@ from torch.utils.data import DataLoader
...
@@ -15,6 +15,8 @@ from torch.utils.data import DataLoader
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
tqdm
import
tqdm
# TODO : setup with lightning
def
main
():
def
main
():
# Arguments
# Arguments
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -48,7 +50,7 @@ def main():
...
@@ -48,7 +50,7 @@ def main():
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
shuffle
=
True
,
worker_init_fn
=
data
.
worker_init_fn
)
#### Create model
#### Create model
model
=
SlotAttentionAutoEncoder
(
resolution
,
10
,
num_iterations
,
cfg
=
cfg
)
.
to
(
device
)
model
=
Lit
SlotAttentionAutoEncoder
(
resolution
,
10
,
num_iterations
,
cfg
=
cfg
)
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
print
(
'
Number of parameters:
'
)
print
(
'
Number of parameters:
'
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment