diff --git a/eval_sa.py b/eval_sa.py
new file mode 100644
index 0000000000000000000000000000000000000000..858462dc755c04c79e1a287b98734f28e606d21d
--- /dev/null
+++ b/eval_sa.py
@@ -0,0 +1,55 @@
+from osrt import data
+from osrt.model import SlotAttentionAutoEncoder
+import torch
+import matplotlib.pyplot as plt
+from PIL import Image as Image
+import argparse
+import yaml
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+from osrt.utils.visualize import visualize_slot_attention
+
+if __name__ == "__main__":
+    # Arguments
+    parser = argparse.ArgumentParser(
+        description='Train a 3D scene representation model.'
+    )
+    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
+    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
+    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
+    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
+
+    args = parser.parse_args()
+    with open(args.config, 'r') as f:
+        cfg = yaml.load(f, Loader=yaml.CLoader)
+
+    # Hyperparameters.
+    seed = 0
+    batch_size = 1
+    num_slots = 7
+    num_iterations = 3
+    resolution = (128, 128)
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations)
+    model = torch.load('./ckpt.pth')['network']
+    print(model)
+    model.eval()
+
+
+
+    eval_dataset = data.get_dataset('train', cfg['data'])
+    eval_loader = DataLoader(
+        eval_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
+        shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
+
+    model = model.to(device)
+
+    image = torch.squeeze(next(iter(eval_loader)).get('input_images').to(device), dim=1)
+    image = F.interpolate(image, size=128)
+    image = image.to(device)
+    recon_combined, recons, masks, slots = model(image)
+
+    visualize_slot_attention(num_slots, image, recon_combined, recons, masks)
+
+    
+
diff --git a/osrt/layers.py b/osrt/layers.py
index 4e9f0d1fa345f72d1c50e17bc47417d4e9f062b7..0daf93ccfb3560a2c083cccc3d2624a0cee61149 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -267,31 +267,58 @@ class SlotAttention(nn.Module):
         self.num_slots = num_slots
         self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))
 
-#############################################
-#############################################
-#############################################
-#############################################
-#############################################
-#############################################
-### New implementations
+### Utils for SlotAttentionAutoEncoder
+def build_grid(resolution):
+    ranges = [np.linspace(0., 1., num=res) for res in resolution]
+    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
+    grid = np.stack(grid, axis=-1)
+    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
+    grid = np.expand_dims(grid, axis=0)
+    grid = grid.astype(np.float32)
+    return np.concatenate([grid, 1.0 - grid], axis=-1)
+
+class SoftPositionEmbed(nn.Module):
+    """Adds soft positional embedding with learnable projection.
+    Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py"""
+
+    def __init__(self, hidden_size, resolution):
+        """Builds the soft position embedding layer.
 
+        Args:
+            hidden_size: Size of input feature dimension.
+            resolution: Tuple of integers specifying width and height of grid.
+        """
+        super().__init__()
+        self.dense = JaxLinear(4, hidden_size)
+        self.grid = build_grid(resolution)
+
+    def forward(self, inputs):
+        return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) #  from [b, h, w, c] to [b, c, h, w]
+    
+### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
     An extension of Slot Attention using self-attention
     """
-    def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
+    def __init__(self, depth, heads, dim_head, mlp_dim, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, eps=1e-8,
                  randomize_initial_slots=False):
         super().__init__()
 
         self.num_slots = num_slots
         self.batch_slots = []
-        self.iters = iters
         self.scale = slot_dim ** -0.5
         self.slot_dim = slot_dim
+        self.depth = depth
+        self.num_heads = 8
+
 
         self.randomize_initial_slots = randomize_initial_slots
         self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
 
+        #def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None):
+        self.transformer_stage_1 = Transformer(dim=384, depth=2, heads=8)
+        self.transformer_stage_2 = Transformer(dim=384, depth=2, heads=8)
+
         self.eps = eps
 
         self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
@@ -345,31 +372,3 @@ class TransformerSlotAttention(nn.Module):
             slots = slots + self.mlp(self.norm_pre_mlp(slots))
 
         return slots # [batch_size, num_slots, dim]
-
-
-def build_grid(resolution):
-    ranges = [np.linspace(0., 1., num=res) for res in resolution]
-    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
-    grid = np.stack(grid, axis=-1)
-    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
-    grid = np.expand_dims(grid, axis=0)
-    grid = grid.astype(np.float32)
-    return np.concatenate([grid, 1.0 - grid], axis=-1).transpose(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
-
-class SoftPositionEmbed(nn.Module):
-    """Adds soft positional embedding with learnable projection.
-    Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py"""
-
-    def __init__(self, hidden_size, resolution):
-        """Builds the soft position embedding layer.
-
-        Args:
-            hidden_size: Size of input feature dimension.
-            resolution: Tuple of integers specifying width and height of grid.
-        """
-        super().__init__()
-        self.dense = JaxLinear(4, hidden_size)
-        self.grid = build_grid(resolution)
-
-    def forward(self, inputs):
-        return inputs + self.dense(torch.tensor(self.grid).cuda())
\ No newline at end of file
diff --git a/osrt/model.py b/osrt/model.py
index 45821355ff17b56ebf0769373ce5211de3999cbf..938d990377e61104dd90a39b2069917d1b0babab 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -90,20 +90,20 @@ class SlotAttentionAutoEncoder(nn.Module):
 
         self.decoder_initial_size = (8, 8)
         self.decoder_cnn = nn.Sequential(
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2),
             nn.ReLU(),
-            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2)
+            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1)
         )
 
-        self.encoder_pos = SoftPositionEmbed(64, (32, 32))
+        self.encoder_pos = SoftPositionEmbed(64, self.resolution)
         self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
 
         self.layer_norm = nn.LayerNorm(64)
@@ -115,17 +115,16 @@ class SlotAttentionAutoEncoder(nn.Module):
 
         self.slot_attention = SlotAttention(
             num_slots=self.num_slots,
+            input_dim=64,
             slot_dim=64,
             hidden_dim=128,
             iters=self.num_iterations)
 
     def forward(self, image):
         # `image` has shape: [batch_size, num_channels, width, height].
-        print(f"Shape input {image.shape}")
         # Convolutional encoder with position embedding.
         x = self.encoder_cnn(image)  # CNN Backbone.
-        print(f"Shape after encoder {x.shape}")
-        x = self.encoder_pos(x)  # Position embedding.
+        x = self.encoder_pos(x).permute(0, 2, 3, 1)  # Position embedding.
         x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
         x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
         # `x` has shape: [batch_size, width*height, input_size].
@@ -135,10 +134,11 @@ class SlotAttentionAutoEncoder(nn.Module):
         # `slots` has shape: [batch_size, num_slots, slot_size].
 
         # Spatial broadcast decoder.
-        x = spatial_broadcast(slots, self.decoder_initial_size)
+        x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2)
+
         # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
-        #x = self.decoder_pos(x)
-        x = self.decoder_cnn(x)
+        x = self.decoder_pos(x)
+        x = self.decoder_cnn(x).permute(0, 2, 3, 1)
         # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
 
         # Undo combination of slot and batch dimension; split alpha masks.
diff --git a/osrt/utils/visualize.py b/osrt/utils/visualize.py
index 53c59485edf7584ef3ef822b94903d7bd9f49f20..af5129083898c32e2ee13f24b9c3ee8c7584c9f2 100644
--- a/osrt/utils/visualize.py
+++ b/osrt/utils/visualize.py
@@ -88,4 +88,30 @@ def draw_visualization_grid(columns, outfile, row_labels=None, name=None):
     plt.savefig(f'{outfile}.png')
     plt.close()
 
-
+def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, save_file = False):
+    fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2))
+    image = image.squeeze(0)
+    recon_combined = recon_combined.squeeze(0)
+    recons = recons.squeeze(0)
+    masks = masks.squeeze(0)
+    image = image.permute(1,2,0).cpu().numpy()
+    recon_combined = recon_combined.cpu().detach().numpy()
+    recons = recons.cpu().detach().numpy()
+    masks = masks.cpu().detach().numpy()
+
+    if not save_file:
+        ax[0].imshow(image)
+        ax[0].set_title('Image')
+        ax[1].imshow(recon_combined)
+        ax[1].set_title('Recon.')
+        for i in range(6):
+            picture = recons[i] * masks[i] + (1 - masks[i])
+            ax[i + 2].imshow(picture)
+            ax[i + 2].set_title('Slot %s' % str(i + 1))
+        for i in range(len(ax)):
+            ax[i].grid(False)
+            ax[i].axis('off')
+        plt.show()
+    else:
+        # TODO : save png in file
+        pass
diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml
index 44592b758dea1ca61e4a92be44eb3202cbe97763..aade5ddcd330c1fd1cdb69c88ae3504ae140970e 100644
--- a/runs/clevr3d/slot_att/config.yaml
+++ b/runs/clevr3d/slot_att/config.yaml
@@ -5,7 +5,7 @@ model:
   iters: 3
 training:
   num_workers: 2 
-  batch_size: 64 
+  batch_size: 32 
   visualize_every: 5000
   validate_every: 5000
   checkpoint_every: 1000
diff --git a/train_sa.py b/train_sa.py
index 3afbecd956ba117968bda77e3b69fdde9f6e1eab..35aa44dac636980dc38c8e7aa3896940bceb2316 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -9,6 +9,7 @@ from osrt.model import SlotAttentionAutoEncoder
 from osrt import data
 
 from torch.utils.data import DataLoader
+import torch.nn.functional as F
 
 
 def l2_loss(prediction, target):
@@ -17,10 +18,12 @@ def l2_loss(prediction, target):
 def train_step(batch, model, optimizer, device):
     """Perform a single training step."""
     input_image = torch.squeeze(batch.get('input_images').to(device), dim=1)
+    input_image = F.interpolate(input_image, size=128)
 
     # Get the prediction of the model and compute the loss.
     preds = model(input_image)
     recon_combined, recons, masks, slots = preds
+    input_image = input_image.permute(0, 2, 3, 1)
     loss_value = l2_loss(input_image, recon_combined)
     del recons, masks, slots  # Unused.