Skip to content
Snippets Groups Projects
Commit 0773a8e5 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : testset (in addition to existing valset)

parent 7ef2d115
No related branches found
No related tags found
No related merge requests found
...@@ -12,9 +12,10 @@ def load_args_contrastive(): ...@@ -12,9 +12,10 @@ def load_args_contrastive():
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--positive_prop', type=int, default=30) parser.add_argument('--positive_prop', type=int, default=30)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--sampler', type=str, default=None) parser.add_argument('--sampler', type=str, default=None) #'balanced' for weighted oversampling
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_test_dir', type=str, default=None)
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
......
...@@ -154,7 +154,7 @@ class ImageFolderDuo(data.Dataset): ...@@ -154,7 +154,7 @@ class ImageFolderDuo(data.Dataset):
def __len__(self): def __len__(self):
return len(self.imlist) return len(self.imlist)
def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True, noise_threshold=0, ref_dir = None, positive_prop=None, sampler=None):
...@@ -182,9 +182,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise ...@@ -182,9 +182,13 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
print('Default val transform') print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform) train_dataset = ImageFolderDuo(root=base_dir_train, transform=train_transform, ref_dir = ref_dir, positive_prop=positive_prop, ref_transform=ref_transform)
val_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform) val_dataset = ImageFolderDuo_Batched(root=base_dir_val, transform=val_transform, ref_dir = ref_dir, ref_transform=ref_transform)
if sampler =='weighted' : if base_dir_test is not None :
test_dataset = ImageFolderDuo_Batched(root=base_dir_test, transform=val_transform, ref_dir=ref_dir,
ref_transform=ref_transform)
if sampler =='balanced' :
y_train_label = np.array([i for (_,_,i)in train_dataset.imlist]) y_train_label = np.array([i for (_,_,i)in train_dataset.imlist])
class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)]) class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)])
weight = 1. / class_sample_count weight = 1. / class_sample_count
...@@ -211,7 +215,7 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise ...@@ -211,7 +215,7 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
pin_memory=False, pin_memory=False,
) )
data_loader_test = data.DataLoader( data_loader_val = data.DataLoader(
dataset=val_dataset, dataset=val_dataset,
batch_size=1, batch_size=1,
shuffle=shuffle, shuffle=shuffle,
...@@ -220,7 +224,19 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise ...@@ -220,7 +224,19 @@ def load_data_duo(base_dir_train, base_dir_test, batch_size, shuffle=True, noise
pin_memory=False, pin_memory=False,
) )
return data_loader_train, data_loader_test if base_dir_test is not None :
data_loader_test = data.DataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=shuffle,
num_workers=0,
collate_fn=None,
pin_memory=False,
)
else :
data_loader_test = None
return data_loader_train, data_loader_val, data_loader_test
class ImageFolderDuo_Batched(data.Dataset): class ImageFolderDuo_Batched(data.Dataset):
......
...@@ -13,6 +13,7 @@ from sklearn.metrics import confusion_matrix ...@@ -13,6 +13,7 @@ from sklearn.metrics import confusion_matrix
import seaborn as sn import seaborn as sn
import pandas as pd import pandas as pd
def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
model.train() model.train()
losses = 0. losses = 0.
...@@ -20,32 +21,32 @@ def train_duo(model, data_train, optimizer, loss_function, epoch, wandb): ...@@ -20,32 +21,32 @@ def train_duo(model, data_train, optimizer, loss_function, epoch, wandb):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = True param.requires_grad = True
for imaer,imana, img_ref, label in data_train: for imaer, imana, img_ref, label in data_train:
label = label.long() label = label.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
pred_logits = model.forward(imaer,imana,img_ref) pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits,dim=1) pred_class = torch.argmax(pred_logits, dim=1)
acc += (pred_class==label).sum().item() acc += (pred_class == label).sum().item()
loss = loss_function(pred_logits,label) loss = loss_function(pred_logits, label)
losses += loss.item() losses += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
losses = losses/len(data_train.dataset) losses = losses / len(data_train.dataset)
acc = acc/len(data_train.dataset) acc = acc / len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc))
if wandb is not None: if wandb is not None:
wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc }) wdb.log({"train loss": losses, 'train epoch': epoch, "train contrastive accuracy": acc})
return losses, acc return losses, acc
def test_duo(model, data_test, loss_function, epoch, wandb):
def val_duo(model, data_test, loss_function, epoch, wandb):
model.eval() model.eval()
losses = 0. losses = 0.
acc = 0. acc = 0.
...@@ -53,11 +54,11 @@ def test_duo(model, data_test, loss_function, epoch, wandb): ...@@ -53,11 +54,11 @@ def test_duo(model, data_test, loss_function, epoch, wandb):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
for imaer,imana, img_ref, label in data_test: for imaer, imana, img_ref, label in data_test:
imaer = imaer.transpose(0,1) imaer = imaer.transpose(0, 1)
imana = imana.transpose(0,1) imana = imana.transpose(0, 1)
img_ref = img_ref.transpose(0,1) img_ref = img_ref.transpose(0, 1)
label = label.transpose(0,1) label = label.transpose(0, 1)
label = label.squeeze() label = label.squeeze()
label = label.long() label = label.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -66,75 +67,86 @@ def test_duo(model, data_test, loss_function, epoch, wandb): ...@@ -66,75 +67,86 @@ def test_duo(model, data_test, loss_function, epoch, wandb):
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
label_class = torch.argmin(label).data.cpu().numpy() label_class = torch.argmin(label).data.cpu().numpy()
pred_logits = model.forward(imaer,imana,img_ref) pred_logits = model.forward(imaer, imana, img_ref)
pred_class = torch.argmax(pred_logits[:,0]).tolist() pred_class = torch.argmax(pred_logits[:, 0]).tolist()
acc_contrastive += (torch.argmax(pred_logits,dim=1).data.cpu().numpy()==label.data.cpu().numpy()).sum().item() acc_contrastive += (
acc += (pred_class==label_class) torch.argmax(pred_logits, dim=1).data.cpu().numpy() == label.data.cpu().numpy()).sum().item()
loss = loss_function(pred_logits,label) acc += (pred_class == label_class)
loss = loss_function(pred_logits, label)
losses += loss.item() losses += loss.item()
losses = losses/(label.shape[0]*len(data_test.dataset)) losses = losses / (label.shape[0] * len(data_test.dataset))
acc = acc/(len(data_test.dataset)) acc = acc / (len(data_test.dataset))
acc_contrastive = acc_contrastive /(label.shape[0]*len(data_test.dataset)) acc_contrastive = acc_contrastive / (label.shape[0] * len(data_test.dataset))
print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch,losses,acc,acc_contrastive)) print('Test epoch {}, loss : {:.3f} acc : {:.3f} acc contrastive : {:.3f}'.format(epoch, losses, acc,
acc_contrastive))
if wandb is not None: if wandb is not None:
wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc, "validation contrastive accuracy" : acc_contrastive }) wdb.log({"validation loss": losses, 'validation epoch': epoch, "validation classification accuracy": acc,
"validation contrastive accuracy": acc_contrastive})
return losses,acc,acc_contrastive return losses, acc, acc_contrastive
def run_duo(args):
#wandb init def run_duo(args):
# wandb init
if args.wandb is not None: if args.wandb is not None:
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
wdb.init(project="Intensity prediction", dir='./wandb_run', name=args.wandb) wdb.init(project="contrastive_classification", dir='./wandb_run', name=args.wandb)
#load data # load data
data_train, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir, base_dir_test=args.dataset_val_dir, batch_size=args.batch_size, data_train, data_val_batch, data_test_batch = load_data_duo(base_dir_train=args.dataset_train_dir,
ref_dir=args.dataset_ref_dir, positive_prop=args.positive_prop, sampler=args.sampler) base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size,
ref_dir=args.dataset_ref_dir,
positive_prop=args.positive_prop, sampler=args.sampler)
#load model # load model
model = Classification_model_duo_contrastive(model = args.model, n_class=2) model = Classification_model_duo_contrastive(model=args.model, n_class=2)
model.double() model.double()
#load weight # load weight
if args.pretrain_path is not None : if args.pretrain_path is not None:
print('Model weight loaded') print('Model weight loaded')
load_model(model,args.pretrain_path) load_model(model, args.pretrain_path)
#move parameters to GPU # move parameters to GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
print('Model loaded on GPU') print('Model loaded on GPU')
model = model.cuda() model = model.cuda()
#init accumulators # init accumulators
best_loss = 100 best_loss = 100
train_acc=[] train_acc = []
train_loss=[] train_loss = []
val_acc=[] val_acc = []
val_cont_acc=[] val_cont_acc = []
val_loss=[] val_loss = []
#init training # init training
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#train model # train model
for e in range(args.epoches): for e in range(args.epoches):
loss, acc = train_duo(model,data_train,optimizer,loss_function,e,args.wandb) loss, acc = train_duo(model, data_train, optimizer, loss_function, e, args.wandb)
train_loss.append(loss) train_loss.append(loss)
train_acc.append(acc) train_acc.append(acc)
if e%args.eval_inter==0 : if e % args.eval_inter == 0:
loss, acc, acc_contrastive = test_duo(model,data_test_batch,loss_function,e,args.wandb) loss, acc, acc_contrastive = val_duo(model, data_val_batch, loss_function, e, args.wandb)
val_loss.append(loss) val_loss.append(loss)
val_acc.append(acc) val_acc.append(acc)
val_cont_acc.append(acc_contrastive) val_cont_acc.append(acc_contrastive)
if loss < best_loss : if loss < best_loss:
save_model(model,args.save_path) save_model(model, args.save_path)
best_loss = loss best_loss = loss
if e % args.test_inter == 0 and args.dataset_test_dir is not None:
loss, acc, acc_contrastive = val_duo(model, data_test_batch, loss_function, e, args.wandb)
val_loss.append(loss)
val_acc.append(acc)
val_cont_acc.append(acc_contrastive)
# plot and save training figs # plot and save training figs
if args.wandb is None: if args.wandb is None:
plt.clf() plt.clf()
plt.subplot(2, 1, 1) plt.subplot(2, 1, 1)
plt.plot(train_acc, label='train cont acc') plt.plot(train_acc, label='train cont acc')
...@@ -159,10 +171,16 @@ def run_duo(args): ...@@ -159,10 +171,16 @@ def run_duo(args):
plt.show() plt.show()
plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop)) plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop))
#load and evaluate best model # load and evaluate best model
load_model(model, args.save_path) load_model(model, args.save_path)
make_prediction_duo(model,data_test_batch, 'output/confusion_matrix_contractive_{}_bis.png'.format(args.positive_prop), if args.args.dataset_test_dir is not None :
'output/confidence_matrix_contractive_{}_bis.png'.format(args.positive_prop)) make_prediction_duo(model, data_test_batch,
'output/confusion_matrix_contractive_{}_bis_test.png'.format(args.positive_prop),
'output/confidence_matrix_contractive_{}_bis_test.png'.format(args.positive_prop))
make_prediction_duo(model, data_val_batch,
'output/confusion_matrix_contractive_{}_bis_val.png'.format(args.positive_prop),
'output/confidence_matrix_contractive_{}_bis_val.png'.format(args.positive_prop))
if args.wandb is not None: if args.wandb is not None:
wdb.finish() wdb.finish()
...@@ -177,26 +195,25 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -177,26 +195,25 @@ def make_prediction_duo(model, data, f_name, f_name2):
y_true = [] y_true = []
soft_max = nn.Softmax(dim=1) soft_max = nn.Softmax(dim=1)
# iterate over test data # iterate over test data
for imaer,imana,img_ref, label in data: for imaer, imana, img_ref, label in data:
imaer = imaer.transpose(0,1) imaer = imaer.transpose(0, 1)
imana = imana.transpose(0,1) imana = imana.transpose(0, 1)
img_ref = img_ref.transpose(0,1) img_ref = img_ref.transpose(0, 1)
label = label.transpose(0,1) label = label.transpose(0, 1)
label = label.squeeze() label = label.squeeze()
label = label.long() label = label.long()
specie = torch.argmin(label) specie = torch.argmin(label)
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
img_ref = img_ref.cuda() img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
output = model(imaer,imana,img_ref) output = model(imaer, imana, img_ref)
confidence = soft_max(output) confidence = soft_max(output)
confidence_pred_list[specie].append(confidence[:,0].data.cpu().numpy()) confidence_pred_list[specie].append(confidence[:, 0].data.cpu().numpy())
#Mono class output (only most postive paire) # Mono class output (only most postive paire)
output = torch.argmax(output[:,0]) output = torch.argmax(output[:, 0])
label = torch.argmin(label) label = torch.argmin(label)
y_pred.append(output.tolist()) y_pred.append(output.tolist())
y_true.append(label.tolist()) # Save Truth y_true.append(label.tolist()) # Save Truth
...@@ -205,9 +222,9 @@ def make_prediction_duo(model, data, f_name, f_name2): ...@@ -205,9 +222,9 @@ def make_prediction_duo(model, data, f_name, f_name2):
# Build confusion matrix # Build confusion matrix
classes = data.dataset.classes classes = data.dataset.classes
cf_matrix = confusion_matrix(y_true, y_pred) cf_matrix = confusion_matrix(y_true, y_pred)
confidence_matrix = np.zeros((n_class,n_class)) confidence_matrix = np.zeros((n_class, n_class))
for i in range(n_class): for i in range(n_class):
confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0) confidence_matrix[i] = np.mean(confidence_pred_list[i], axis=0)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes]) columns=[i for i in classes])
...@@ -230,12 +247,12 @@ def save_model(model, path): ...@@ -230,12 +247,12 @@ def save_model(model, path):
print('Model saved') print('Model saved')
torch.save(model.state_dict(), path) torch.save(model.state_dict(), path)
def load_model(model, path): def load_model(model, path):
model.load_state_dict(torch.load(path, weights_only=True)) model.load_state_dict(torch.load(path, weights_only=True))
if __name__ == '__main__': if __name__ == '__main__':
args = load_args_contrastive() args = load_args_contrastive()
print(args) print(args)
run_duo(args) run_duo(args)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment