Skip to content
Snippets Groups Projects
Commit 75048e18 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix policy FSDp

parent ef39193a
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,10 @@ import torch ...@@ -2,7 +2,10 @@ import torch
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.wrap import default_auto_wrap_policy from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import functools
import numpy as np import numpy as np
#import bitsandbytes as bnb #import bitsandbytes as bnb
...@@ -15,9 +18,13 @@ import yaml ...@@ -15,9 +18,13 @@ import yaml
from osrt import data from osrt import data
from osrt.model import OSRT from osrt.model import OSRT
from osrt.trainer import SRTTrainer, OSRTSamTrainer from osrt.trainer import SRTTrainer, OSRTSamTrainer
from osrt.layers import Transformer
from osrt.checkpoint import Checkpoint from osrt.checkpoint import Checkpoint
from osrt.utils.common import init_ddp from osrt.utils.common import init_ddp
from segment_anything.modeling.image_encoder import Block
from segment_anything.modeling.transformer import TwoWayTransformer
from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule
...@@ -135,12 +142,20 @@ if __name__ == '__main__': ...@@ -135,12 +142,20 @@ if __name__ == '__main__':
if args.strategy == "fsdp": if args.strategy == "fsdp":
model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
custom_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Transformer,
TwoWayTransformer,
Block
},
)
model.encoder = FullyShardedDataParallel( model.encoder = FullyShardedDataParallel(
model_encoder_ddp(), model_encoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy) fsdp_auto_wrap_policy=custom_auto_wrap_policy)
model.decoder = FullyShardedDataParallel( model.decoder = FullyShardedDataParallel(
model_decoder_ddp(), model_decoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy) fsdp_auto_wrap_policy=custom_auto_wrap_policy)
else: else:
model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment