diff --git a/main_custom.py b/main_custom.py index a8db44702d0a9f7d04949170d50a68cc70496d01..daf954fe6ec1cd9b8ef1d6d3dcc6e68ef6c14b5c 100644 --- a/main_custom.py +++ b/main_custom.py @@ -116,6 +116,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() pred_rt, pred_int = model.forward(seq, charge) + print(rt.shape,pred_rt.shape) loss_rt = criterion_rt(rt, pred_rt) loss_int = criterion_intensity(intensity, pred_int) losses_rt += loss_rt.item()