From 9f54597da51f825205d05e10a99eb2388f83540e Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 16 Apr 2025 13:42:30 +0200 Subject: [PATCH] fix : max time budget --- image_ref/config.py | 13 +++++++------ image_ref/main.py | 5 ++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/image_ref/config.py b/image_ref/config.py index 7137615..d49ca0b 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -4,24 +4,25 @@ import argparse def load_args_contrastive(): parser = argparse.ArgumentParser() - parser.add_argument('--epoches', type=int, default=100) + parser.add_argument('--epoches', type=int, default=0) parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--test_inter', type=int, default=10) - parser.add_argument('--noise_threshold', type=int, default=500) + parser.add_argument('--noise_threshold', type=int, default=1.2) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--positive_prop', type=int, default=30) + parser.add_argument('--opti', type=str, default='adam') parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--sampler', type=str, default=None) #'balanced' for weighted oversampling - parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') - parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') + parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data') + parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data') parser.add_argument('--dataset_test_dir', type=str, default=None) - parser.add_argument('--base_out', type=str, default='output/baseline') + parser.add_argument('--base_out', type=str, default='output/best_model_base_ray') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default=None) - parser.add_argument('--wandb', type=str, default=None) + parser.add_argument('--wandb', type=str, default='image_ref/best_model_base_ray.pt') args = parser.parse_args() return args \ No newline at end of file diff --git a/image_ref/main.py b/image_ref/main.py index f33a63b..57135d1 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -126,7 +126,10 @@ def run_duo(args): val_loss = [] # init training loss_function = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + if args.opti == 'adam': + optimizer = optim.Adam(model.parameters(), lr=args.lr) + else : + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) # train model for e in range(args.epoches): loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb) -- GitLab