diff --git a/config_common.py b/config_common.py
index 7bf9ab098f7c5f93fb7060c025066f5dc675c1a3..527364198df4c98c9164b256493d228642b3cc9b 100644
--- a/config_common.py
+++ b/config_common.py
@@ -23,6 +23,7 @@ def load_args():
     parser.add_argument('--dataset_train', type=str, default='database/data_DIA_ISA_55_train.pkl')
     parser.add_argument('--dataset_val', type=str, default='database/data_DIA_ISA_55_test.pkl')
     parser.add_argument('--dataset_test', type=str, default='database/data_DIA_ISA_55_test.pkl')
+    parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
     parser.add_argument('--activation', type=str,default='relu')
     args = parser.parse_args()
diff --git a/main_custom.py b/main_custom.py
index 1f1d51e3b9487ee450321292b650762da4baf28f..fcc907494f8c2419cfd47376f0c5a9f4fac3a94c 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -174,7 +174,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
 
 
 def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
-        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None):
+        criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
     for e in range(1, epochs + 1):
         train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
               wandb=wandb)
@@ -183,7 +183,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
                  wandb=wandb)
         if e % save_inter == 0:
             save(model, 'model_common_' + str(e) + '.pt')
-    save_pred(model, data_val, forward, 'output/out.csv')
+    save_pred(model, data_val, forward, output)
 
 
 def main(args):
@@ -221,7 +221,7 @@ def main(args):
     run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
         data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
         criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
-        wandb=args.wandb, forward=args.forward)
+        wandb=args.wandb, forward=args.forward, output=args.output)
 
     if args.wandb is not None:
         wdb.finish()