diff --git a/train.py b/train.py
index f96e01777ce57df927419885f19f2e145ba21789..40a14127d066a4e2548e5969ad57acdd83b64191 100755
--- a/train.py
+++ b/train.py
@@ -2,7 +2,10 @@ import torch
 import torch.optim as optim
 from torch.nn.parallel import DistributedDataParallel
 from torch.distributed.fsdp import FullyShardedDataParallel, CPUOffload
-from torch.distributed.fsdp.wrap import default_auto_wrap_policy
+from torch.distributed.fsdp.wrap import (
+   transformer_auto_wrap_policy,
+)
+import functools
 
 import numpy as np
 #import bitsandbytes as bnb
@@ -15,9 +18,13 @@ import yaml
 from osrt import data
 from osrt.model import OSRT
 from osrt.trainer import SRTTrainer, OSRTSamTrainer
+from osrt.layers import Transformer
 from osrt.checkpoint import Checkpoint
 from osrt.utils.common import init_ddp
 
+from segment_anything.modeling.image_encoder import Block
+from segment_anything.modeling.transformer import TwoWayTransformer
+
 from torch.profiler import profile, tensorboard_trace_handler, ProfilerActivity,schedule
 
 
@@ -135,12 +142,20 @@ if __name__ == '__main__':
         if args.strategy == "fsdp":
             model_encoder_ddp = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
             model_decoder_ddp = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)
+            custom_auto_wrap_policy = functools.partial(
+                transformer_auto_wrap_policy,
+                transformer_layer_cls={
+                    Transformer,
+                    TwoWayTransformer,
+                    Block
+                },
+            )
             model.encoder = FullyShardedDataParallel(
                 model_encoder_ddp(), 
-                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+                fsdp_auto_wrap_policy=custom_auto_wrap_policy)
             model.decoder = FullyShardedDataParallel(
                 model_decoder_ddp(), 
-                fsdp_auto_wrap_policy=default_auto_wrap_policy)
+                fsdp_auto_wrap_policy=custom_auto_wrap_policy)
         else:
             model.encoder = DistributedDataParallel(model.encoder, device_ids=[rank], output_device=rank, find_unused_parameters=True) # Set find_unused_parameters to True because the ViT is not trained 
             model.decoder = DistributedDataParallel(model.decoder, device_ids=[rank], output_device=rank, find_unused_parameters=False)