diff --git a/config/config.py b/config/config.py index f94bc82fa140737f02ccc5c055271728b1225240..406b5af4bad923b2286e7ca4dfe6226ce4d4e33b 100644 --- a/config/config.py +++ b/config/config.py @@ -11,6 +11,7 @@ def load_args(): parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--model_type', type=str, default='solo') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') diff --git a/main.py b/main.py index daadc8f34a950b286462f9d955ce7e972598e83b..a4074a243a26c9b71bdecac735ea861dd5ea86b6 100644 --- a/main.py +++ b/main.py @@ -206,10 +206,10 @@ def run_duo(args): plt.plot(train_acc) plt.ylim(0, 1.05) plt.show() - plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) load_model(model, args.save_path) - make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) def make_prediction_duo(model, data, f_name): @@ -254,4 +254,7 @@ def load_model(model, path): if __name__ == '__main__': args = load_args() - run(args) \ No newline at end of file + if args.model_type=='duo': + run_duo(args) + else : + run(args) \ No newline at end of file