diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 74ae9be3d727b7e90c1117d85134e742a166d1f5..860f80ac93af6b3fc73364aa2d8f8d06c0df763b 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -199,7 +199,7 @@ class Attention(nn.Module): self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." - device = super().parameters().device + device = next(super().parameters()).device self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)