Skip to content
Snippets Groups Projects
Commit 409c42bd authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change devices in mask decoder

parent 9402c9e6
No related branches found
No related tags found
No related merge requests found
...@@ -118,13 +118,14 @@ class MaskDecoder(nn.Module): ...@@ -118,13 +118,14 @@ class MaskDecoder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details.""" """Predicts masks. See 'forward' for more details."""
# Concatenate output tokens # Concatenate output tokens
device = sparse_prompt_embeddings.device
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1).to(device)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1).to(device)
# Expand per-image data in batch direction to be per-mask # Expand per-image data in batch direction to be per-mask
if image_embeddings.shape[0] != tokens.shape[0]: if image_embeddings.shape[0] != tokens.shape[0]:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device)
else: else:
src = image_embeddings src = image_embeddings
src = src + dense_prompt_embeddings src = src + dense_prompt_embeddings
...@@ -141,10 +142,11 @@ class MaskDecoder(nn.Module): ...@@ -141,10 +142,11 @@ class MaskDecoder(nn.Module):
# Upscale mask embeddings and predict masks using the mask tokens # Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w) src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src) upscaled_embedding = self.output_upscaling(src)
upscaled_embedding = upscaled_embedding.to(device)
hyper_in_list: List[torch.Tensor] = [] hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens): for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1) hyper_in = torch.stack(hyper_in_list, dim=1).to(device)
b, c, h, w = upscaled_embedding.shape b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment