diff --git a/.visualisation_1639.png b/.visualisation_1639.png
new file mode 100644
index 0000000000000000000000000000000000000000..fece79284e2d145e76f53cc37c395ef07a98cebf
Binary files /dev/null and b/.visualisation_1639.png differ
diff --git a/osrt/layers.py b/osrt/layers.py
index 39f246a20754e7ec7c048abb7002644d144f8608..deaf9d8419b9e18adc6ffffccabeac5afc687f64 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -5,6 +5,7 @@ import numpy as np
 
 import math
 from einops import rearrange, repeat
+import torch.nn.functional as F
 
 
 __USE_DEFAULT_INIT__ = False
@@ -194,10 +195,11 @@ class SlotAttention(nn.Module):
     @edit : we changed the code as to make it possible to handle a different number of slots depending on the input images
     """
     def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
-                 randomize_initial_slots=False):
+                 randomize_initial_slots=False, gain = 1, temperature_factor = 1):
         super().__init__()
 
         self.num_slots = num_slots
+        self.temperature_factor = temperature_factor
         self.batch_slots = []
         self.iters = iters
         self.scale = slot_dim ** -0.5
@@ -207,24 +209,31 @@ class SlotAttention(nn.Module):
         self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
 
         self.eps = eps
+        self.slots_mu        = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim)))
+        self.slots_log_sigma = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim)))
+        
+
+        self.to_q = nn.Linear(slot_dim, slot_dim, bias=False)
+        self.to_k = nn.Linear(input_dim, slot_dim, bias=False)
+        self.to_v = nn.Linear(input_dim, slot_dim, bias=False)
 
-        self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
-        self.to_k = JaxLinear(input_dim, slot_dim, bias=False)
-        self.to_v = JaxLinear(input_dim, slot_dim, bias=False)
+        nn.init.xavier_uniform_(self.to_q.weight, gain = gain)
+        nn.init.xavier_uniform_(self.to_k.weight, gain = gain)
+        nn.init.xavier_uniform_(self.to_v.weight, gain = gain)
 
         self.gru = nn.GRUCell(slot_dim, slot_dim)
 
         self.mlp = nn.Sequential(
-            JaxLinear(slot_dim, hidden_dim),
+            nn.Linear(slot_dim, hidden_dim),
             nn.ReLU(inplace=True),
-            JaxLinear(hidden_dim, slot_dim)
+            nn.Linear(hidden_dim, slot_dim)
         )
 
         self.norm_input   = nn.LayerNorm(input_dim)
         self.norm_slots   = nn.LayerNorm(slot_dim)
         self.norm_pre_mlp = nn.LayerNorm(slot_dim)
 
-    def forward(self, inputs, masks=None):
+    def forward(self, inputs):
         """
         Args:
             inputs: set-latent representation [batch_size, num_inputs, dim]
@@ -232,74 +241,56 @@ class SlotAttention(nn.Module):
         batch_size, num_inputs, dim = inputs.shape
 
         inputs = self.norm_input(inputs)
-        
-        # Initialize the slots. Shape: [batch_size, num_slots, slot_dim].
-        if self.randomize_initial_slots:
-            slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
-            slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
-        else:
-            slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
 
         k, v = self.to_k(inputs), self.to_v(inputs)
 
+        if slots is None:
+            slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device)
+
         # Multiple rounds of attention.
         for _ in range(self.iters):
             slots_prev = slots
-            norm_slots = self.norm_slots(slots)
+            slots = self.norm_slots(slots)
 
-            q = self.to_q(norm_slots)
+            q = self.to_q(slots)
+            q *= self.scale
 
-            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale # Dot product and normalization
+            attn_logits = torch.bmm(q, k.transpose(-1, -2))
+
+            attn_pixelwise = F.softmax(attn_logits / self.temperature_factor, dim = 1)
 
-            if masks != None:
-                temp_masks = masks.unsqueeze(1)
-                attention_masking = torch.where(temp_masks == 1.0, float("-inf"), temp_masks).to(device=dots.device)
-                dots += attention_masking
             # shape: [batch_size, num_slots, num_inputs]
-            attn = dots.softmax(dim=1) + self.eps
+            attn_slotwise = F.normalize(attn_pixelwise + self.eps, p = 1, dim = -1)
 
-            # Weighted mean
-            attn = attn / attn.sum(dim=-1, keepdim=True)
-            updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] 
+            # shape: [batch_size, num_inputs, slot_dim] 
+            updates = torch.bmm(attn_slotwise, v) 
             
             # Slot update
             slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
             slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
             slots = slots + self.mlp(self.norm_pre_mlp(slots))
 
-        return slots # [batch_size, num_slots, dim]
+        return slots, attn_logits, attn_slotwise # [batch_size, num_slots, dim]
 
     def change_slots_number(self, num_slots): 
         self.num_slots = num_slots
         self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim))
 
-### Utils for SlotAttentionAutoEncoder
-def build_grid(resolution):
-    ranges = [np.linspace(0., 1., num=res) for res in resolution]
-    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
-    grid = np.stack(grid, axis=-1)
-    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
-    grid = np.expand_dims(grid, axis=0)
-    grid = grid.astype(np.float32)
-    return np.concatenate([grid, 1.0 - grid], axis=-1)
-
-class SoftPositionEmbed(nn.Module):
-    """Adds soft positional embedding with learnable projection.
-    Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py"""
-
-    def __init__(self, hidden_size, resolution):
-        """Builds the soft position embedding layer.
 
-        Args:
-            hidden_size: Size of input feature dimension.
-            resolution: Tuple of integers specifying width and height of grid.
-        """
+class PositionEmbeddingImplicit(nn.Module):
+    """
+    Position embedding extracted from
+    https://github.com/vadimkantorov/yet_another_pytorch_slot_attention/blob/master/models.py
+    """
+    def __init__(self, hidden_dim):
         super().__init__()
-        self.dense = JaxLinear(4, hidden_size)
-        self.grid = build_grid(resolution)
+        self.dense = nn.Linear(4, hidden_dim)
 
-    def forward(self, inputs):
-        return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) #  from [b, h, w, c] to [b, c, h, w]
+    def forward(self, x):
+        spatial_shape = x.shape[-3:-1]
+        grid = torch.stack(torch.meshgrid(*[torch.linspace(0., 1., r, device = x.device) for r in spatial_shape]), dim = -1)
+        grid = torch.cat([grid, 1 - grid], dim = -1)
+        return x + self.dense(grid)
     
 def fourier_encode(x, max_freq, num_bands = 4):
     x = x.unsqueeze(-1)
@@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4):
     x = torch.cat((x, orig_x), dim = -1)
     return x
 
-
 ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
@@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module):
                 x_d = self_attn(inputs) + inputs
                 inputs = self_ff(x_d) + x_d
 
-        return slots # [batch_size, num_slots, dim]
+        return slots, None, None # [batch_size, num_slots, dim]
diff --git a/osrt/model.py b/osrt/model.py
index 0102d6153d02118b3866e8bb6fd4adb8b1d0dcb9..3ecc2c26c107c32ab5bb4189565393639566b5b2 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -1,13 +1,16 @@
 from torch import nn
 import torch
+import torch.nn.functional as F
+
 import numpy as np
 
 from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
-from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention
-
+from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention
 import osrt.layers as layers
 
+
+
 class OSRT(nn.Module):
     def __init__(self, cfg):
         super().__init__()
@@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module):
         self.num_iterations = num_iterations
 
         self.encoder_cnn = nn.Sequential(
-            nn.Conv2d(3, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU()
+            nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True)
         )
 
         self.decoder_initial_size = (8, 8)
         self.decoder_cnn = nn.Sequential(
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1)
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True),
+            nn.ConvTranspose2d(64, 4, kernel_size=3)
         )
 
-        self.encoder_pos = SoftPositionEmbed(64, self.resolution)
-        self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
+        self.encoder_pos = PositionEmbeddingImplicit(64)
+        self.decoder_pos = PositionEmbeddingImplicit(64)
 
         self.layer_norm = nn.LayerNorm(64)
         self.mlp = nn.Sequential(
-            JaxLinear(64, 64),
-            nn.ReLU(),
-            JaxLinear(64, 64)
+            nn.Linear(64, 64),
+            nn.ReLU(inplace=True),
+            nn.Linear(64, 64)
         )
 
         model_type = cfg['model']['model_type']
@@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module):
                 depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
 
     def forward(self, image):
-        # `image` has shape: [batch_size, num_channels, width, height].
-        # Convolutional encoder with position embedding.
-        x = self.encoder_cnn(image)  # CNN Backbone.
-        x = self.encoder_pos(x).permute(0, 2, 3, 1)  # Position embedding.
-        x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
-        x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
-        # `x` has shape: [batch_size, width*height, input_size].
-
-        # Slot Attention module.
-        slots = self.slot_attention(x)
-        # `slots` has shape: [batch_size, num_slots, slot_size].
-
-        # Spatial broadcast decoder.
-        x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2)
-
-        # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
+        x = self.encoder_cnn(image).movedim(1, -1)
+        x = self.encoder_pos(x)
+        x = self.mlp(self.layer_norm(x))
+        
+        slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots)
+        x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1)
         x = self.decoder_pos(x)
-        x = self.decoder_cnn(x).permute(0, 2, 3, 1)
-        # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
+        x = self.decoder_cnn(x.movedim(-1, 1))
+        
+        x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode)
 
-        # Undo combination of slot and batch dimension; split alpha masks.
-        recons, masks = unstack_and_split(x, batch_size=image.shape[0])
-        # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
-        # `masks` has shape: [batch_size, num_slots, width, height, 1].
+        x = x.unflatten(0, (len(image), len(x) // len(image)))
 
-        # Normalize alpha masks over slots.
-        masks = torch.softmax(masks, dim=1)
-        recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
-        # `recon_combined` has shape: [batch_size, width, height, num_channels].
+        recons, masks = x.split((3, 1), dim = 2)
+        masks = masks.softmax(dim = 1)
+        recon_combined = (recons * masks).sum(dim = 1)
 
-        return recon_combined, recons, masks, slots
+        return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:])
 
diff --git a/visualise.py b/visualise.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c7d2ea84833a8fefb9855379961b6e9f1b6cb0
--- /dev/null
+++ b/visualise.py
@@ -0,0 +1,81 @@
+import datetime
+import time
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import argparse
+import yaml
+
+from osrt.model import SlotAttentionAutoEncoder
+from osrt import data
+from osrt.utils.visualize import visualize_slot_attention
+from osrt.utils.common import mse2psnr
+
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+from tqdm import tqdm
+
+def main():
+    # Arguments
+    parser = argparse.ArgumentParser(
+        description='Train a 3D scene representation model.'
+    )
+    parser.add_argument('config', type=str, help="Where to save the checkpoints.")
+    parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
+    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
+    parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
+
+    args = parser.parse_args()
+    with open(args.config, 'r') as f:
+        cfg = yaml.load(f, Loader=yaml.CLoader)
+
+    ### Set random seed.
+    torch.manual_seed(args.seed)
+
+    ### Hyperparameters of the model.
+    num_slots = cfg["model"]["num_slots"]
+    num_iterations = cfg["model"]["iters"]
+    base_learning_rate = 0.0004
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+    resolution = (128, 128)
+    
+    #### Create datasets
+    
+    vis_dataset = data.get_dataset('test', cfg['data'])
+    vis_loader = DataLoader(
+        vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
+        shuffle=True, worker_init_fn=data.worker_init_fn)
+
+    #### Create model
+    model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device)
+    num_params = sum(p.numel() for p in model.parameters())
+
+    print('Number of parameters:')
+    print(f'Model slot attention: {num_params}')
+
+    optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08)
+
+    ckpt = {
+        'network': model,
+        'optimizer': optimizer,
+        'global_step': 1639
+    }
+    #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth')
+    """ckpt = torch.load('~/ckpt.pth')
+    model = ckpt['network']"""
+    model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"])
+
+    image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
+    image = F.interpolate(image, size=128)
+    image = image.to(device)
+    recon_combined, recons, masks, slots = model(image)
+    loss = nn.MSELoss()
+    input_image = image.permute(0, 2, 3, 1)
+    loss_value = loss(recon_combined, input_image)
+    psnr = mse2psnr(loss_value)
+    print(f"MSE value : {loss_value} VS PSNR {psnr}")
+    visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=1639, save_file=True)
+                  
+if __name__ == "__main__":
+    main()
\ No newline at end of file