From 5bfbc749363ee0dfd86e24b7750cfeec68e79519 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 09:10:59 +0200
Subject: [PATCH] Device pb

---
 osrt/sam/mask_decoder.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index 49f54cc..9967d16 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -125,9 +125,9 @@ class MaskDecoder(nn.Module):
 
         # Expand per-image data in batch direction to be per-mask
         if image_embeddings.shape[0] != tokens.shape[0]:
-            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(dense_prompt_embeddings.device)
         else:
-            src = image_embeddings
+            src = image_embeddings.to(dense_prompt_embeddings.device)
         src = src + dense_prompt_embeddings
 
         pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
-- 
GitLab