From 28365b5a09641922d57ea1302c2256fc357074c8 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 21 Oct 2024 15:44:10 +0200 Subject: [PATCH] datasets --- main_custom.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/main_custom.py b/main_custom.py index df172b3..7048332 100644 --- a/main_custom.py +++ b/main_custom.py @@ -24,9 +24,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity for param in model.parameters(): param.requires_grad = True if forward == 'both': - i=0 for seq, charge, rt, intensity in data_train: - i+=seq.shape[0] rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() @@ -43,7 +41,6 @@ def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity optimizer.zero_grad() loss.backward() optimizer.step() - print(i,'/',len(data_train.dataset)) if wandb is not None: wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train), -- GitLab