diff --git a/osrt/encoder.py b/osrt/encoder.py
index e5292bad070d61e9b673ce9fb9ff59989377e1f2..deb591831b95b4daf60af4a019e520ec67519552 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -112,6 +112,20 @@ class OSRTEncoder(nn.Module):
         slot_latents = self.slot_attention(set_latents)
         return slot_latents
 
+class SlotTransformer(nn.Module):
+    def __init__(self, pos_start_octave=0, num_slots=6, slot_dim=1536, slot_iters=1,
+                 randomize_initial_slots=False):
+        super().__init__()
+        self.srt_encoder = ImprovedSRTEncoder(num_conv_blocks=3, num_att_blocks=5,
+                                             pos_start_octave=pos_start_octave)
+
+        self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters,
+                                            randomize_initial_slots=randomize_initial_slots)
+
+    def forward(self, images, camera_pos, rays):
+        set_latents = self.srt_encoder(images, camera_pos, rays)
+        slot_latents = self.slot_attention(set_latents)
+        return slot_latents
 
 class FeatureMasking(nn.Module):
     def __init__(self, 
@@ -345,7 +359,7 @@ class FeatureMasking(nn.Module):
         ):
 
         ### Prepare the points of the batch
-        in_points = torch.as_tensor(self.transform.apply_coords(points, im_size))
+        in_points = torch.as_tensor(self.resize.apply_coords(points, im_size))
         in_labels = torch.ones(in_points.shape[0], dtype=torch.int)
         
         point_coords = in_points[:, None, :]