diff --git a/train.py b/train.py index c0e0dd87e3ec4226aadb1e765322b95bb1811ae9..f96e01777ce57df927419885f19f2e145ba21789 100755 --- a/train.py +++ b/train.py @@ -132,14 +132,18 @@ if __name__ == '__main__': print('Model created.') if world_size > 1: - model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained - model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) - model.encoder = FullyShardedDataParallel( - model_encoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) - model.decoder = FullyShardedDataParallel( - model_decoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + if args.strategy == "fsdp": + model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained + model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) + model.encoder = FullyShardedDataParallel( + model_encoder_ddp(), + fsdp_auto_wrap_policy=default_auto_wrap_policy) + model.decoder = FullyShardedDataParallel( + model_decoder_ddp(), + fsdp_auto_wrap_policy=default_auto_wrap_policy) + else: + model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained + model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) encoder_module = model.encoder.module decoder_module = model.decoder.module