From c515e6d579c88d42d148c78eb9901773c274acc6 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Tue, 18 Jul 2023 17:07:07 +0200
Subject: [PATCH] Resolve issue on points coords

---
 osrt/encoder.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/osrt/encoder.py b/osrt/encoder.py
index e5292ba..deb5918 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -112,6 +112,20 @@ class OSRTEncoder(nn.Module):
         slot_latents = self.slot_attention(set_latents)
         return slot_latents
 
+class SlotTransformer(nn.Module):
+    def __init__(self, pos_start_octave=0, num_slots=6, slot_dim=1536, slot_iters=1,
+                 randomize_initial_slots=False):
+        super().__init__()
+        self.srt_encoder = ImprovedSRTEncoder(num_conv_blocks=3, num_att_blocks=5,
+                                             pos_start_octave=pos_start_octave)
+
+        self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters,
+                                            randomize_initial_slots=randomize_initial_slots)
+
+    def forward(self, images, camera_pos, rays):
+        set_latents = self.srt_encoder(images, camera_pos, rays)
+        slot_latents = self.slot_attention(set_latents)
+        return slot_latents
 
 class FeatureMasking(nn.Module):
     def __init__(self, 
@@ -345,7 +359,7 @@ class FeatureMasking(nn.Module):
         ):
 
         ### Prepare the points of the batch
-        in_points = torch.as_tensor(self.transform.apply_coords(points, im_size))
+        in_points = torch.as_tensor(self.resize.apply_coords(points, im_size))
         in_labels = torch.ones(in_points.shape[0], dtype=torch.int)
         
         point_coords = in_points[:, None, :]
-- 
GitLab