diff --git a/osrt/layers.py b/osrt/layers.py
index f681910a4461948a637dbab4eff83b4b590c8da4..4e9f0d1fa345f72d1c50e17bc47417d4e9f062b7 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -347,121 +347,6 @@ class TransformerSlotAttention(nn.Module):
         return slots # [batch_size, num_slots, dim]
 
 
-
-def unstack_and_split(x, batch_size, num_channels=3):
-    """Unstack batch dimension and split into channels and alpha mask."""
-    unstacked = x.view(batch_size, -1, *x.shape[1:])
-    channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
-    return channels, masks
-
-def spatial_flatten(x):
-    return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
-
-def spatial_broadcast(slots, resolution):
-    """Broadcast slot features to a 2D grid and collapse slot dimension."""
-    # `slots` has shape: [batch_size, num_slots, slot_size].
-    slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
-    grid = slots.repeat(1, resolution[0], resolution[1], 1)
-    # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
-    return grid
-
-# TODO : adapt this model
-class SlotAttentionAutoEncoder(nn.Module):
-    """
-    Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
-
-    Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
-
-    """
-
-    def __init__(self, resolution, num_slots, num_iterations):
-        """Builds the Slot Attention-based auto-encoder.
-
-        Args:
-        resolution: Tuple of integers specifying width and height of input image.
-        num_slots: Number of slots in Slot Attention.
-        num_iterations: Number of iterations in Slot Attention.
-        """
-        super().__init__()
-        self.resolution = resolution
-        self.num_slots = num_slots
-        self.num_iterations = num_iterations
-
-        self.encoder_cnn = nn.Sequential(
-            nn.Conv2d(3, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU(),
-            nn.Conv2d(64, 64, kernel_size=5, padding=2),
-            nn.ReLU()
-        )
-
-        self.decoder_initial_size = (8, 8)
-        self.decoder_cnn = nn.Sequential(
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),
-            nn.ReLU(),
-            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2)
-        )
-
-        self.encoder_pos = SoftPositionEmbed(64, self.resolution)
-        self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
-
-        self.layer_norm = nn.LayerNorm(64)
-        self.mlp = nn.Sequential(
-            JaxLinear(64, 64),
-            nn.ReLU(),
-            JaxLinear(64, 64)
-        )
-
-        self.slot_attention = SlotAttention(
-            num_slots=self.num_slots,
-            slot_dim=64,
-            hidden_dim=128,
-            iters=self.num_iterations)
-
-    def forward(self, image):
-        # `image` has shape: [batch_size, width, height, num_channels].
-
-        # Convolutional encoder with position embedding.
-        x = self.encoder_cnn(image)  # CNN Backbone.
-        #x = self.encoder_pos(x)  # Position embedding.
-        x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
-        x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
-        # `x` has shape: [batch_size, width*height, input_size].
-
-        # Slot Attention module.
-        slots = self.slot_attention(x)
-        # `slots` has shape: [batch_size, num_slots, slot_size].
-
-        # Spatial broadcast decoder.
-        x = spatial_broadcast(slots, self.decoder_initial_size)
-        # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
-        #x = self.decoder_pos(x)
-        x = self.decoder_cnn(x)
-        # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
-
-        # Undo combination of slot and batch dimension; split alpha masks.
-        recons, masks = unstack_and_split(x, batch_size=image.shape[0])
-        # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
-        # `masks` has shape: [batch_size, num_slots, width, height, 1].
-
-        # Normalize alpha masks over slots.
-        masks = torch.softmax(masks, dim=1)
-        recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
-        # `recon_combined` has shape: [batch_size, width, height, num_channels].
-
-        return recon_combined, recons, masks, slots
-
 def build_grid(resolution):
     ranges = [np.linspace(0., 1., num=res) for res in resolution]
     grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
@@ -469,7 +354,7 @@ def build_grid(resolution):
     grid = np.reshape(grid, [resolution[0], resolution[1], -1])
     grid = np.expand_dims(grid, axis=0)
     grid = grid.astype(np.float32)
-    return np.concatenate([grid, 1.0 - grid], axis=-1)
+    return np.concatenate([grid, 1.0 - grid], axis=-1).transpose(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
 
 class SoftPositionEmbed(nn.Module):
     """Adds soft positional embedding with learnable projection.
@@ -487,6 +372,4 @@ class SoftPositionEmbed(nn.Module):
         self.grid = build_grid(resolution)
 
     def forward(self, inputs):
-        print(inputs.shape)
-        print(self.dense(torch.tensor(self.grid).cuda()).shape)
         return inputs + self.dense(torch.tensor(self.grid).cuda())
\ No newline at end of file
diff --git a/osrt/model.py b/osrt/model.py
index 6480fa8fee535bb17709ae274679dad0927338df..45821355ff17b56ebf0769373ce5211de3999cbf 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -1,7 +1,10 @@
 from torch import nn
+import torch
+import numpy as np
 
 from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
 from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
+from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed
 
 import osrt.layers as layers
 
@@ -33,3 +36,120 @@ class OSRT(nn.Module):
             raise ValueError(f'Unknown decoder type: {decoder_type}')
 
 
+
+def unstack_and_split(x, batch_size, num_channels=3):
+    """Unstack batch dimension and split into channels and alpha mask."""
+    unstacked = x.view(batch_size, -1, *x.shape[1:])
+    channels, masks = torch.split(unstacked, [num_channels, 1], dim=-1)
+    return channels, masks
+
+def spatial_flatten(x):
+    return x.view(-1, x.shape[1] * x.shape[2], x.shape[-1])
+
+def spatial_broadcast(slots, resolution):
+    """Broadcast slot features to a 2D grid and collapse slot dimension."""
+    # `slots` has shape: [batch_size, num_slots, slot_size].
+    slots = slots.view(-1, slots.shape[-1])[:, None, None, :]
+    grid = slots.repeat(1, resolution[0], resolution[1], 1)
+    # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
+    return grid
+
+
+# TODO : adapt this model
+class SlotAttentionAutoEncoder(nn.Module):
+    """
+    Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
+
+    Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
+
+    """
+
+    def __init__(self, resolution, num_slots, num_iterations):
+        """Builds the Slot Attention-based auto-encoder.
+
+        Args:
+        resolution: Tuple of integers specifying width and height of input image.
+        num_slots: Number of slots in Slot Attention.
+        num_iterations: Number of iterations in Slot Attention.
+        """
+        super().__init__()
+        self.resolution = resolution
+        self.num_slots = num_slots
+        self.num_iterations = num_iterations
+
+        self.encoder_cnn = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=5, padding=2),
+            nn.ReLU(),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2),
+            nn.ReLU(),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2),
+            nn.ReLU(),
+            nn.Conv2d(64, 64, kernel_size=5, padding=2),
+            nn.ReLU()
+        )
+
+        self.decoder_initial_size = (8, 8)
+        self.decoder_cnn = nn.Sequential(
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2),
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2)
+        )
+
+        self.encoder_pos = SoftPositionEmbed(64, (32, 32))
+        self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size)
+
+        self.layer_norm = nn.LayerNorm(64)
+        self.mlp = nn.Sequential(
+            JaxLinear(64, 64),
+            nn.ReLU(),
+            JaxLinear(64, 64)
+        )
+
+        self.slot_attention = SlotAttention(
+            num_slots=self.num_slots,
+            slot_dim=64,
+            hidden_dim=128,
+            iters=self.num_iterations)
+
+    def forward(self, image):
+        # `image` has shape: [batch_size, num_channels, width, height].
+        print(f"Shape input {image.shape}")
+        # Convolutional encoder with position embedding.
+        x = self.encoder_cnn(image)  # CNN Backbone.
+        print(f"Shape after encoder {x.shape}")
+        x = self.encoder_pos(x)  # Position embedding.
+        x = spatial_flatten(x)  # Flatten spatial dimensions (treat image as set).
+        x = self.mlp(self.layer_norm(x))  # Feedforward network on set.
+        # `x` has shape: [batch_size, width*height, input_size].
+
+        # Slot Attention module.
+        slots = self.slot_attention(x)
+        # `slots` has shape: [batch_size, num_slots, slot_size].
+
+        # Spatial broadcast decoder.
+        x = spatial_broadcast(slots, self.decoder_initial_size)
+        # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size].
+        #x = self.decoder_pos(x)
+        x = self.decoder_cnn(x)
+        # `x` has shape: [batch_size*num_slots, width, height, num_channels+1].
+
+        # Undo combination of slot and batch dimension; split alpha masks.
+        recons, masks = unstack_and_split(x, batch_size=image.shape[0])
+        # `recons` has shape: [batch_size, num_slots, width, height, num_channels].
+        # `masks` has shape: [batch_size, num_slots, width, height, 1].
+
+        # Normalize alpha masks over slots.
+        masks = torch.softmax(masks, dim=1)
+        recon_combined = torch.sum(recons * masks, dim=1)  # Recombine image.
+        # `recon_combined` has shape: [batch_size, width, height, num_channels].
+
+        return recon_combined, recons, masks, slots
+
diff --git a/train_sa.py b/train_sa.py
index c087f342d00020730875aed50ad2bab29b1a5750..3afbecd956ba117968bda77e3b69fdde9f6e1eab 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -5,7 +5,7 @@ import torch.nn as nn
 import torch.optim as optim
 import argparse
 import yaml
-from osrt.layers import SlotAttentionAutoEncoder
+from osrt.model import SlotAttentionAutoEncoder
 from osrt import data
 
 from torch.utils.data import DataLoader