Skip to content
Snippets Groups Projects
Commit ef39193a authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Check strategy

parent 92ed8ddf
No related branches found
No related tags found
No related merge requests found
...@@ -132,14 +132,18 @@ if __name__ == '__main__': ...@@ -132,14 +132,18 @@ if __name__ == '__main__':
print('Model created.') print('Model created.')
if world_size > 1: if world_size > 1:
model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained if args.strategy == "fsdp":
model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.encoder = FullyShardedDataParallel( model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
model_encoder_ddp(), model.encoder = FullyShardedDataParallel(
fsdp_auto_wrap_policy=default_auto_wrap_policy) model_encoder_ddp(),
model.decoder = FullyShardedDataParallel( fsdp_auto_wrap_policy=default_auto_wrap_policy)
model_decoder_ddp(), model.decoder = FullyShardedDataParallel(
fsdp_auto_wrap_policy=default_auto_wrap_policy) model_decoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy)
else:
model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
encoder_module = model.encoder.module encoder_module = model.encoder.module
decoder_module = model.decoder.module decoder_module = model.decoder.module
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment