diff --git a/main.py b/main.py index 92b6b22e6449f84534e383273fcb78e23944a51d..14e0159a8e38f410906a198602332c1a5268304e 100644 --- a/main.py +++ b/main.py @@ -84,7 +84,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') model.load_state_dict(torch.load(output.strip('.csv')+'pt', weights_only=True)) - save_pred(model, data_test, output) + save_pred(model, data_test, output, criterion_rt, metric_rt, wandb) def main(args):