Skip to content
Snippets Groups Projects
Commit 05b96da0 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Logs

parent d1a1be78
No related branches found
No related tags found
No related merge requests found
......@@ -151,10 +151,6 @@ class FeatureMasking(nn.Module):
# We first initialize the automatic mask generator from SAM
# TODO : change the loading here !!!!
self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path)
device = self.mask_generator.device
self.mask_generator.image_encoder.to(device)
self.mask_generator.prompt_encoder.to(device)
self.mask_generator.mask_decoder.to(device)
self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
......
......@@ -133,8 +133,8 @@ class MaskDecoder(nn.Module):
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
self.transformer.to(src.device)
print(f"Src device : {src.device}")
print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}")
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
......
......@@ -126,6 +126,7 @@ def main():
optimizer = AdamW(params, lr=lr_scheduler.get_cur_lr(0))
model, optimizer = fabric.setup(model, optimizer)
model = fabric.to_device(model)
#########################
### Training
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment