From a5b96b61e5171f298a4f219ffd2a656ca5caabd1 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 21 Oct 2024 16:48:38 +0200
Subject: [PATCH] datasets

---
 main_custom.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index daf954f..fec2d98 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -116,6 +116,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
             if torch.cuda.is_available():
                 seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()
             pred_rt, pred_int = model.forward(seq, charge)
+            print(rt)
             print(rt.shape,pred_rt.shape)
             loss_rt = criterion_rt(rt, pred_rt)
             loss_int = criterion_intensity(intensity, pred_int)
@@ -199,8 +200,8 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
 
     else :
         for e in range(1, epochs + 1):
-            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-                  wandb=wandb)
+            # train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+            #       wandb=wandb)
             if e % eval_inter == 0:
                 eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                      wandb=wandb)
-- 
GitLab