Skip to content
Snippets Groups Projects
Commit aea78759 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Resolve problem with device

parent c515e6d5
No related branches found
No related tags found
No related merge requests found
......@@ -186,7 +186,7 @@ class PositionEmbeddingRandom(nn.Module):
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = coords @ self.positional_encoding_gaussian_matrix.to(coords.device)
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment