From a5b96b61e5171f298a4f219ffd2a656ca5caabd1 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 21 Oct 2024 16:48:38 +0200 Subject: [PATCH] datasets --- main_custom.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/main_custom.py b/main_custom.py index daf954f..fec2d98 100644 --- a/main_custom.py +++ b/main_custom.py @@ -116,6 +116,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = model.forward(seq, charge) + print(rt) print(rt.shape,pred_rt.shape) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) @@ -199,8 +200,8 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, else : for e in range(1, epochs + 1): - train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, - wandb=wandb) + # train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, + # wandb=wandb) if e % eval_inter == 0: eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=wandb) -- GitLab