diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index b6848f569564763e721afbf1d6f28267e287a0b2..4043e79cfda41621128c0d7ce761031be7fae36f 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -132,9 +132,7 @@ class MaskDecoder(nn.Module): pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape - - print(f"Src device : {src.device}") - print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}") + # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py index 28fafea52288603fea275f3a100790471825c34a..74ae9be3d727b7e90c1117d85134e742a166d1f5 100644 --- a/osrt/sam/transformer.py +++ b/osrt/sam/transformer.py @@ -199,11 +199,11 @@ class Attention(nn.Module): self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." - - self.q_proj = nn.Linear(embedding_dim, self.internal_dim) - self.k_proj = nn.Linear(embedding_dim, self.internal_dim) - self.v_proj = nn.Linear(embedding_dim, self.internal_dim) - self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + device = super().parameters().device + self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim).to(device) def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: b, n, c = x.shape