diff --git a/main.py b/main.py
index 92b6b22e6449f84534e383273fcb78e23944a51d..14e0159a8e38f410906a198602332c1a5268304e 100644
--- a/main.py
+++ b/main.py
@@ -84,7 +84,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
         if e % save_inter == 0:
             save(model, 'model_common_' + str(e) + '.pt')
     model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True))
-    save_pred(model, data_test, output)
+    save_pred(model, data_test, output, criterion_rt, metric_rt, wandb)
 
 
 def main(args):