diff --git a/main_custom.py b/main_custom.py index 283c0fccb065e8989ccb6d79b03d084be6494003..ae2ec0990e35f56a78b72b1d35d93780e2a5ade5 100644 --- a/main_custom.py +++ b/main_custom.py @@ -185,7 +185,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') - save_pred(model, data_val, 'rt', output) + save_pred(model, data_val, 'both', output) else : for e in range(1, epochs + 1):