diff --git a/config/config.py b/config/config.py
index ac5144bff32e2b1267e03e264abdc6b6e3cd8d43..bd022bed2184fc2a6f5d14db3000c9a6e7a7c58a 100644
--- a/config/config.py
+++ b/config/config.py
@@ -1,21 +1,18 @@
 import argparse
-import torch
 
 
 def load_args():
     parser = argparse.ArgumentParser()
-
-    parser.add_argument('--epoches', type=int, default=20)
+    parser.add_argument('--epoches', type=int, default=50)
     parser.add_argument('--eval_inter', type=int, default=1)
-    parser.add_argument('--noise_threshold', type=int, default=1000)
+    parser.add_argument('--augment_args', nargs = '+', type = float, default = [0,0,0,0.1,0.1,0.1,0.1])
+    parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
     parser.add_argument('--optim', type = str, default = "Adam")
-    parser.add_argument('--beta1', type=float, default=0.9)
-    parser.add_argument('--beta2', type=float, default=0.999)
+    parser.add_argument('--beta1', type=float, default=0.938)
+    parser.add_argument('--beta2', type=float, default=0.9928)
     parser.add_argument('--momentum', type=float, default=0.9)
-    parser.add_argument('--classes_names', type=list, default = ["Citrobacter freundii","Citrobacter koseri","Enterobacter asburiae","Enterobacter cloacae","Enterobacter hormaechei","Escherichia coli","Klebsiella aerogenes","Klebsiella michiganensis","Klebsiella oxytoca","Klebsiella pneumoniae","Klebsiella quasipneumoniae","Proteus mirabilis","Salmonella enterica"])
-    parser.add_argument('--classes_numbers', type=list, default = [51,12,9,10,86,231,20,13,24,96,11,39,11])
-    parser.add_argument('--weighted_entropy', type=bool, default = True)
+    parser.add_argument('--weighted_entropy', type=bool, default = False)
     parser.add_argument('--batch_size', type=int, default=128)
     parser.add_argument('--model', type=str, default='ResNet18')
     parser.add_argument('--model_type', type=str, default='duo')
@@ -23,6 +20,7 @@ def load_args():
     parser.add_argument('--output', type=str, default='output/out.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
+    parser.add_argument('--random_state',type = int, default = 42)
     args = parser.parse_args()
 
     return args
diff --git a/dataset/dataset.py b/dataset/dataset.py
index e7e526e831e2b8fb2a1e9c8d8caf00ee8571372a..5119a732aa05708a4dbec17f85de3a05eec1dee0 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -30,8 +30,42 @@ class Log_normalisation:
     def __call__(self, x):
         return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
 
-class Random_shift_rt():
-    pass
+class Random_erasing:
+    """with a probability prob, erases a proportion prop of the image"""
+    def __init__(self, prob, prop):
+        self.prob = prob
+        self.prop = prop
+    
+    def __call__(self,x):
+        if np.random.rand() > self.prob:
+            return x*(torch.rand_like(x) > self.prop)
+        return x
+        
+        
+class Random_int_noise:
+    """With a probability prob, adds a gaussian noise to the image """
+    def __init__(self, prob, maximum):
+        self.prob = prob
+        self.minimum = 1/maximum
+        self.delta = maximum-self.minimum
+        
+    def __call__(self, x):
+        if np.random.rand() > self.prob:
+            return x*(self.minimum + torch.rand_like(x)*self.delta)
+        return x
+    
+class Random_shift_rt:
+    """With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
+    def __init__(self, prob, mean, std):
+        self.prob = prob
+        self.mean = torch.tensor(float(mean))
+        self.std = float(std)
+        
+    def __call__(self,x):
+        if np.random.rand()>self.prob:
+            shift = torch.normal(self.mean,self.std)
+            return transforms.functional.affine(x,0,[0,shift],1,[0,0])
+        return x
 
 
 def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
@@ -175,6 +209,7 @@ class ImageFolderDuo(data.Dataset):
         if self.transform is not None:
             imgAER = self.transform(imgAER)
             imgANA = self.transform(imgANA)
+            
         if self.target_transform is not None:
             target = self.target_transform(target)
         return imgAER, imgANA, target
@@ -184,7 +219,10 @@ class ImageFolderDuo(data.Dataset):
 
 def load_data_duo(base_dir, batch_size, args, shuffle=True):
     train_transform = transforms.Compose(
-        [transforms.Resize((224, 224)),
+        [Random_erasing(args.augment_args[0], args.augment_args[3]),
+         Random_int_noise(args.augment_args[1], args.augment_args[4]),
+         Random_shift_rt(args.augment_args[2], args.augment_args[5], args.augment_args[6]),
+         transforms.Resize((224, 224)),
          Threshold_noise(args.noise_threshold),
          Log_normalisation(),
          transforms.Normalize(0.5, 0.5)])
@@ -199,8 +237,9 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True):
 
     train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
     val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
-    repart_weights = [args.classes_names[i] for i in range(len(args.classes_names)) for k in range(args.classes_numbers[i])]
-    iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=42)
+    classes_name = os.listdir(args.dataset_dir)
+    repart_weights = [name for name in classes_name for k in range(len(os.listdir(args.dataset_dir+"/"+name))//2)]
+    iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=args.random_state)
     train_dataset = torch.utils.data.Subset(train_dataset, iTrain)
     val_dataset = torch.utils.data.Subset(val_dataset, iVal)
     # generator1 = torch.Generator().manual_seed(42)
@@ -217,7 +256,6 @@ def load_data_duo(base_dir, batch_size, args, shuffle=True):
         collate_fn=None,
         pin_memory=False,
     )
-
     data_loader_test = data.DataLoader(
         dataset=val_dataset,
         batch_size=batch_size,
diff --git a/main.py b/main.py
index 7d02e1057741c7ba49b051a1bade633c1a58a371..6aa0b89809a21b52e9f2e7c924533ef89aae8052 100644
--- a/main.py
+++ b/main.py
@@ -2,6 +2,7 @@ import matplotlib.pyplot as plt
 import numpy as np
 from config.config import load_args
 from dataset.dataset import load_data, load_data_duo
+from os import listdir
 import torch
 import torch.nn as nn
 from models.model import Classification_model, Classification_model_duo
@@ -198,6 +199,7 @@ def run_duo(args):
     memPrint(1)
     data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size,args=args)
     #load model
+    
     memPrint(2)
     model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
     memPrint(3)
@@ -220,7 +222,7 @@ def run_duo(args):
     val_loss=[]
     #init training
     if args.weighted_entropy:
-        loss_weights = torch.tensor([1/n for n in args.classes_numbers])
+        loss_weights = torch.tensor([2/len(listdir(args.dataset_dir+"/"+class_name)) for class_name in listdir(args.dataset_dir)])
         if torch.cuda.is_available():
             loss_weights = loss_weights.cuda()
         loss_function = nn.CrossEntropyLoss(loss_weights)
@@ -245,23 +247,25 @@ def run_duo(args):
             if loss  < best_loss :
                 save_model(model,args.save_path)
                 best_loss = loss
+                best_acc = acc
     # plot and save training figs
     plt.plot(train_loss)
     plt.plot(val_loss)
     plt.plot(train_loss)
     plt.plot(train_loss)
-    plt.ylim(0, 0.025)
-    plt.show()
 
-    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+    plt.savefig(f'output/training_plot_model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
     #load and evaluate best model
     load_model(model, args.save_path)
-    make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
-
+    make_prediction_duo(model,data_test, f'output/model_{args.model}_noise_{args.noise_threshold}_lr_{args.lr}_optim_{args.optim + ("_momentum_"+str(args.momentum) if args.optim=="SGD" else "_betas_" + str(args.beta1)+ "_" +str(args.beta2))}.png')
+    return best_loss,best_acc
 
 def make_prediction_duo(model, data, f_name):
     y_pred = []
     y_true = []
+    softmaxes = []
+    classes = data.dataset.dataset.classes
+    print(classes)
     # iterate over test data
     for imaer,imana, label in data:
         label = label.long()
@@ -270,17 +274,28 @@ def make_prediction_duo(model, data, f_name):
         if torch.cuda.is_available():
             imaer = imaer.cuda()
             imana = imana.cuda()
-            label = label.cuda()
         output = model(imaer,imana)
-        
+        softmaxes.extend(nn.functional.softmax(output.detach(),1).cpu().numpy())
         output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
         y_pred.extend(output)
 
-        label = label.data.cpu().numpy()
-        y_true.extend(label)  # Save Truth
+        y_true.extend(label.numpy())  # Save Truth
     # constant for classes
 
-    classes = data.dataset.dataset.classes
+    confiance = np.zeros((len(classes),len(classes)))
+    variance = np.zeros((len(classes),len(classes)))
+    for i in range(len(softmaxes)):
+        confiance[y_true[i],:] += softmaxes[i]
+    denominator = confiance.sum(1,keepdims=True)
+    print(confiance)
+    print(denominator)
+    confiance = confiance/denominator
+    
+    for i in range(len(softmaxes)):
+        variance[y_true[i],:] += (confiance[y_true[i],:]-softmaxes[i])**2
+    variance /= denominator
+    variance = variance**0.5
+    
     # Build confusion matrix
     print(len(y_true),len(y_pred))
     cf_matrix = confusion_matrix(y_true, y_pred)
@@ -289,7 +304,19 @@ def make_prediction_duo(model, data, f_name):
     print('Saving Confusion Matrix')
     plt.figure(figsize=(14, 9))
     sn.heatmap(df_cm, annot=cf_matrix)
-    plt.savefig(f_name)
+    confuName = f_name.split("/")
+    confuName[-1] = "confusion_matrix_"+confuName[-1]
+    plt.savefig('/'.join(confuName))
+    
+    print('Saving Confidence Matrix')
+    confiance_df = pd.DataFrame(confiance, index=[i for i in classes],
+                         columns=[i for i in classes])
+    plt.figure(figsize=(14, 9))
+    sn.heatmap(confiance_df, annot=confiance.astype("<U4")+np.full(confiance.shape,"\u00B1")+variance.astype("<U4"),fmt='')
+    confiName = f_name.split("/")
+    confiName[-1] = "confiance_matrix_"+confiName[-1]
+    plt.savefig('/'.join(confiName))
+
 
 
 def save_model(model, path):
diff --git a/models/model.py b/models/model.py
index f8f0bf69cc9f883857d42034b2abab201498d50f..7e86a989a765edc422fda2706ba05c2d5eb3ed65 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,6 +1,5 @@
 import torch
 import torch.nn as nn
-import torchvision
 
 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
     """3x3 convolution with padding"""
diff --git a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index f41c7b675da69dc7810225603dc42e3e1b1cb654..0000000000000000000000000000000000000000
Binary files a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index b374ff872f64700de0913ce9275856810d89259f..0000000000000000000000000000000000000000
Binary files a/output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index 68fc78ad740a4f9816ae5ab7d88fd59335a19924..0000000000000000000000000000000000000000
Binary files a/output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index 958aae0a2189e7fa0f1fc2eecd45463504a8f2ac..0000000000000000000000000000000000000000
Binary files a/output/confusion_matrix_noise_200_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index 428cb6c883a0b19996d2e61b8c0627737217d20c..0000000000000000000000000000000000000000
Binary files a/output/confusion_matrix_noise_500_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index bdf5e978c89ff252061f9e0d045ff31c61f4e970..0000000000000000000000000000000000000000
Binary files a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index 5a3fcbbd21927c8c8b53442e57df342a8edf1316..0000000000000000000000000000000000000000
Binary files a/output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index d8a704f018ec661846dfca786117fee547aa431c..0000000000000000000000000000000000000000
Binary files a/output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index 8221c44e612d806c78ccce7c7c29da89cc46a5db..0000000000000000000000000000000000000000
Binary files a/output/training_plot_noise_200_lr_0.001_model_ResNet18_duo.png and /dev/null differ
diff --git a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png
deleted file mode 100644
index afd1849cb0c80de776239ffbe72a026ce04e91b6..0000000000000000000000000000000000000000
Binary files a/output/training_plot_noise_500_lr_0.001_model_ResNet18_duo.png and /dev/null differ