Skip to content
Snippets Groups Projects
Commit dca81131 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change device scale

parent d82172c8
No related branches found
No related tags found
No related merge requests found
...@@ -109,10 +109,10 @@ class Sam(nn.Module): ...@@ -109,10 +109,10 @@ class Sam(nn.Module):
masks=image_record.get("mask_inputs", None), masks=image_record.get("mask_inputs", None),
) )
low_res_masks, iou_predictions = self.mask_decoder( low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0), image_embeddings=curr_embedding.unsqueeze(0).to(self.device),
image_pe=self.prompt_encoder.get_dense_pe(), image_pe=self.prompt_encoder.get_dense_pe().to(self.device),
sparse_prompt_embeddings=sparse_embeddings, sparse_prompt_embeddings=sparse_embeddings.to(self.device),
dense_prompt_embeddings=dense_embeddings, dense_prompt_embeddings=dense_embeddings.to(self.device),
multimask_output=multimask_output, multimask_output=multimask_output,
) )
masks = self.postprocess_masks( masks = self.postprocess_masks(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment