diff --git a/main_custom.py b/main_custom.py
index 78b8bb060d995465ba1fe325afd2a7b3c19bf0a5..fb68fa6e6facf3798ad984c373cdc8d0cc2549f8 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -175,15 +175,28 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
 
 def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
         criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
-    for e in range(1, epochs + 1):
-        train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-              wandb=wandb)
-        if e % eval_inter == 0:
-            eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-                 wandb=wandb)
-        if e % save_inter == 0:
-            save(model, 'model_common_' + str(e) + '.pt')
-    save_pred(model, data_val, forward, output)
+
+    if forward =='transfer' :
+        for e in range(1, epochs + 1):
+            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
+                  wandb=wandb)
+            if e % eval_inter == 0:
+                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
+                     wandb=wandb)
+            if e % save_inter == 0:
+                save(model, 'model_common_' + str(e) + '.pt')
+        save_pred(model, data_val, 'rt', output)
+
+    else :
+        for e in range(1, epochs + 1):
+            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                  wandb=wandb)
+            if e % eval_inter == 0:
+                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                     wandb=wandb)
+            if e % save_inter == 0:
+                save(model, 'model_common_' + str(e) + '.pt')
+        save_pred(model, data_val, forward, output)
 
 
 def main(args):
@@ -202,7 +215,11 @@ def main(args):
                                                                    path_test=args.dataset_test,
                                                                    batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc')
     elif args.forward == 'rt':
-        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
+        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
+                                                               batch_size=args.batch_size, length=25)
+
+    elif args.forward == 'transfer':
+        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.cvs','database/data_holdout.cvs'],
                                                                batch_size=args.batch_size, length=25)
 
         _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,