Skip to content
Snippets Groups Projects
Commit f8834019 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Solve problem of memory

parent 2322dc27
No related branches found
No related tags found
No related merge requests found
Showing
with 95 additions and 45 deletions
.visualisation_0.png

34.8 KiB

......@@ -38,7 +38,7 @@ def main():
resolution = (128, 128)
#### Create datasets
test_dataset = data.get_dataset('val', cfg['data'])
test_dataset = data.get_dataset('test', cfg['data'])
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
......
File added
{}
File added
{}
File added
{}
File added
{}
......@@ -5,6 +5,8 @@ import numpy as np
import math
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch.nn.functional as F
......@@ -303,6 +305,66 @@ def fourier_encode(x, max_freq, num_bands = 4):
x = torch.cat((x, orig_x), dim = -1)
return x
class AutoEncoder(nn.Module):
def __init__(self, patch_size, image_size, emb_dim):
super(self).__init__()
self.patchify = nn.Conv2d(3, emb_dim, patch_size, patch_size)
self.head = nn.Linear(emb_dim, 3 * patch_size ** 2)
self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
def encode(self, img):
return self.patchify(img)
def decode(self, feature):
feature = feature.reshape(feature.shape[0],feature.shape[1],-1).permute(1,0,2)
return self.patch2img(self.head(feature))
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
)
self.encoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64)
)
def forward(self, x):
x = self.encoder_cnn(x).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
return x
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
self.decoder_pos = PositionEmbeddingImplicit(64)
def forward(self, x):
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
return x
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module):
"""
......
......@@ -9,7 +9,7 @@ import numpy as np
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder
import osrt.layers as layers
from osrt.utils.common import mse2psnr
......@@ -66,32 +66,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
self.criterion = nn.MSELoss()
self.encoder_cnn = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
)
self.decoder_initial_size = (8, 8)
self.decoder_cnn = nn.Sequential(
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 4, kernel_size=3)
)
self.encoder_pos = PositionEmbeddingImplicit(64)
self.decoder_pos = PositionEmbeddingImplicit(64)
self.layer_norm = nn.LayerNorm(64)
self.mlp = nn.Sequential(
nn.Linear(64, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 64)
)
self.encoder = Encoder()
self.decoder = Decoder()
model_type = cfg['model']['model_type']
if model_type == 'sa':
......@@ -111,15 +87,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
x = self.encoder(image)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
x = F.interpolate(x, image.shape[-2:], mode='bilinear')
x = x.unflatten(0, (len(image), len(x) // len(image)))
......@@ -135,14 +107,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
return optimizer
def one_step(self, image):
x = self.encoder_cnn(image).movedim(1, -1)
x = self.encoder_pos(x)
x = self.mlp(self.layer_norm(x))
x = self.encoder(image)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
x = self.decoder_pos(x)
x = self.decoder_cnn(x.movedim(-1, 1))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
x = F.interpolate(x, image.shape[-2:], mode='bilinear')
......@@ -152,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None
def training_step(self, batch, batch_idx):
"""Perform a single training step."""
......
outputs/visualisation_12000.png

47.5 KiB

data:
dataset: clevr3d
model:
num_slots: 6
num_slots: 10
iters: 3
model_type: sa
training:
num_workers: 2
num_gpus: 1
batch_size: 64
batch_size: 8
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
......
data:
dataset: clevr3d
model:
num_slots: 10
iters: 3
model_type: tsa
training:
num_workers: 2
num_gpus: 1
batch_size: 32
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
......@@ -50,7 +50,7 @@ def main():
shuffle=True, worker_init_fn=data.worker_init_fn)
#### Create model
model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment