diff --git a/config/config.py b/config/config.py index ac5144bff32e2b1267e03e264abdc6b6e3cd8d43..bd022bed2184fc2a6f5d14db3000c9a6e7a7c58a 100644 --- a/config/config.py +++ b/config/config.py @@ -1,21 +1,18 @@ import argparse -import torch def load_args(): parser = argparse.ArgumentParser() - - parser.add_argument('--epoches', type=int, default=20) + parser.add_argument('--epoches', type=int, default=50) parser.add_argument('--eval_inter', type=int, default=1) - parser.add_argument('--noise_threshold', type=int, default=1000) + parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.1,0.1]) + parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--optim', type = str, default = "Adam") - parser.add_argument('--beta1', type=float, default=0.9) - parser.add_argument('--beta2', type=float, default=0.999) + parser.add_argument('--beta1', type=float, default=0.938) + parser.add_argument('--beta2', type=float, default=0.9928) parser.add_argument('--momentum', type=float, default=0.9) - parser.add_argument('--classes_names', type=list, default = ["Citrobacter freundii","Citrobacter koseri","Enterobacter asburiae","Enterobacter cloacae","Enterobacter hormaechei","Escherichia coli","Klebsiella aerogenes","Klebsiella michiganensis","Klebsiella oxytoca","Klebsiella pneumoniae","Klebsiella quasipneumoniae","Proteus mirabilis","Salmonella enterica"]) - parser.add_argument('--classes_numbers', type=list, default = [51,12,9,10,86,231,20,13,24,96,11,39,11]) - parser.add_argument('--weighted_entropy', type=bool, default = True) + parser.add_argument('--weighted_entropy', type=bool, default = False) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model_type', type=str, default='duo') @@ -23,6 +20,7 @@ def load_args(): parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--pretrain_path', type=str, default=None) + parser.add_argument('--random_state',type = int, default = 42) args = parser.parse_args() return args diff --git a/dataset/dataset.py b/dataset/dataset.py index e7e526e831e2b8fb2a1e9c8d8caf00ee8571372a..5119a732aa05708a4dbec17f85de3a05eec1dee0 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -30,8 +30,42 @@ class Log_normalisation: def __call__(self, x): return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon) -class Random_shift_rt(): - pass +class Random_erasing: + """with a probability prob, erases a proportion prop of the image""" + def __init__(self, prob, prop): + self.prob = prob + self.prop = prop + + def __call__(self,x): + if np.random.rand() > self.prob: + return x*(torch.rand_like(x) > self.prop) + return x + + +class Random_int_noise: + """With a probability prob, adds a gaussian noise to the image """ + def __init__(self, prob, maximum): + self.prob = prob + self.minimum = 1/maximum + self.delta = maximum-self.minimum + + def __call__(self, x): + if np.random.rand() > self.prob: + return x*(self.minimum + torch.rand_like(x)*self.delta) + return x + +class Random_shift_rt: + """With a probability prob, shifts verticaly the image depending on a gaussian distribution""" + def __init__(self, prob, mean, std): + self.prob = prob + self.mean = torch.tensor(float(mean)) + self.std = float(std) + + def __call__(self,x): + if np.random.rand()>self.prob: + shift = torch.normal(self.mean,self.std) + return transforms.functional.affine(x,0,[0,shift],1,[0,0]) + return x def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): @@ -175,6 +209,7 @@ class ImageFolderDuo(data.Dataset): if self.transform is not None: imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) + if self.target_transform is not None: target = self.target_transform(target) return imgAER, imgANA, target @@ -184,7 +219,10 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir, batch_size, args, shuffle=True): train_transform = transforms.Compose( - [transforms.Resize((224, 224)), + [Random_erasing(args.augment_args[0], args.augment_args[3]), + Random_int_noise(args.augment_args[1], args.augment_args[4]), + Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]), + transforms.Resize((224, 224)), Threshold_noise(args.noise_threshold), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) @@ -199,8 +237,9 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True): train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) - repart_weights = [args.classes_names[i] for i in range(len(args.classes_names)) for k in range(args.classes_numbers[i])] - iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=42) + classes_name = os.listdir(args.dataset_dir) + repart_weights = [name for name in classes_name for k in range(len(os.listdir(args.dataset_dir+"/"+name))//2)] + iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=args.random_state) train_dataset = torch.utils.data.Subset(train_dataset, iTrain) val_dataset = torch.utils.data.Subset(val_dataset, iVal) # generator1 = torch.Generator().manual_seed(42) @@ -217,7 +256,6 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True): collate_fn=None, pin_memory=False, ) - data_loader_test = data.DataLoader( dataset=val_dataset, batch_size=batch_size, diff --git a/main.py b/main.py index 7d02e1057741c7ba49b051a1bade633c1a58a371..6aa0b89809a21b52e9f2e7c924533ef89aae8052 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import numpy as np from config.config import load_args from dataset.dataset import load_data, load_data_duo +from os import listdir import torch import torch.nn as nn from models.model import Classification_model, Classification_model_duo @@ -198,6 +199,7 @@ def run_duo(args): memPrint(1) data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size,args=args) #load model + memPrint(2) model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes)) memPrint(3) @@ -220,7 +222,7 @@ def run_duo(args): val_loss=[] #init training if args.weighted_entropy: - loss_weights = torch.tensor([1/n for n in args.classes_numbers]) + loss_weights = torch.tensor([2/len(listdir(args.dataset_dir+"/"+class_name)) for class_name in listdir(args.dataset_dir)]) if torch.cuda.is_available(): loss_weights = loss_weights.cuda() loss_function = nn.CrossEntropyLoss(loss_weights) @@ -245,23 +247,25 @@ def run_duo(args): if loss < best_loss : save_model(model,args.save_path) best_loss = loss + best_acc = acc # plot and save training figs plt.plot(train_loss) plt.plot(val_loss) plt.plot(train_loss) plt.plot(train_loss) - plt.ylim(0, 0.025) - plt.show() - plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + plt.savefig(f'output/training_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') #load and evaluate best model load_model(model, args.save_path) - make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) - + make_prediction_duo(model,data_test, f'output/model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png') + return best_loss,best_acc def make_prediction_duo(model, data, f_name): y_pred = [] y_true = [] + softmaxes = [] + classes = data.dataset.dataset.classes + print(classes) # iterate over test data for imaer,imana, label in data: label = label.long() @@ -270,17 +274,28 @@ def make_prediction_duo(model, data, f_name): if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() - label = label.cuda() output = model(imaer,imana) - + softmaxes.extend(nn.functional.softmax(output.detach(),1).cpu().numpy()) output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() y_pred.extend(output) - label = label.data.cpu().numpy() - y_true.extend(label) # Save Truth + y_true.extend(label.numpy()) # Save Truth # constant for classes - classes = data.dataset.dataset.classes + confiance = np.zeros((len(classes),len(classes))) + variance = np.zeros((len(classes),len(classes))) + for i in range(len(softmaxes)): + confiance[y_true[i],:] += softmaxes[i] + denominator = confiance.sum(1,keepdims=True) + print(confiance) + print(denominator) + confiance = confiance/denominator + + for i in range(len(softmaxes)): + variance[y_true[i],:] += (confiance[y_true[i],:]-softmaxes[i])**2 + variance /= denominator + variance = variance**0.5 + # Build confusion matrix print(len(y_true),len(y_pred)) cf_matrix = confusion_matrix(y_true, y_pred) @@ -289,7 +304,19 @@ def make_prediction_duo(model, data, f_name): print('Saving Confusion Matrix') plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=cf_matrix) - plt.savefig(f_name) + confuName = f_name.split("/") + confuName[-1] = "confusion_matrix_"+confuName[-1] + plt.savefig('/'.join(confuName)) + + print('Saving Confidence Matrix') + confiance_df = pd.DataFrame(confiance, index=[i for i in classes], + columns=[i for i in classes]) + plt.figure(figsize=(14, 9)) + sn.heatmap(confiance_df, annot=confiance.astype("<U4")+np.full(confiance.shape,"\u00B1")+variance.astype("<U4"),fmt='') + confiName = f_name.split("/") + confiName[-1] = "confiance_matrix_"+confiName[-1] + plt.savefig('/'.join(confiName)) + def save_model(model, path): diff --git a/models/model.py b/models/model.py index f8f0bf69cc9f883857d42034b2abab201498d50f..7e86a989a765edc422fda2706ba05c2d5eb3ed65 100644 --- a/models/model.py +++ b/models/model.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torchvision def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): """3x3 convolution with padding""" diff --git a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index f41c7b675da69dc7810225603dc42e3e1b1cb654..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index b374ff872f64700de0913ce9275856810d89259f..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 68fc78ad740a4f9816ae5ab7d88fd59335a19924..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 958aae0a2189e7fa0f1fc2eecd45463504a8f2ac..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 428cb6c883a0b19996d2e61b8c0627737217d20c..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index bdf5e978c89ff252061f9e0d045ff31c61f4e970..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 5a3fcbbd21927c8c8b53442e57df342a8edf1316..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index d8a704f018ec661846dfca786117fee547aa431c..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 8221c44e612d806c78ccce7c7c29da89cc46a5db..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index afd1849cb0c80de776239ffbe72a026ce04e91b6..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png and /dev/null differ