diff --git a/train.py b/train.py
index c0e0dd87e3ec4226aadb1e765322b95bb1811ae9..f96e01777ce57df927419885f19f2e145ba21789 100755
--- a/train.py
+++ b/train.py
@@ -132,14 +132,18 @@ if __name__ == '__main__':
     print('Model created.')
 
     if world_size > 1:
-        model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
-        model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
-        model.encoder = FullyShardedDataParallel(
-            model_encoder_ddp(), 
-            fsdp_auto_wrap_policy=default_auto_wrap_policy)
-        model.decoder = FullyShardedDataParallel(
-            model_decoder_ddp(), 
-            fsdp_auto_wrap_policy=default_auto_wrap_policy)
+        if args.strategy == "fsdp":
+            model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
+            model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
+            model.encoder = FullyShardedDataParallel(
+                model_encoder_ddp(), 
+                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+            model.decoder = FullyShardedDataParallel(
+                model_decoder_ddp(), 
+                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+        else:
+            model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
+            model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
 
         encoder_module = model.encoder.module
         decoder_module = model.decoder.module