diff --git a/config/config.py b/config/config.py index 73f7754b9ac3d06bb711367bd604291540014af6..ac5144bff32e2b1267e03e264abdc6b6e3cd8d43 100644 --- a/config/config.py +++ b/config/config.py @@ -10,6 +10,8 @@ def load_args(): parser.add_argument('--noise_threshold', type=int, default=1000) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam") + parser.add_argument('--beta1', type=float, default=0.9) + parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--classes_names', type=list, default = ["Citrobacter freundii","Citrobacter koseri","Enterobacter asburiae","Enterobacter cloacae","Enterobacter hormaechei","Escherichia coli","Klebsiella aerogenes","Klebsiella michiganensis","Klebsiella oxytoca","Klebsiella pneumoniae","Klebsiella quasipneumoniae","Proteus mirabilis","Salmonella enterica"]) parser.add_argument('--classes_numbers', type=list, default = [51,12,9,10,86,231,20,13,24,96,11,39,11]) diff --git a/main.py b/main.py index b89b3512c0ac037e729e5ca2db6f0456168cab94..7d02e1057741c7ba49b051a1bade633c1a58a371 100644 --- a/main.py +++ b/main.py @@ -230,7 +230,7 @@ def run_duo(args): if args.optim == "SGD": optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optim == "Adam": - optimizer = optim.Adam(model.parameters(), lr=args.lr) + optimizer = optim.Adam(model.parameters(), lr=args.lr, betas = (args.beta1,args.beta2)) else: raise Exception("Unusual args.optim") #train model diff --git a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png index 6a95ebc7a9edc4eedb7f58bec1ada010ea1239a3..b374ff872f64700de0913ce9275856810d89259f 100644 Binary files a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png and b/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png differ diff --git a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png index 1b668b54cf0ce6880930d69b1c5fad4462b24c7f..5a3fcbbd21927c8c8b53442e57df342a8edf1316 100644 Binary files a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png and b/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png differ