diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index 347c5c5977a65c36553c487416a281e0e5694356..49f54cc6a7b0f0032433ffd81b85144b7a67c3e0 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -118,13 +118,14 @@ class MaskDecoder(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens + device = sparse_prompt_embeddings.device output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) - output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) - tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1).to(device) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1).to(device) # Expand per-image data in batch direction to be per-mask if image_embeddings.shape[0] != tokens.shape[0]: - src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0).to(device) else: src = image_embeddings src = src + dense_prompt_embeddings @@ -141,10 +142,11 @@ class MaskDecoder(nn.Module): # Upscale mask embeddings and predict masks using the mask tokens src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) + upscaled_embedding = upscaled_embedding.to(device) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) - hyper_in = torch.stack(hyper_in_list, dim=1) + hyper_in = torch.stack(hyper_in_list, dim=1).to(device) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)