Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • achapin/Segment-Object-Centric
1 result
Show changes
Commits on Source (2)
Showing
with 97 additions and 45 deletions
......@@ -5,5 +5,7 @@
/wandb
logs/*
logs
data
data/*
results
*__pycache__
.visualisation_0.png

34.8 KiB

......@@ -38,7 +38,7 @@ def main():
resolution = (128, 128)
#### Create datasets
test_dataset = data.get_dataset('val', cfg['data'])
test_dataset = data.get_dataset('test', cfg['data'])
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
......
{}
{}
{}
{}
......@@ -5,6 +5,8 @@ import numpy as np
import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch.nn.functional as F
......@@ -303,6 +305,66 @@ def fourier_encode(x, max_freq, num_bands = 4):
x = torch.cat((x, orig_x), dim = -1)
return x
class AutoEncoder(nn.Module):
def __init__(self, patch_size, image_size, emb_dim):
super(self).__init__()
self.patchify = nn.Conv2d(3, emb_dim, patch_size, patch_size)
self.head = nn.Linear(emb_dim, 3 * patch_size ** 2)
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
def encode(self, img):
return self.patchify(img)
def decode(self, feature):
feature = feature.reshape(feature.shape[0],feature.shape[1],-1).permute(1,0,2)
return self.patch2img(self.head(feature))
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
)
self.encoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64)
)
def forward(self, x):
x = self.encoder_cnn(x).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
return x
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
self.decoder_pos = PositionEmbeddingImplicit(64)
def forward(self, x):
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
return x
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module):
"""
......
......@@ -9,7 +9,7 @@ import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder
import osrt.layers as layers
from osrt.utils.common import mse2psnr
......@@ -66,32 +66,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
self.criterion = nn.MSELoss()
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
self.encoder_pos = PositionEmbeddingImplicit(64)
self.decoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64)
)
self.encoder = Encoder()
self.decoder = Decoder()
model_type = cfg['model']['model_type']
if model_type == 'sa':
......@@ -111,15 +87,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
x = self.encoder(image)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
x = F.interpolate(x, image.shape[-2:], mode='bilinear')
x = x.unflatten(0, (len(image), len(x) // len(image)))
......@@ -135,14 +107,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return optimizer
def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
x = self.encoder(image)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
x = F.interpolate(x, image.shape[-2:], mode='bilinear')
......@@ -152,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None
def training_step(self, batch, batch_idx):
"""Perform a single training step."""
......
outputs/visualisation_12000.png

47.5 KiB

data:
dataset: clevr3d
model:
num_slots: 6
num_slots: 10
iters: 3
model_type: sa
training:
num_workers: 2
num_gpus: 1
batch_size: 64
batch_size: 8
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
......
data:
dataset: clevr3d
model:
num_slots: 10
iters: 3
model_type: tsa
training:
num_workers: 2
num_gpus: 1
batch_size: 32
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
......@@ -50,7 +50,7 @@ def main():
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
......