From dca811317a4e7879999345c024cb466988c8b693 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 19 Jul 2023 16:16:10 +0200
Subject: [PATCH] Change device scale

---
 osrt/sam/sam.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/osrt/sam/sam.py b/osrt/sam/sam.py
index 8074cff..764b485 100644
--- a/osrt/sam/sam.py
+++ b/osrt/sam/sam.py
@@ -109,10 +109,10 @@ class Sam(nn.Module):
                 masks=image_record.get("mask_inputs", None),
             )
             low_res_masks, iou_predictions = self.mask_decoder(
-                image_embeddings=curr_embedding.unsqueeze(0),
-                image_pe=self.prompt_encoder.get_dense_pe(),
-                sparse_prompt_embeddings=sparse_embeddings,
-                dense_prompt_embeddings=dense_embeddings,
+                image_embeddings=curr_embedding.unsqueeze(0).to(self.device),
+                image_pe=self.prompt_encoder.get_dense_pe().to(self.device),
+                sparse_prompt_embeddings=sparse_embeddings.to(self.device),
+                dense_prompt_embeddings=dense_embeddings.to(self.device),
                 multimask_output=multimask_output,
             )
             masks = self.postprocess_masks(
-- 
GitLab