Skip to content
Snippets Groups Projects
Commit bf7f214a authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Handle devices

parent 05b96da0
No related branches found
No related tags found
No related merge requests found
......@@ -132,9 +132,7 @@ class MaskDecoder(nn.Module):
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
print(f"Src device : {src.device}")
print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}")
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
......
......@@ -199,11 +199,11 @@ class Attention(nn.Module):
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
device = super().parameters().device
self.q_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim).to(device)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim).to(device)
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
b, n, c = x.shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment