Skip to content
Snippets Groups Projects
Commit 9f54597d authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : max time budget

parent 58543c49
No related branches found
No related tags found
No related merge requests found
......@@ -4,24 +4,25 @@ import argparse
def load_args_contrastive():
parser = argparse.ArgumentParser()
parser.add_argument('--epoches', type=int, default=100)
parser.add_argument('--epoches', type=int, default=0)
parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--test_inter', type=int, default=10)
parser.add_argument('--noise_threshold', type=int, default=500)
parser.add_argument('--noise_threshold', type=int, default=1.2)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--positive_prop', type=int, default=30)
parser.add_argument('--opti', type=str, default='adam')
parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--sampler', type=str, default=None) #'balanced' for weighted oversampling
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff/npy_image/train_data')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff/npy_image/test_data')
parser.add_argument('--dataset_test_dir', type=str, default=None)
parser.add_argument('--base_out', type=str, default='output/baseline')
parser.add_argument('--base_out', type=str, default='output/best_model_base_ray')
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
parser.add_argument('--wandb', type=str, default=None)
parser.add_argument('--wandb', type=str, default='image_ref/best_model_base_ray.pt')
args = parser.parse_args()
return args
\ No newline at end of file
......@@ -126,7 +126,10 @@ def run_duo(args):
val_loss = []
# init training
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
if args.opti == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else :
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
# train model
for e in range(args.epoches):
loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment