diff --git a/osrt/encoder.py b/osrt/encoder.py index e5292bad070d61e9b673ce9fb9ff59989377e1f2..deb591831b95b4daf60af4a019e520ec67519552 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -112,6 +112,20 @@ class OSRTEncoder(nn.Module): slot_latents = self.slot_attention(set_latents) return slot_latents +class SlotTransformer(nn.Module): + def __init__(self, pos_start_octave=0, num_slots=6, slot_dim=1536, slot_iters=1, + randomize_initial_slots=False): + super().__init__() + self.srt_encoder = ImprovedSRTEncoder(num_conv_blocks=3, num_att_blocks=5, + pos_start_octave=pos_start_octave) + + self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters, + randomize_initial_slots=randomize_initial_slots) + + def forward(self, images, camera_pos, rays): + set_latents = self.srt_encoder(images, camera_pos, rays) + slot_latents = self.slot_attention(set_latents) + return slot_latents class FeatureMasking(nn.Module): def __init__(self, @@ -345,7 +359,7 @@ class FeatureMasking(nn.Module): ): ### Prepare the points of the batch - in_points = torch.as_tensor(self.transform.apply_coords(points, im_size)) + in_points = torch.as_tensor(self.resize.apply_coords(points, im_size)) in_labels = torch.ones(in_points.shape[0], dtype=torch.int) point_coords = in_points[:, None, :]