diff --git a/main_custom.py b/main_custom.py index d04876c0bac24076baec97ee726560ef238b4486..df172b3d1432ec40df9bf6c9a4063e78bc25399f 100644 --- a/main_custom.py +++ b/main_custom.py @@ -43,7 +43,7 @@ def train(model, data_train, epoch, optimizer, criterion_rt, criterion_intensity optimizer.zero_grad() loss.backward() optimizer.step() - print(i,'/',len(data_train)) + print(i,'/',len(data_train.dataset)) if wandb is not None: wdb.log({"train rt loss": losses_rt / len(data_train), "train int loss": losses_int / len(data_train),