diff --git a/.visualisation_0.png b/.visualisation_0.png new file mode 100644 index 0000000000000000000000000000000000000000..6f0184272777793b0d4181822f23ecf15a2b2a1d Binary files /dev/null and b/.visualisation_0.png differ diff --git a/evaluate_sa.py b/evaluate_sa.py index af75b4bb89159700ede32361e378238a24693ac0..e5e06eb7ec4990daa29ba21aead49de35b00f270 100644 --- a/evaluate_sa.py +++ b/evaluate_sa.py @@ -38,7 +38,7 @@ def main(): resolution = (128, 128) #### Create datasets - test_dataset = data.get_dataset('val', cfg['data']) + test_dataset = data.get_dataset('test', cfg['data']) test_dataloader = DataLoader( test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], shuffle=True, worker_init_fn=data.worker_init_fn) diff --git a/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 new file mode 100644 index 0000000000000000000000000000000000000000..9e9c05c6e308e39ca9c9298dc2b9b8a32a4b516b Binary files /dev/null and b/lightning_logs/version_2/events.out.tfevents.1690294616.achapin-Precision-5570.88157.0 differ diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_2/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 new file mode 100644 index 0000000000000000000000000000000000000000..a7cc5deb3a254b64b177f15573a663f3ad9d0de6 Binary files /dev/null and b/lightning_logs/version_3/events.out.tfevents.1690294661.achapin-Precision-5570.88480.0 differ diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_3/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 new file mode 100644 index 0000000000000000000000000000000000000000..65a1f5c0cfc4c7d12dea17678f70732aaaf2731c Binary files /dev/null and b/lightning_logs/version_4/events.out.tfevents.1690295005.achapin-Precision-5570.89889.0 differ diff --git a/lightning_logs/version_4/hparams.yaml b/lightning_logs/version_4/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_4/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 new file mode 100644 index 0000000000000000000000000000000000000000..73d3b9d0a715b4f3a4b30267a21646af46b3fabe Binary files /dev/null and b/lightning_logs/version_5/events.out.tfevents.1690362262.achapin-Precision-5570.9831.0 differ diff --git a/lightning_logs/version_5/hparams.yaml b/lightning_logs/version_5/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_5/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/osrt/layers.py b/osrt/layers.py index 24c363f866c6b003ff4c1e297b0e84a2bbd90258..9b7c55543c99c37cf3361a3934e1d23acf559437 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -5,6 +5,8 @@ import numpy as np import math from einops import rearrange, repeat +from einops.layers.torch import Rearrange + import torch.nn.functional as F @@ -303,6 +305,66 @@ def fourier_encode(x, max_freq, num_bands = 4): x = torch.cat((x, orig_x), dim = -1) return x + +class AutoEncoder(nn.Module): + def __init__(self, patch_size, image_size, emb_dim): + super(self).__init__() + self.patchify = nn.Conv2d(3, emb_dim, patch_size, patch_size) + self.head = nn.Linear(emb_dim, 3 * patch_size ** 2) + self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size) + + def encode(self, img): + return self.patchify(img) + + def decode(self, feature): + feature = feature.reshape(feature.shape[0],feature.shape[1],-1).permute(1,0,2) + return self.patch2img(self.head(feature)) + +class Encoder(nn.Module): + def __init__(self): + super().__init__() + self.encoder_cnn = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact + nn.Conv2d(64, 64, kernel_size=5, padding=2, stride=2), nn.ReLU(inplace=True), # Added a stride to reduce memory impact + nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True) + ) + self.encoder_pos = PositionEmbeddingImplicit(64) + + self.layer_norm = nn.LayerNorm(64) + self.mlp = nn.Sequential( + nn.Linear(64, 64), + nn.ReLU(inplace=True), + nn.Linear(64, 64) + ) + + def forward(self, x): + x = self.encoder_cnn(x).movedim(1, -1) + x = self.encoder_pos(x) + x = self.mlp(self.layer_norm(x)) + return x + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.decoder_initial_size = (8, 8) + self.decoder_cnn = nn.Sequential( + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 4, kernel_size=3) + ) + + self.decoder_pos = PositionEmbeddingImplicit(64) + + def forward(self, x): + x = self.decoder_pos(x) + x = self.decoder_cnn(x.movedim(-1, 1)) + return x + ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ diff --git a/osrt/model.py b/osrt/model.py index d176d9abd4d115b3c6268625988327e9ada1f67b..932abd9224e956af29d9eb1d05b426b77461e465 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -9,7 +9,7 @@ import numpy as np from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder -from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention +from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder import osrt.layers as layers from osrt.utils.common import mse2psnr @@ -66,32 +66,8 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.criterion = nn.MSELoss() - self.encoder_cnn = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), - nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True) - ) - - self.decoder_initial_size = (8, 8) - self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=1), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True), - nn.ConvTranspose2d(64, 4, kernel_size=3) - ) - - self.encoder_pos = PositionEmbeddingImplicit(64) - self.decoder_pos = PositionEmbeddingImplicit(64) - - self.layer_norm = nn.LayerNorm(64) - self.mlp = nn.Sequential( - nn.Linear(64, 64), - nn.ReLU(inplace=True), - nn.Linear(64, 64) - ) + self.encoder = Encoder() + self.decoder = Decoder() model_type = cfg['model']['model_type'] if model_type == 'sa': @@ -111,15 +87,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model def forward(self, image): - x = self.encoder_cnn(image).movedim(1, -1) - x = self.encoder_pos(x) - x = self.mlp(self.layer_norm(x)) - + x = self.encoder(image) slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) - x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) - x = self.decoder_pos(x) - x = self.decoder_cnn(x.movedim(-1, 1)) + x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1) + x = self.decoder(x) x = F.interpolate(x, image.shape[-2:], mode='bilinear') x = x.unflatten(0, (len(image), len(x) // len(image))) @@ -135,14 +107,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return optimizer def one_step(self, image): - x = self.encoder_cnn(image).movedim(1, -1) - x = self.encoder_pos(x) - x = self.mlp(self.layer_norm(x)) + x = self.encoder(image) slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) - x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) - x = self.decoder_pos(x) - x = self.decoder_cnn(x.movedim(-1, 1)) + x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1) + x = self.decoder(x) x = F.interpolate(x, image.shape[-2:], mode='bilinear') @@ -152,7 +121,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): masks = masks.softmax(dim = 1) recon_combined = (recons * masks).sum(dim = 1) - return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise else None def training_step(self, batch, batch_idx): """Perform a single training step.""" diff --git a/outputs/visualisation_12000.png b/outputs/visualisation_12000.png new file mode 100644 index 0000000000000000000000000000000000000000..b8e3a04cb6766c6d03e3f695d2f6dd99d7643aa7 Binary files /dev/null and b/outputs/visualisation_12000.png differ diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index 161aaf628a68e847d514e2e78bdf5867ceb2167d..82c0f8e78d858da06e976697291a7189012fba51 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -1,13 +1,13 @@ data: dataset: clevr3d model: - num_slots: 6 + num_slots: 10 iters: 3 model_type: sa training: num_workers: 2 num_gpus: 1 - batch_size: 64 + batch_size: 8 max_it: 333000000 warmup_it: 10000 decay_rate: 0.5 diff --git a/runs/clevr3d/slot_att/config_tsa.yaml b/runs/clevr3d/slot_att/config_tsa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f03fc97062dd583a0884019890f8fbd0778280b1 --- /dev/null +++ b/runs/clevr3d/slot_att/config_tsa.yaml @@ -0,0 +1,15 @@ +data: + dataset: clevr3d +model: + num_slots: 10 + iters: 3 + model_type: tsa +training: + num_workers: 2 + num_gpus: 1 + batch_size: 32 + max_it: 333000000 + warmup_it: 10000 + decay_rate: 0.5 + decay_it: 100000 + diff --git a/visualize_sa.py b/visualize_sa.py index a1d4ce15204654a4018d62ecc25e34a6341c7a72..881d9d474000ee763ed1877310a79ebf2c755119 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -50,7 +50,7 @@ def main(): shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device) + model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg).to(device) checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['state_dict'])