Skip to content
Snippets Groups Projects
Commit 74e27806 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : dataloader batched

parent 5f123ed9
No related branches found
No related tags found
No related merge requests found
......@@ -110,22 +110,26 @@ def run_duo(args):
plt.savefig('output/training_plot_contrastive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
#load and evaluate best model
load_model(model, args.save_path)
make_prediction_duo(model,data_test, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model))
make_prediction_duo(model,data_test, 'output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
def make_prediction_duo(model, data, f_name):
n_class = len(data[0][2])
confidence_pred_list = [[] for i in range(n_class)]
y_pred = []
y_true = []
# iterate over test data
for imaer,imana,img_ref, label in data:
label = label.long()
specie = torch.argmax(label)
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda()
output = model(imaer,imana,img_ref)
confidence_pred_list[specie].append(output.data.cpu().numpy())
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output)
......@@ -135,6 +139,10 @@ def make_prediction_duo(model, data, f_name):
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
confidence_matrix = np.zeros((n_class,n_class))
for i in range(n_class):
confidence_matrix[i]=np.mean(confidence_pred_list[i],axis=0)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in range(2)],
columns=['True','False'])
print('Saving Confusion Matrix')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment