diff --git a/config/config.py b/config/config.py index d898b6bbef2f8ea9c6a9e3bebd3380057cba4ca1..f94bc82fa140737f02ccc5c055271728b1225240 100644 --- a/config/config.py +++ b/config/config.py @@ -10,6 +10,7 @@ def load_args(): parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') diff --git a/main.py b/main.py index 7246be316d9cdecd8e030ea3a537680a21c0b8b7..af15fa96b21716975db10cf994d2c6a784c4c387 100644 --- a/main.py +++ b/main.py @@ -60,7 +60,7 @@ def test(model, data_test, loss_function, epoch): def run(args): data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size) - model = Classification_model(n_class=len(data_train.dataset.dataset.classes)) + model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes)) if args.pretrain_path is not None : load_model(model,args.pretrain_path) if torch.cuda.is_available(): @@ -89,10 +89,10 @@ def run(args): plt.plot(train_acc) plt.plot(train_acc) plt.show() - plt.savefig('output/training_plot_{}_.png'.format(args.output)) + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) load_model(model, args.save_path) - make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output)) + make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model)) def make_prediction(model, data, f_name): diff --git a/models/model.py b/models/model.py index f520cf2990f9c123f69bb976da47e1ffeb2f97d0..5853484a01d4e3bbe0b6923922b5c7327f4eded6 100644 --- a/models/model.py +++ b/models/model.py @@ -262,10 +262,11 @@ def resnet152(num_classes=1000,**kwargs): class Classification_model(nn.Module): - def __init__(self, n_class, *args, **kwargs): + def __init__(self, model, n_class, *args, **kwargs): super().__init__(*args, **kwargs) self.n_class = n_class - self.im_encoder = resnet18(num_classes=self.n_class) + if model =='Resnet18': + self.im_encoder = resnet18(num_classes=self.n_class) def forward(self, input):