diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
index b6848f569564763e721afbf1d6f28267e287a0b2..4043e79cfda41621128c0d7ce761031be7fae36f 100644
--- a/osrt/sam/mask_decoder.py
+++ b/osrt/sam/mask_decoder.py
@@ -132,9 +132,7 @@ class MaskDecoder(nn.Module):
 
         pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
         b, c, h, w = src.shape
-
-        print(f"Src device : {src.device}")
-        print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}")
+        
         # Run the transformer
         hs, src = self.transformer(src, pos_src, tokens)
         iou_token_out = hs[:, 0, :]
diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 28fafea52288603fea275f3a100790471825c34a..74ae9be3d727b7e90c1117d85134e742a166d1f5 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -199,11 +199,11 @@ class Attention(nn.Module):
         self.internal_dim = embedding_dim // downsample_rate
         self.num_heads = num_heads
         assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-
-        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
-        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
-        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
-        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+        device = super().parameters().device
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
+        self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
+        self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim).to(device)
 
     def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
         b, n, c = x.shape