Skip to content
Snippets Groups Projects
Commit 5a01f6ca authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change device slots

parent 8d45a44b
No related branches found
No related tags found
No related merge requests found
......@@ -174,7 +174,6 @@ class FeatureMasking(nn.Module):
num_slots = 100000
for b in range(B):
latents_batch = torch.empty((1, dim), device=self.mask_generator.device)
# TODO : set a new number of slots
for n in range(N):
embeds = masks[b][n]["embeddings"]
num_slots = min(len(embeds), num_slots)
......@@ -270,11 +269,11 @@ class SamAutomaticMask(nn.Module):
input_size = 0 # depends on the image size
self.token_dim = (self.image_encoder.img_size // patch_size)**2
self.tokenizer = nn.Sequential(
nn.Linear(input_size, 100),
nn.Linear(self.token_dim, 3500),
nn.ReLU(),
nn.Linear(100, 50),
nn.Linear(3500, 2500),
nn.ReLU(),
nn.Linear(50, self.token_dim),
nn.Linear(2500, 2048),
)
# Space positional embedding
......
......@@ -229,10 +229,10 @@ class SlotAttention(nn.Module):
inputs = self.norm_input(inputs)
if self.randomize_initial_slots:
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim]
slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
else:
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1)
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
k, v = self.to_k(inputs), self.to_v(inputs)
......
from torch import nn
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder
from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
import osrt.layers as layers
......@@ -17,6 +17,8 @@ class OSRT(nn.Module):
self.encoder = ImprovedSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'osrt':
self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'sam':
self.encoder = FeatureMasking(**cfg['encoder_kwargs'])
else:
raise ValueError(f'Unknown encoder type: {encoder_type}')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment