diff --git a/osrt/layers.py b/osrt/layers.py index 0daf93ccfb3560a2c083cccc3d2624a0cee61149..0cfcff18aff1d98d7d7ee8759df3fed3acf20fd9 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -300,43 +300,43 @@ class TransformerSlotAttention(nn.Module): """ An extension of Slot Attention using self-attention """ - def __init__(self, depth, heads, dim_head, mlp_dim, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, eps=1e-8, + """def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8, + randomize_initial_slots=False):""" + def __init__(self, num_slots=10, depth=6, input_dim=768, slot_dim=1536, hidden_dim=3072, cross_heads=1, self_heads=6, randomize_initial_slots=False): super().__init__() self.num_slots = num_slots + self.input_dim = input_dim self.batch_slots = [] self.scale = slot_dim ** -0.5 - self.slot_dim = slot_dim + self.slot_dim = slot_dim # latent_dim + self.hidden_dim = hidden_dim self.depth = depth - self.num_heads = 8 + self.self_head = self_heads + self.cross_heads=cross_heads + ### Cross-attention layers + self.cs_layers = nn.ModuleList([]) + for _ in range(depth): + # def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None): + self.cs_layers.append(nn.ModuleList([ + PreNorm(self.slot_dim, Attention(self.slot_dim, heads = self.cross_heads, dim_head= self.hidden_dim, selfatt=False)), + PreNorm(self.slot_dim, FeedForward(self.slot_dim, self.hidden_dim)) + ])) + + ### Self-attention layers + self.sf_layers = nn.ModuleList([]) + for _ in range(depth-1): + self.sf_layers.append(nn.ModuleList([ + PreNorm(self.input_dim, Attention(self.input_dim, heads=self.self_head, dim_head = self.hidden_dim)), + PreNorm(self.input_dim, FeedForward(self.input_dim, self.hidden_dim)) + ])) + ### Initialize slots self.randomize_initial_slots = randomize_initial_slots self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim)) - #def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None): - self.transformer_stage_1 = Transformer(dim=384, depth=2, heads=8) - self.transformer_stage_2 = Transformer(dim=384, depth=2, heads=8) - - self.eps = eps - - self.to_q = JaxLinear(slot_dim, slot_dim, bias=False) - self.to_k = JaxLinear(input_dim, slot_dim, bias=False) - self.to_v = JaxLinear(input_dim, slot_dim, bias=False) - - self.gru = nn.GRUCell(slot_dim, slot_dim) - - self.mlp = nn.Sequential( - JaxLinear(slot_dim, hidden_dim), - nn.ReLU(inplace=True), - JaxLinear(hidden_dim, slot_dim) - ) - - self.norm_input = nn.LayerNorm(input_dim) - self.norm_slots = nn.LayerNorm(slot_dim) - self.norm_pre_mlp = nn.LayerNorm(slot_dim) - def forward(self, inputs): """ Args: @@ -352,23 +352,21 @@ class TransformerSlotAttention(nn.Module): else: slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) - k, v = self.to_k(inputs), self.to_v(inputs) + ############### TODO : adapt this part of code + # data = torch.cat((data, enc_pos.reshape(b,-1,enc_pos.shape[-1])), dim = -1) TODO : add a positional encoding here - for _ in range(self.iters): - slots_prev = slots - norm_slots = self.norm_slots(slots) + x0 = repeat(self.latents, 'n d -> b n d', b = b) + for i in range(self.depth): + cross_attn, cross_ff = self.cs_layers[i] - q = self.to_q(norm_slots) + # cross attention only happens once for Perceiver IO - dots = torch.einsum('bid,bjd->bij', q, k) * self.scale + x = cross_attn(x0, context = data, mask = mask) + x0 + x0 = cross_ff(x) + x - # shape: [batch_size, num_slots, num_inputs] - attn = dots.softmax(dim=1) + self.eps - attn = attn / attn.sum(dim=-1, keepdim=True) - updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] - - slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1)) - slots = slots.reshape(batch_size, self.num_slots, self.slot_dim) - slots = slots + self.mlp(self.norm_pre_mlp(slots)) + if i != self.depth - 1: + self_attn, self_ff = self.layers[i] + x_d = self_attn(data) + data + data = self_ff(x_d) + x_d return slots # [batch_size, num_slots, dim] diff --git a/train_sa.py b/train_sa.py index 173fb148ef7d20620ec47bf5373b7a70626c8ca5..f4d00dcf8e3e295198a0425166a70251c5ef1e2c 100644 --- a/train_sa.py +++ b/train_sa.py @@ -68,13 +68,13 @@ def main(): train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader( - train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, - shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) + train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], + shuffle=True, worker_init_fn=data.worker_init_fn) vis_dataset = data.get_dataset('test', cfg['data']) vis_loader = DataLoader( - vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, - shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) + vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], + shuffle=True, worker_init_fn=data.worker_init_fn) model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device) num_params = sum(p.numel() for p in model.parameters()) @@ -98,8 +98,8 @@ def main(): global_step = ckpt['global_step'] start = time.time() - for _ in range(num_train_steps): - batch = next(iter(train_loader)) + for batch in train_loader: + #batch = next(iter(train_loader)) # Learning rate warm-up. if global_step < warmup_steps: