diff --git a/config/config.py b/config/config.py
index d898b6bbef2f8ea9c6a9e3bebd3380057cba4ca1..f94bc82fa140737f02ccc5c055271728b1225240 100644
--- a/config/config.py
+++ b/config/config.py
@@ -10,6 +10,7 @@ def load_args():
     parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--model', type=str, default='ResNet18')
     parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model.pt')
diff --git a/main.py b/main.py
index 7246be316d9cdecd8e030ea3a537680a21c0b8b7..af15fa96b21716975db10cf994d2c6a784c4c387 100644
--- a/main.py
+++ b/main.py
@@ -60,7 +60,7 @@ def test(model, data_test, loss_function, epoch):
 
 def run(args):
     data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
-    model = Classification_model(n_class=len(data_train.dataset.dataset.classes))
+    model = Classification_model(model = args.model, n_class=len(data_train.dataset.dataset.classes))
     if args.pretrain_path is not None :
         load_model(model,args.pretrain_path)
     if torch.cuda.is_available():
@@ -89,10 +89,10 @@ def run(args):
     plt.plot(train_acc)
     plt.plot(train_acc)
     plt.show()
-    plt.savefig('output/training_plot_{}_.png'.format(args.output))
+    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
 
     load_model(model, args.save_path)
-    make_prediction(model,data_test, 'output/confusion_matrix_{}_.png'.format(args.output))
+    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_thresold,args.lr,args.model))
 
 
 def make_prediction(model, data, f_name):
diff --git a/models/model.py b/models/model.py
index f520cf2990f9c123f69bb976da47e1ffeb2f97d0..5853484a01d4e3bbe0b6923922b5c7327f4eded6 100644
--- a/models/model.py
+++ b/models/model.py
@@ -262,10 +262,11 @@ def resnet152(num_classes=1000,**kwargs):
 
 class Classification_model(nn.Module):
 
-    def __init__(self, n_class, *args, **kwargs):
+    def __init__(self, model, n_class, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.n_class = n_class
-        self.im_encoder = resnet18(num_classes=self.n_class)
+        if model =='Resnet18':
+            self.im_encoder = resnet18(num_classes=self.n_class)
 
 
     def forward(self, input):