diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index d0ec7c066b0203e53280847e6acbcc6b734feb39..9f33926b5b4d2ecc4972e54c38c9eb9429331e59 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -53,7 +53,7 @@ def train_model(config,args): elif config['optimizer']=='SGD' : optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) # init training - n_class = data_train.dataset.classes + n_class = len(data_train.dataset.classes) loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class])) # Load existing checkpoint through `get_checkpoint()` API.