From ef39193a9ba0be54f1d6e8cbc0329d103b5ddf68 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 12 Jul 2023 14:14:45 +0200 Subject: [PATCH] Check strategy --- train.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index c0e0dd8..f96e017 100755 --- a/train.py +++ b/train.py @@ -132,14 +132,18 @@ if __name__ == '__main__': print('Model created.') if world_size > 1: - model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained - model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) - model.encoder = FullyShardedDataParallel( - model_encoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) - model.decoder = FullyShardedDataParallel( - model_decoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + if args.strategy == "fsdp": + model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained + model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) + model.encoder = FullyShardedDataParallel( + model_encoder_ddp(), + fsdp_auto_wrap_policy=default_auto_wrap_policy) + model.decoder = FullyShardedDataParallel( + model_decoder_ddp(), + fsdp_auto_wrap_policy=default_auto_wrap_policy) + else: + model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained + model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) encoder_module = model.encoder.module decoder_module = model.decoder.module -- GitLab