diff --git a/osrt/encoder.py b/osrt/encoder.py index 706f516ca8fbee0391d4cbeac27475a5cc84ee22..deb591831b95b4daf60af4a019e520ec67519552 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -151,10 +151,6 @@ class FeatureMasking(nn.Module): # We first initialize the automatic mask generator from SAM # TODO : change the loading here !!!! self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) - device = self.mask_generator.device - self.mask_generator.image_encoder.to(device) - self.mask_generator.prompt_encoder.to(device) - self.mask_generator.mask_decoder.to(device) self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size) self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size) diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py index d0739c84374b9e95db54cdf98d4141bd4652c99e..b6848f569564763e721afbf1d6f28267e287a0b2 100644 --- a/osrt/sam/mask_decoder.py +++ b/osrt/sam/mask_decoder.py @@ -133,8 +133,8 @@ class MaskDecoder(nn.Module): pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape - - self.transformer.to(src.device) + print(f"Src device : {src.device}") + print(f"Transformer MaskDecoder device : {next(self.transformer.parameters()).device}") # Run the transformer hs, src = self.transformer(src, pos_src, tokens) iou_token_out = hs[:, 0, :] diff --git a/train.py b/train.py index 4351afbf4d9fb84de942c02be0546bf81887c61a..ac9f416e040ceccae8f007c97434d01528a93e86 100755 --- a/train.py +++ b/train.py @@ -126,6 +126,7 @@ def main(): optimizer = AdamW(params, lr=lr_scheduler.get_cur_lr(0)) model, optimizer = fabric.setup(model, optimizer) + model = fabric.to_device(model) ######################### ### Training