Skip to content
Snippets Groups Projects
Commit 780938ab authored by Schneider Leo's avatar Schneider Leo
Browse files

add : wandb sweep

parent a33f5639
No related branches found
No related tags found
No related merge requests found
import wandb as wdb
from config import load_args_contrastive
if __name__ == '__main__':
args = load_args_contrastive()
sweep_configuration = {
"program": "sweep_train.py",
"method": "random",
"metric": {"goal": "minimize", "name": "validation loss"},
"parameters": {
"epoches":{"value": 50},
"eval_inter":{"value": 1},
"noise_threshold": {"distribution" : "log_uniform_values", "max": 10000., "min": 0.0001},
"lr": {"distribution" : "log_uniform_values", "max": 0.01, "min": 0.0001},
"batch_size": {"value": 64},
"positive_prop": {"distribution" : "uniform","max": 95., "min": 5.},
"opti": {"value": "adam"},
"model": {"value": "resnet18"},
"sampler": {"values": ["random","balanced"]},
"dataset_train_dir": {"value": "data/processed_data_wiff/npy_image/train_data"},
"dataset_val_dir": {"value": "data/processed_data_wiff/npy_image/test_data"},
"dataset_ref_dir": {"values": ["image_ref/img_ref","image_ref/img_ref_count_th_10","image_ref/img_ref_count_th_5"]},
},
}
sweep_id = wdb.sweep(sweep=sweep_configuration, project="param_sweep_contrastive")
sweep = wdb.controller(sweep_id)
sweep.configure_controller(type="local")
sweep.run()
import os
import wandb as wdb
from dataset_ref import load_data_duo
import torch
import torch.nn as nn
from model import Classification_model_duo_contrastive
import torch.optim as optim
def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
for imaer, imana, img_ref, label in data_train:
imaer = imaer.float()
imana = imana.float()
img_ref = img_ref.float()
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits, dim=1)
acc += (pred_class == label).sum().item()
loss = loss_function(pred_logits, label)
losses += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses = losses / len(data_train.dataset)
acc = acc / len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc))
wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc})
return losses, acc
def val_duo(model, data_test, loss_function, epoch, wandb):
model.eval()
losses = 0.
acc = 0.
acc_contrastive = 0.
for param in model.parameters():
param.requires_grad = False
for imaer, imana, img_ref, label in data_test:
imaer = imaer.float()
imana = imana.float()
img_ref = img_ref.float()
imaer = imaer.transpose(0, 1)
imana = imana.transpose(0, 1)
img_ref = img_ref.transpose(0, 1)
label = label.transpose(0, 1)
label = label.squeeze()
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits[:, 0]).tolist()
acc_contrastive += (
torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
acc += (pred_class == label_class)
loss = loss_function(pred_logits, label)
losses += loss.item()
losses = losses / (label.shape[0] * len(data_test.dataset))
acc = acc / (len(data_test.dataset))
acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc,
acc_contrastive))
wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc,
"validation contrastive accuracy": acc_contrastive})
return losses, acc, acc_contrastive
def run_duo(args):
# wandb init
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
wdb.init(project="param_sweep_contrastive", dir='./wandb_run')
print('Wandb initialised')
# load data
data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
base_dir_val=args.dataset_val_dir,
base_dir_test=None,
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
positive_prop=args.positive_prop, sampler=args.sampler)
# load model
model = Classification_model_duo_contrastive(model=args.model, n_class=2)
model.float()
# move parameters to GPU
if torch.cuda.is_available():
print('Model loaded on GPU')
model = model.cuda()
# init accumulators
best_loss = 100
train_acc = []
train_loss = []
val_acc = []
val_cont_acc = []
val_loss = []
# init training
loss_function = nn.CrossEntropyLoss()
if args.opti == 'adam':
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# train model
for e in range(args.epoches):
loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
train_loss.append(loss)
train_acc.append(acc)
if e % args.eval_inter == 0:
loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb)
val_loss.append(loss)
val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
wdb.finish()
if __name__ == '__main__':
config = wdb.config
print(config)
run_duo(config)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment