diff --git a/osrt/sam/sam.py b/osrt/sam/sam.py index 8074cff6b40addc6b66f7ab4962218eef20da13c..764b485fbf061b373447f10c0deebfbe6b06efeb 100644 --- a/osrt/sam/sam.py +++ b/osrt/sam/sam.py @@ -109,10 +109,10 @@ class Sam(nn.Module): masks=image_record.get("mask_inputs", None), ) low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=curr_embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, + image_embeddings=curr_embedding.unsqueeze(0).to(self.device), + image_pe=self.prompt_encoder.get_dense_pe().to(self.device), + sparse_prompt_embeddings=sparse_embeddings.to(self.device), + dense_prompt_embeddings=dense_embeddings.to(self.device), multimask_output=multimask_output, ) masks = self.postprocess_masks(