diff --git a/main.py b/main.py index 83e77e4f47ca14827453c43f3473db0fc9185fbc..04f370c3a40baf34ddbe27e1ad3759f6c05e9b5f 100644 --- a/main.py +++ b/main.py @@ -162,15 +162,15 @@ def save_pred(model, data_val, output_path, criterion_rt, metric_rt, wandb=None loss_rt = criterion_rt(rt, pr_rt) losses_rt += loss_rt.item() - dist_rt = metric_rt(rt, pred_rt) + dist_rt = metric_rt(rt, pr_rt) dist_rt_acc += dist_rt.item() if wandb is not None: wdb.log({"test rt loss": losses_rt / len(data_val), "test rt mean metric": dist_rt_acc / len(data_val)}) - print('val rt loss', losses_rt / len(data_val), - "val rt mean metric : ", + print('test rt loss', losses_rt / len(data_val), + "test rt mean metric : ", dist_rt_acc / len(data_val)) data_frame['rt pred'] = pred_rt