From ef39193a9ba0be54f1d6e8cbc0329d103b5ddf68 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 12 Jul 2023 14:14:45 +0200
Subject: [PATCH] Check strategy

---
 train.py | 20 ++++++++++++--------
 1 file changed, 12 insertions(+), 8 deletions(-)

diff --git a/train.py b/train.py
index c0e0dd8..f96e017 100755
--- a/train.py
+++ b/train.py
@@ -132,14 +132,18 @@ if __name__ == '__main__':
     print('Model created.')
 
     if world_size > 1:
-        model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
-        model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
-        model.encoder = FullyShardedDataParallel(
-            model_encoder_ddp(), 
-            fsdp_auto_wrap_policy=default_auto_wrap_policy)
-        model.decoder = FullyShardedDataParallel(
-            model_decoder_ddp(), 
-            fsdp_auto_wrap_policy=default_auto_wrap_policy)
+        if args.strategy == "fsdp":
+            model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
+            model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
+            model.encoder = FullyShardedDataParallel(
+                model_encoder_ddp(), 
+                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+            model.decoder = FullyShardedDataParallel(
+                model_decoder_ddp(), 
+                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+        else:
+            model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
+            model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
 
         encoder_module = model.encoder.module
         decoder_module = model.decoder.module
-- 
GitLab