Skip to content
Snippets Groups Projects
Commit 3bbd64ae authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change loading dataset

parent a3f58aaf
No related branches found
No related tags found
No related merge requests found
...@@ -300,43 +300,43 @@ class TransformerSlotAttention(nn.Module): ...@@ -300,43 +300,43 @@ class TransformerSlotAttention(nn.Module):
""" """
An extension of Slot Attention using self-attention An extension of Slot Attention using self-attention
""" """
def __init__(self, depth, heads, dim_head, mlp_dim, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, eps=1e-8, """def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8,
randomize_initial_slots=False):"""
def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6,
randomize_initial_slots=False): randomize_initial_slots=False):
super().__init__() super().__init__()
self.num_slots = num_slots self.num_slots = num_slots
self.input_dim = input_dim
self.batch_slots = [] self.batch_slots = []
self.scale = slot_dim ** -0.5 self.scale = slot_dim ** -0.5
self.slot_dim = slot_dim self.slot_dim = slot_dim # latent_dim
self.hidden_dim = hidden_dim
self.depth = depth self.depth = depth
self.num_heads = 8 self.self_head = self_heads
self.cross_heads=cross_heads
### Cross-attention layers
self.cs_layers = nn.ModuleList([])
for _ in range(depth):
# def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
self.cs_layers.append(nn.ModuleList([
PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)),
PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim))
]))
### Self-attention layers
self.sf_layers = nn.ModuleList([])
for _ in range(depth-1):
self.sf_layers.append(nn.ModuleList([
PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)),
PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim))
]))
### Initialize slots
self.randomize_initial_slots = randomize_initial_slots self.randomize_initial_slots = randomize_initial_slots
self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim)) self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))
#def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None):
self.transformer_stage_1 = Transformer(dim=384, depth=2, heads=8)
self.transformer_stage_2 = Transformer(dim=384, depth=2, heads=8)
self.eps = eps
self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
self.to_k = JaxLinear(input_dim, slot_dim, bias=False)
self.to_v = JaxLinear(input_dim, slot_dim, bias=False)
self.gru = nn.GRUCell(slot_dim, slot_dim)
self.mlp = nn.Sequential(
JaxLinear(slot_dim, hidden_dim),
nn.ReLU(inplace=True),
JaxLinear(hidden_dim, slot_dim)
)
self.norm_input = nn.LayerNorm(input_dim)
self.norm_slots = nn.LayerNorm(slot_dim)
self.norm_pre_mlp = nn.LayerNorm(slot_dim)
def forward(self, inputs): def forward(self, inputs):
""" """
Args: Args:
...@@ -352,23 +352,21 @@ class TransformerSlotAttention(nn.Module): ...@@ -352,23 +352,21 @@ class TransformerSlotAttention(nn.Module):
else: else:
slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device)
k, v = self.to_k(inputs), self.to_v(inputs) ############### TODO : adapt this part of code
# data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here
for _ in range(self.iters): x0 = repeat(self.latents, 'n d -> b n d', b = b)
slots_prev = slots for i in range(self.depth):
norm_slots = self.norm_slots(slots) cross_attn, cross_ff = self.cs_layers[i]
q = self.to_q(norm_slots) # cross attention only happens once for Perceiver IO
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale x = cross_attn(x0, context = data, mask = mask) + x0
x0 = cross_ff(x) + x
# shape: [batch_size, num_slots, num_inputs] if i != self.depth - 1:
attn = dots.softmax(dim=1) + self.eps self_attn, self_ff = self.layers[i]
attn = attn / attn.sum(dim=-1, keepdim=True) x_d = self_attn(data) + data
updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] data = self_ff(x_d) + x_d
slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
slots = slots + self.mlp(self.norm_pre_mlp(slots))
return slots # [batch_size, num_slots, dim] return slots # [batch_size, num_slots, dim]
...@@ -68,13 +68,13 @@ def main(): ...@@ -68,13 +68,13 @@ def main():
train_dataset = data.get_dataset('train', cfg['data']) train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) shuffle=True, worker_init_fn=data.worker_init_fn)
vis_dataset = data.get_dataset('test', cfg['data']) vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader( vis_loader = DataLoader(
vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) shuffle=True, worker_init_fn=data.worker_init_fn)
model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device) model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
num_params = sum(p.numel() for p in model.parameters()) num_params = sum(p.numel() for p in model.parameters())
...@@ -98,8 +98,8 @@ def main(): ...@@ -98,8 +98,8 @@ def main():
global_step = ckpt['global_step'] global_step = ckpt['global_step']
start = time.time() start = time.time()
for _ in range(num_train_steps): for batch in train_loader:
batch = next(iter(train_loader)) #batch = next(iter(train_loader))
# Learning rate warm-up. # Learning rate warm-up.
if global_step < warmup_steps: if global_step < warmup_steps:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment