Skip to content
Snippets Groups Projects
Commit 499776f2 authored by Schneider Leo's avatar Schneider Leo
Browse files

debugging confusion matrix

parent 77519ed7
No related branches found
No related tags found
No related merge requests found
Showing
with 8 additions and 11 deletions
...@@ -4,14 +4,14 @@ import argparse ...@@ -4,14 +4,14 @@ import argparse
def load_args(): def load_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--epoches', type=int, default=10) parser.add_argument('--epoches', type=int, default=3)
parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--save_inter', type=int, default=50)
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--model_type', type=str, default='solo') parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--save_path', type=str, default='output/best_model.pt')
......
confusion_matrix.png

42.9 KiB

...@@ -90,10 +90,10 @@ def run(args): ...@@ -90,10 +90,10 @@ def run(args):
plt.plot(train_acc) plt.plot(train_acc)
plt.ylim(0, 1.05) plt.ylim(0, 1.05)
plt.show() plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
load_model(model, args.save_path) load_model(model, args.save_path)
make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
def make_prediction(model, data, f_name): def make_prediction(model, data, f_name):
...@@ -121,7 +121,7 @@ def make_prediction(model, data, f_name): ...@@ -121,7 +121,7 @@ def make_prediction(model, data, f_name):
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes]) columns=[i for i in classes])
plt.figure(figsize=(12, 7)) plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True) sn.heatmap(df_cm, annot=cf_matrix)
plt.savefig(f_name) plt.savefig(f_name)
...@@ -206,6 +206,7 @@ def run_duo(args): ...@@ -206,6 +206,7 @@ def run_duo(args):
plt.plot(train_acc) plt.plot(train_acc)
plt.ylim(0, 1.05) plt.ylim(0, 1.05)
plt.show() plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
load_model(model, args.save_path) load_model(model, args.save_path)
...@@ -215,7 +216,6 @@ def run_duo(args): ...@@ -215,7 +216,6 @@ def run_duo(args):
def make_prediction_duo(model, data, f_name): def make_prediction_duo(model, data, f_name):
y_pred = [] y_pred = []
y_true = [] y_true = []
print('Building confusion matrix')
# iterate over test data # iterate over test data
for imaer,imana, label in data: for imaer,imana, label in data:
label = label.long() label = label.long()
...@@ -233,17 +233,14 @@ def make_prediction_duo(model, data, f_name): ...@@ -233,17 +233,14 @@ def make_prediction_duo(model, data, f_name):
# constant for classes # constant for classes
classes = data.dataset.dataset.classes classes = data.dataset.dataset.classes
print('Prediction made')
# Build confusion matrix # Build confusion matrix
print(len(y_true),len(y_pred)) print(len(y_true),len(y_pred))
cf_matrix = confusion_matrix(y_true, y_pred) cf_matrix = confusion_matrix(y_true, y_pred)
print('CM made') df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
df_cm = pd.DataFrame(cf_matrix[:, None], index=[i for i in classes],
columns=[i for i in classes]) columns=[i for i in classes])
print('Saving Confusion Matrix') print('Saving Confusion Matrix')
plt.figure(figsize=(14, 9)) plt.figure(figsize=(14, 9))
sn.heatmap(df_cm, annot=True) sn.heatmap(df_cm, annot=cf_matrix)
plt.savefig(f_name) plt.savefig(f_name)
......
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18.png

41.1 KiB

output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18.png

41.2 KiB

output/confusion_matrix_noise_100_lr_0.001_model_ResNet18.png

40.9 KiB

output/confusion_matrix_noise_200_lr_0.001_model_ResNet18.png

41.1 KiB

output/confusion_matrix_noise_500_lr_0.001_model_ResNet18.png

37.5 KiB

output/training_plot_noise_0_lr_0.001_model_ResNet18.png

31.8 KiB

output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png

23 KiB

output/training_plot_noise_1000_lr_0.001_model_ResNet18.png

30.3 KiB

output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png

19.5 KiB

output/training_plot_noise_100_lr_0.001_model_ResNet18.png

34 KiB

output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png

15.8 KiB

output/training_plot_noise_200_lr_0.001_model_ResNet18.png

30.8 KiB

output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png

16.8 KiB

output/training_plot_noise_500_lr_0.001_model_ResNet18.png

28.6 KiB

output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png

24.1 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment