Skip to content
Snippets Groups Projects
Commit 7ef2d115 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : wandb integration

add : wiff reader package
parent 65d809f9
No related branches found
No related tags found
No related merge requests found
Showing
with 71 additions and 28 deletions
......@@ -19,6 +19,7 @@ def load_args_contrastive():
parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
parser.add_argument('--wandb', type=str, default=None)
args = parser.parse_args()
return args
\ No newline at end of file
import os
import wandb as wdb
import matplotlib.pyplot as plt
import numpy as np
......@@ -11,7 +13,7 @@ from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
def train_duo(model, data_train, optimizer, loss_function, epoch):
def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
model.train()
losses = 0.
acc = 0.
......@@ -36,9 +38,14 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
if wandb is not None:
wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc })
return losses, acc
def test_duo(model, data_test, loss_function, epoch):
def test_duo(model, data_test, loss_function, epoch, wandb):
model.eval()
losses = 0.
acc = 0.
......@@ -69,9 +76,23 @@ def test_duo(model, data_test, loss_function, epoch):
acc = acc/(len(data_test.dataset))
acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive))
if wandb is not None:
wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc, "validation contrastive accuracy" : acc_contrastive })
return losses,acc,acc_contrastive
def run_duo(args):
#wandb init
if args.wandb is not None:
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
wdb.init(project="Intensity prediction", dir='./wandb_run', name=args.wandb)
#load data
data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler)
......@@ -100,11 +121,11 @@ def run_duo(args):
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#train model
for e in range(args.epoches):
loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
loss, acc = train_duo(model,data_train,optimizer,loss_function,e,args.wandb)
train_loss.append(loss)
train_acc.append(acc)
if e%args.eval_inter==0 :
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e)
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e,args.wandb)
val_loss.append(loss)
val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
......@@ -112,36 +133,40 @@ def run_duo(args):
save_model(model,args.save_path)
best_loss = loss
# plot and save training figs
plt.clf()
plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train cont acc')
plt.plot(val_cont_acc, label='val cont acc')
plt.plot(val_acc, label='val classification acc')
plt.title('Train and validation accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc="upper left")
plt.ylim(0, 1.05)
plt.tight_layout()
plt.subplot(2, 1, 2)
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val')
plt.title('Train and validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()
plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop))
if args.wandb is None:
plt.clf()
plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train cont acc')
plt.plot(val_cont_acc, label='val cont acc')
plt.plot(val_acc, label='val classification acc')
plt.title('Train and validation accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(loc="upper left")
plt.ylim(0, 1.05)
plt.tight_layout()
plt.subplot(2, 1, 2)
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val')
plt.title('Train and validation loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(loc="upper left")
plt.tight_layout()
plt.show()
plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop))
#load and evaluate best model
load_model(model, args.save_path)
make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_{}_bis.png'.format(args.positive_prop),
'output/confidence_matrix_contractive_{}_bis.png'.format(args.positive_prop))
if args.wandb is not None:
wdb.finish()
def make_prediction_duo(model, data, f_name, f_name2):
for imaer, imana, img_ref, label in data:
......
*.db
*.opendb
*.pyd
*.zip
*.c
*.nogit
*.mgf
*.speclib.txt
**/cython/**/*.c
**/cython/test/*
**/__pycache__/*
**/build/*
**/dist/*
**/*.egg-info/*
**/.DS_Store
.vscode/*
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment