diff --git a/config_common.py b/config_common.py
index 527364198df4c98c9164b256493d228642b3cc9b..09dea8f600b833a05a64f4dd0faad4fc2a6b6027 100644
--- a/config_common.py
+++ b/config_common.py
@@ -26,6 +26,7 @@ def load_args():
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
     parser.add_argument('--activation', type=str,default='relu')
+    parser.add_argument('--file', action=argparse.BooleanOptionalAction)
     args = parser.parse_args()
 
     return args
diff --git a/main_custom.py b/main_custom.py
index e54c69efd5349982013a2e757f48dd27348e8a4a..b0866f7bdb1c1051dbf188a9999d50fb23ef3f2f 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -175,8 +175,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
 
 
 def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
-        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
-
+        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv', file=False):
     if forward =='transfer' :
         for e in range(1, epochs + 1):
             train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
@@ -197,7 +196,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
                      wandb=wandb)
             if e % save_inter == 0:
                 save(model, 'model_common_' + str(e) + '.pt')
-        save_pred(model, data_val, forward, output)
+        save_pred(model, data_val, forward, output, file=file)
 
 
 def main(args):
@@ -244,7 +243,7 @@ def main(args):
     run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
         data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
         criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
-        wandb=args.wandb, forward=args.forward, output=args.output)
+        wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)
 
     if args.wandb is not None:
         wdb.finish()