diff --git a/config/config.py b/config/config.py index 789f0bc1a495a8581be2fb0087b012598b3d7233..5c79da036ed9738cfab704960f83d42b757f4656 100644 --- a/config/config.py +++ b/config/config.py @@ -4,14 +4,14 @@ import argparse def load_args(): parser = argparse.ArgumentParser() - parser.add_argument('--epoches', type=int, default=10) + parser.add_argument('--epoches', type=int, default=3) parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--model', type=str, default='ResNet18') - parser.add_argument('--model_type', type=str, default='solo') + parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') diff --git a/confusion_matrix.png b/confusion_matrix.png deleted file mode 100644 index 7491c351ae66b3d362c23635172dd3e878d2f4e5..0000000000000000000000000000000000000000 Binary files a/confusion_matrix.png and /dev/null differ diff --git a/main.py b/main.py index 6ef07536e7096d4f2f3bacb3c49e23114cd27000..07cc7e36e67402d8cb9a4e174185f327c93cbe8d 100644 --- a/main.py +++ b/main.py @@ -90,10 +90,10 @@ def run(args): plt.plot(train_acc) plt.ylim(0, 1.05) plt.show() - plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) load_model(model, args.save_path) - make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) def make_prediction(model, data, f_name): @@ -121,7 +121,7 @@ def make_prediction(model, data, f_name): df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], columns=[i for i in classes]) plt.figure(figsize=(12, 7)) - sn.heatmap(df_cm, annot=True) + sn.heatmap(df_cm, annot=cf_matrix) plt.savefig(f_name) @@ -206,6 +206,7 @@ def run_duo(args): plt.plot(train_acc) plt.ylim(0, 1.05) plt.show() + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) load_model(model, args.save_path) @@ -215,7 +216,6 @@ def run_duo(args): def make_prediction_duo(model, data, f_name): y_pred = [] y_true = [] - print('Building confusion matrix') # iterate over test data for imaer,imana, label in data: label = label.long() @@ -233,17 +233,14 @@ def make_prediction_duo(model, data, f_name): # constant for classes classes = data.dataset.dataset.classes - print('Prediction made') # Build confusion matrix print(len(y_true),len(y_pred)) cf_matrix = confusion_matrix(y_true, y_pred) - print('CM made') - df_cm = pd.DataFrame(cf_matrix[:, None], index=[i for i in classes], + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], columns=[i for i in classes]) - print('Saving Confusion Matrix') plt.figure(figsize=(14, 9)) - sn.heatmap(df_cm, annot=True) + sn.heatmap(df_cm, annot=cf_matrix) plt.savefig(f_name) diff --git a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18.png b/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18.png deleted file mode 100644 index 2866a20cc0bf19e813058ff39960c49557ea6ff8..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18.png b/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18.png deleted file mode 100644 index b175002946ec8e9b712143e363f6c2e65a426199..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18.png b/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18.png deleted file mode 100644 index 9e0c15a72bbf4cc95868ab13d67c7199c3e99a2a..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18.png b/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18.png deleted file mode 100644 index 93d65729d4af52d5c6700fef541b7e9df5ed56bd..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18.png b/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18.png deleted file mode 100644 index d01920766b80e52671bcdfc44c31f212efef199c..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_0_lr_0.001_model_ResNet18.png b/output/training_plot_noise_0_lr_0.001_model_ResNet18.png deleted file mode 100644 index 8d632305510f9b83a1a8f4db0c412bcb7afb6c7b..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_0_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 17ca5bed2008086a79ecb6dc0965b20e2e974cda..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_1000_lr_0.001_model_ResNet18.png b/output/training_plot_noise_1000_lr_0.001_model_ResNet18.png deleted file mode 100644 index 631f4034300fa4751be70113fce7cc394796f93d..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_1000_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index f03d158e15ee5506c60cbda34d1ce23b9075f13b..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_100_lr_0.001_model_ResNet18.png b/output/training_plot_noise_100_lr_0.001_model_ResNet18.png deleted file mode 100644 index 0c7d62967f7889f32fd4aa8fa589b9b2c3e79358..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_100_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index ebeb22dbc884d142078bb09c4d5a895d509f03ae..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_200_lr_0.001_model_ResNet18.png b/output/training_plot_noise_200_lr_0.001_model_ResNet18.png deleted file mode 100644 index 59dd340de955c982d48eb6a73a7f33520ed80d95..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_200_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 3fe32753980bae1b7abce911e55ba211c19bcd03..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_500_lr_0.001_model_ResNet18.png b/output/training_plot_noise_500_lr_0.001_model_ResNet18.png deleted file mode 100644 index 4e90cdee00f63d2ae6a55a6d1a824dbaed3deca1..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_500_lr_0.001_model_ResNet18.png and /dev/null differ diff --git a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index dfdf3eac2b40ae77c74d394f36e64f8026343d0d..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png and /dev/null differ