diff --git a/osrt/layers.py b/osrt/layers.py
index 0daf93ccfb3560a2c083cccc3d2624a0cee61149..0cfcff18aff1d98d7d7ee8759df3fed3acf20fd9 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -300,43 +300,43 @@ class TransformerSlotAttention(nn.Module):
     """
     An extension of Slot Attention using self-attention
     """
-    def __init__(self, depth, heads, dim_head, mlp_dim, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, eps=1e-8,
+    """def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
+                 randomize_initial_slots=False):"""
+    def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6,
                  randomize_initial_slots=False):
         super().__init__()
 
         self.num_slots = num_slots
+        self.input_dim = input_dim
         self.batch_slots = []
         self.scale = slot_dim ** -0.5
-        self.slot_dim = slot_dim
+        self.slot_dim = slot_dim # latent_dim
+        self.hidden_dim = hidden_dim
         self.depth = depth
-        self.num_heads = 8
+        self.self_head = self_heads
+        self.cross_heads=cross_heads
 
+        ### Cross-attention layers
+        self.cs_layers = nn.ModuleList([])
+        for _ in range(depth):
+            # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
+            self.cs_layers.append(nn.ModuleList([
+                PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)),
+                PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
+            ]))
+
+        ### Self-attention layers
+        self.sf_layers = nn.ModuleList([])
+        for _ in range(depth-1):
+            self.sf_layers.append(nn.ModuleList([
+                PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)),
+                PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim))
+            ]))
 
+        ### Initialize slots
         self.randomize_initial_slots = randomize_initial_slots
         self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
 
-        #def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None):
-        self.transformer_stage_1 = Transformer(dim=384, depth=2, heads=8)
-        self.transformer_stage_2 = Transformer(dim=384, depth=2, heads=8)
-
-        self.eps = eps
-
-        self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
-        self.to_k = JaxLinear(input_dim, slot_dim, bias=False)
-        self.to_v = JaxLinear(input_dim, slot_dim, bias=False)
-
-        self.gru = nn.GRUCell(slot_dim, slot_dim)
-
-        self.mlp = nn.Sequential(
-            JaxLinear(slot_dim, hidden_dim),
-            nn.ReLU(inplace=True),
-            JaxLinear(hidden_dim, slot_dim)
-        )
-
-        self.norm_input   = nn.LayerNorm(input_dim)
-        self.norm_slots   = nn.LayerNorm(slot_dim)
-        self.norm_pre_mlp = nn.LayerNorm(slot_dim)
-
     def forward(self, inputs):
         """
         Args:
@@ -352,23 +352,21 @@ class TransformerSlotAttention(nn.Module):
         else:
             slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
 
-        k, v = self.to_k(inputs), self.to_v(inputs)
+        ############### TODO : adapt this part of code    
+        # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here
 
-        for _ in range(self.iters):
-            slots_prev = slots
-            norm_slots = self.norm_slots(slots)
+        x0 = repeat(self.latents, 'n d -> b n d', b = b)
+        for i in range(self.depth):
+            cross_attn, cross_ff = self.cs_layers[i]
 
-            q = self.to_q(norm_slots)
+            # cross attention only happens once for Perceiver IO
 
-            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
+            x = cross_attn(x0, context = data, mask = mask) + x0
+            x0 = cross_ff(x) + x
 
-            # shape: [batch_size, num_slots, num_inputs]
-            attn = dots.softmax(dim=1) + self.eps
-            attn = attn / attn.sum(dim=-1, keepdim=True)
-            updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] 
-            
-            slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
-            slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
-            slots = slots + self.mlp(self.norm_pre_mlp(slots))
+            if i != self.depth - 1:
+                self_attn, self_ff = self.layers[i]
+                x_d = self_attn(data) + data
+                data = self_ff(x_d) + x_d
 
         return slots # [batch_size, num_slots, dim]
diff --git a/train_sa.py b/train_sa.py
index 173fb148ef7d20620ec47bf5373b7a70626c8ca5..f4d00dcf8e3e295198a0425166a70251c5ef1e2c 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -68,13 +68,13 @@ def main():
     
     train_dataset = data.get_dataset('train', cfg['data'])
     train_loader = DataLoader(
-        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
-        shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
+        train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
+        shuffle=True, worker_init_fn=data.worker_init_fn)
     
     vis_dataset = data.get_dataset('test', cfg['data'])
     vis_loader = DataLoader(
-        vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
-        shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
+        vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
+        shuffle=True, worker_init_fn=data.worker_init_fn)
 
     model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
     num_params = sum(p.numel() for p in model.parameters())
@@ -98,8 +98,8 @@ def main():
     global_step = ckpt['global_step']
 
     start = time.time()
-    for _ in range(num_train_steps):
-        batch = next(iter(train_loader))
+    for batch in train_loader:
+        #batch = next(iter(train_loader))
 
         # Learning rate warm-up.
         if global_step < warmup_steps: