From 75048e1856d9fc4763e1a8f5039ad61b02882922 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 12 Jul 2023 14:36:04 +0200 Subject: [PATCH] Fix policy FSDp --- train.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index f96e017..40a1412 100755 --- a/train.py +++ b/train.py @@ -2,7 +2,10 @@ import torch import torch.optim as optim from torch.nn.parallel import DistributedDataParallel from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload -from torch.distributed.fsdp.wrap import default_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, +) +import functools import numpy as np #import bitsandbytes as bnb @@ -15,9 +18,13 @@ import yaml from osrt import data from osrt.model import OSRT from osrt.trainer import SRTTrainer, OSRTSamTrainer +from osrt.layers import Transformer from osrt.checkpoint import Checkpoint from osrt.utils.common import init_ddp +from segment_anything.modeling.image_encoder import Block +from segment_anything.modeling.transformer import TwoWayTransformer + from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule @@ -135,12 +142,20 @@ if __name__ == '__main__': if args.strategy == "fsdp": model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) + custom_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + Transformer, + TwoWayTransformer, + Block + }, + ) model.encoder = FullyShardedDataParallel( model_encoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + fsdp_auto_wrap_policy=custom_auto_wrap_policy) model.decoder = FullyShardedDataParallel( model_decoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + fsdp_auto_wrap_policy=custom_auto_wrap_policy) else: model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) -- GitLab