diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index d4eac83711982e631e421a8dfe8b2b0cd3896847..123286fb88f6da44e8218d68de1786baad4f09ae 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -45,7 +45,8 @@ def train_model(config,args): elif config['optimizer']=='SGD' : optimizer = optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) # init training - loss_function = nn.CrossEntropyLoss() + n_class = data_train.dataset.classes + loss_function = nn.CrossEntropyLoss(weight=torch.Tensor([1/n_class,1-1/n_class])) # Load existing checkpoint through `get_checkpoint()` API. if train.get_checkpoint(): @@ -240,7 +241,7 @@ def main(args, gpus_per_trial=1): ), run_config=RunConfig(storage_path="/lustre/fswork/projects/rech/bun/ucg81ws/these/pseudo_image/image_ref/ray_results_test", - name="base_experiment" + name="weight_val_loss_experiment" ), param_space=config