From dca811317a4e7879999345c024cb466988c8b693 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 19 Jul 2023 16:16:10 +0200 Subject: [PATCH] Change device scale --- osrt/sam/sam.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/osrt/sam/sam.py b/osrt/sam/sam.py index 8074cff..764b485 100644 --- a/osrt/sam/sam.py +++ b/osrt/sam/sam.py @@ -109,10 +109,10 @@ class Sam(nn.Module): masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=curr_embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, + image_embeddings=curr_embedding.unsqueeze(0).to(self.device), + image_pe=self.prompt_encoder.get_dense_pe().to(self.device), + sparse_prompt_embeddings=sparse_embeddings.to(self.device), + dense_prompt_embeddings=dense_embeddings.to(self.device), multimask_output=multimask_output, ) masks = self.postprocess_masks( -- GitLab