diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
index 74ae9be3d727b7e90c1117d85134e742a166d1f5..860f80ac93af6b3fc73364aa2d8f8d06c0df763b 100644
--- a/osrt/sam/transformer.py
+++ b/osrt/sam/transformer.py
@@ -199,7 +199,7 @@ class Attention(nn.Module):
         self.internal_dim = embedding_dim // downsample_rate
         self.num_heads = num_heads
         assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-        device = super().parameters().device
+        device = next(super().parameters()).device
         self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
         self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
         self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)