diff --git a/train.py b/train.py index f96e01777ce57df927419885f19f2e145ba21789..40a14127d066a4e2548e5969ad57acdd83b64191 100755 --- a/train.py +++ b/train.py @@ -2,7 +2,10 @@ import torch import torch.optim as optim from torch.nn.parallel import DistributedDataParallel from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload -from torch.distributed.fsdp.wrap import default_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, +) +import functools import numpy as np #import bitsandbytes as bnb @@ -15,9 +18,13 @@ import yaml from osrt import data from osrt.model import OSRT from osrt.trainer import SRTTrainer, OSRTSamTrainer +from osrt.layers import Transformer from osrt.checkpoint import Checkpoint from osrt.utils.common import init_ddp +from segment_anything.modeling.image_encoder import Block +from segment_anything.modeling.transformer import TwoWayTransformer + from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule @@ -135,12 +142,20 @@ if __name__ == '__main__': if args.strategy == "fsdp": model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False) + custom_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + Transformer, + TwoWayTransformer, + Block + }, + ) model.encoder = FullyShardedDataParallel( model_encoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + fsdp_auto_wrap_policy=custom_auto_wrap_policy) model.decoder = FullyShardedDataParallel( model_decoder_ddp(), - fsdp_auto_wrap_policy=default_auto_wrap_policy) + fsdp_auto_wrap_policy=custom_auto_wrap_policy) else: model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)