Skip to content
Snippets Groups Projects
Commit 75048e18 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix policy FSDp

parent ef39193a
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,10 @@ import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import functools
import numpy as np
#import bitsandbytes as bnb
......@@ -15,9 +18,13 @@ import yaml
from osrt import data
from osrt.model import OSRT
from osrt.trainer import SRTTrainer, OSRTSamTrainer
from osrt.layers import Transformer
from osrt.checkpoint import Checkpoint
from osrt.utils.common import init_ddp
from segment_anything.modeling.image_encoder import Block
from segment_anything.modeling.transformer import TwoWayTransformer
from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule
......@@ -135,12 +142,20 @@ if __name__ == '__main__':
if args.strategy == "fsdp":
model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
custom_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Transformer,
TwoWayTransformer,
Block
},
)
model.encoder = FullyShardedDataParallel(
model_encoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy)
fsdp_auto_wrap_policy=custom_auto_wrap_policy)
model.decoder = FullyShardedDataParallel(
model_decoder_ddp(),
fsdp_auto_wrap_policy=default_auto_wrap_policy)
fsdp_auto_wrap_policy=custom_auto_wrap_policy)
else:
model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained
model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment