Skip to content
Snippets Groups Projects
Commit 5bfbc749 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Device pb

parent 409c42bd
No related branches found
No related tags found
No related merge requests found
...@@ -125,9 +125,9 @@ class MaskDecoder(nn.Module): ...@@ -125,9 +125,9 @@ class MaskDecoder(nn.Module):
# Expand per-image data in batch direction to be per-mask # Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]: if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device) src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device)
else: else:
src = image_embeddings src = image_embeddings.to(dense_prompt_embeddings.device)
src = src + dense_prompt_embeddings src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment