Skip to content
Snippets Groups Projects
Commit 409c42bd authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change devices in mask decoder

parent 9402c9e6
No related branches found
No related tags found
No related merge requests found
......@@ -118,13 +118,14 @@ class MaskDecoder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
device = sparse_prompt_embeddings.device
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1).to(device)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1).to(device)
# Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
else:
src = image_embeddings
src = src + dense_prompt_embeddings
......@@ -141,10 +142,11 @@ class MaskDecoder(nn.Module):
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
upscaled_embedding = upscaled_embedding.to(device)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)
hyper_in = torch.stack(hyper_in_list, dim=1).to(device)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment