Skip to content
Snippets Groups Projects
Commit 8f98651f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change masking print

parent fd38c7cc
No related branches found
No related tags found
No related merge requests found
...@@ -58,13 +58,13 @@ if __name__ == '__main__': ...@@ -58,13 +58,13 @@ if __name__ == '__main__':
model.to(device) model.to(device)
images= [] images= []
for j in range(1): for j in range(8):
image = images_path[j] image = images_path[j]
img = cv2.imread(image) img = cv2.imread(image)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
images.append(np.expand_dims(img, axis=0)) images.append(np.expand_dims(img, axis=0))
images_np = np.array(images) images_np = np.array(images)
images_np = images_np.reshape(1, 1, images_np.shape[2], images_np.shape[3], images_np.shape[4]) images_np = images_np.reshape(2, 4, images_np.shape[2], images_np.shape[3], images_np.shape[4])
h, w, _ = images_np[0][0].shape h, w, _ = images_np[0][0].shape
points = model.mask_generator.points_grid points = model.mask_generator.points_grid
......
...@@ -178,13 +178,13 @@ class FeatureMasking(nn.Module): ...@@ -178,13 +178,13 @@ class FeatureMasking(nn.Module):
num_slots = min(len(embeds), num_slots) num_slots = min(len(embeds), num_slots)
for embed in embeds: for embed in embeds:
latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0) latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
masks_batch.append(torch.zeros_like(latents_batch)) masks_batch.append(torch.zeros_like(latents_batch.squeeze(-1)))
embedding_batch.append(latents_batch) embedding_batch.append(latents_batch)
print(masks_batch.shape)
set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0) set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0)
attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=float("-inf")) attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
print(f"Set latents shape {set_latents.shape}") print(f"Set latents shape {set_latents.shape}")
print(f"Mask shape {attention_mask.shape}") print(f"Mask shape {attention_mask.shape}")
print(set_latents)
# TODO: create a mask to reduce the size afterward # TODO: create a mask to reduce the size afterward
# Here we create the masks and set to zeros the values out of the scope # Here we create the masks and set to zeros the values out of the scope
......
...@@ -243,12 +243,14 @@ class SlotAttention(nn.Module): ...@@ -243,12 +243,14 @@ class SlotAttention(nn.Module):
q = self.to_q(norm_slots) q = self.to_q(norm_slots)
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attention_masking = torch.zeros_like(dots)
#attention_masking[:, : ] = float("-inf")
print(f"numslots {self.num_slots}, num_inputs {num_inputs}, input dim {dim}, slot dim {self.slot_dim}") print(f"numslots {self.num_slots}, num_inputs {num_inputs}, input dim {dim}, slot dim {self.slot_dim}")
print(f"Shape of k and v {k.shape} {v.shape}") print(f"Shape of k and v {k.shape} {v.shape}")
print(f"dots shape {dots.shape}") print(f"dots shape {dots.shape}")
print(f"masks shape {masks.shape}") print(f"masks shape {masks.shape}")
if masks != None: #if masks != None:
dots += masks # dots += masks
# shape: [batch_size, num_slots, num_inputs] # shape: [batch_size, num_slots, num_inputs]
attn = dots.softmax(dim=1) + self.eps attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True) attn = attn / attn.sum(dim=-1, keepdim=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment