diff --git a/main_custom.py b/main_custom.py index ff87467c0aba7a9f854810fb31f62114f8e55ecc..d04876c0bac24076baec97ee726560ef238b4486 100644 --- a/main_custom.py +++ b/main_custom.py @@ -26,6 +26,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity if forward == 'both': i=0 for seq, charge, rt, intensity in data_train: + i+=seq.shape[0] rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda()