From 5bfbc749363ee0dfd86e24b7750cfeec68e79519 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 19 Jul 2023 09:10:59 +0200 Subject: [PATCH] Device pb --- osrt/sam/mask_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index 49f54cc..9967d16 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -125,9 +125,9 @@ class MaskDecoder(nn.Module): # Expand per-image data in batch direction to be per-mask if image_embeddings.shape[0] != tokens.shape[0]: - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device) + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device) else: - src = image_embeddings + src = image_embeddings.to(dense_prompt_embeddings.device) src = src + dense_prompt_embeddings pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) -- GitLab