Skip to content
Snippets Groups Projects
Commit 2015a9ae authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add device

parent e666be99
No related branches found
No related tags found
No related merge requests found
......@@ -169,7 +169,7 @@ class FeatureMasking(nn.Module):
randomize_initial_slots=randomize_initial_slots)
def forward(self, images, camera_pos=None, rays=None, extract_masks=False):
def forward(self, images, camera_pos=None, rays=None,extract_masks=False):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
......@@ -191,7 +191,7 @@ class FeatureMasking(nn.Module):
im_size = self.resize.apply_image(images[0]).shape[-3:-1]
### Pre-process images for the image encoder (Resize and Pad)
images = torch.stack([self.preprocess(x) for x in images])
images = torch.stack([self.preprocess(x) for x in images], device=self.mask_generator.device)
### Encode images
image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W]
......
......@@ -124,9 +124,9 @@ def train_sam(
data_time.update(time.time() - end)
### Extract input data
input_images = data.get('input_images').to(fabric.device)
input_camera_pos = data.get('input_camera_pos').to(fabric.device)
input_rays = data.get('input_rays').to(fabric.device)
input_images = data.get('input_images')
input_camera_pos = data.get('input_camera_pos')
input_rays = data.get('input_rays')
### Encode input informations and extract masks
if isinstance(model.encoder, FeatureMasking):
......@@ -260,9 +260,6 @@ def main(cfg) -> None:
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)
train_loader = fabric.to_device(train_loader)
val_loader = fabric.to_device(val_loader)
test_loader = fabric.to_device(test_loader)
vis_loader_val = DataLoader(val_dataset, batch_size=12, num_workers=num_workers)
data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment