diff --git a/dataset/dataset.py b/dataset/dataset.py
index e5a1a7b22715124b657750e5074fe2981a78e108..b937657cb0d030af8b56bc9579c3dee3cdfec6ca 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -86,7 +86,7 @@ def default_loader(path):
     return Image.open(path).convert('RGB')
 
 def remove_aer_ana(l):
-    l = l.map(lambda x : x.split('_')[0])
+    l = map(lambda x : x.split('_')[0],l)
     return list(OrderedDict.fromkeys(l))
 
 def make_dataset_custom(
@@ -118,7 +118,7 @@ def make_dataset_custom(
     if extensions is not None:
 
         def is_valid_file(x: str) -> bool:
-            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+            return torchvision.datasets.folder.has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
 
     is_valid_file = cast(Callable[[str], bool], is_valid_file)
 
@@ -136,7 +136,7 @@ def make_dataset_custom(
                 fname_aer = fname + '_AER.png'
                 path_ana = os.path.join(root, fname_ana)
                 path_aer = os.path.join(root, fname_aer)
-                if is_valid_file(path_ana) and is_valid_file(path_aer):
+                if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
                     item = path_aer, path_ana, class_index
                     instances.append(item)
 
@@ -161,12 +161,12 @@ class ImageFolderDuo(data.Dataset):
         self.transform = transform
         self.target_transform = target_transform
         self.loader = loader
-        self.classes = torchvision.datasets.folder.find_classes(root)
+        self.classes = torchvision.datasets.folder.find_classes(root)[0]
 
     def __getitem__(self, index):
         impathAER, impathANA, target = self.imlist[index]
-        imgAER = self.loader(os.path.join(self.root, impathAER))
-        imgANA = self.loader(os.path.join(self.root, impathANA))
+        imgAER = self.loader(impathAER)
+        imgANA = self.loader(impathANA)
         if self.transform is not None:
             imgAER = self.transform(imgAER)
             imgANA = self.transform(imgANA)
@@ -196,8 +196,8 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
          Log_normalisation(),
          transforms.Normalize(0.5, 0.5)])
     print('Default val transform')
-    train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform)
-    val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform)
+    train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
+    val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
     generator1 = torch.Generator().manual_seed(42)
     indices = torch.randperm(len(train_dataset), generator=generator1)
     val_size = len(train_dataset) // 5
diff --git a/main.py b/main.py
index 13946a3ba0746cf702c8be77b993710b04517c48..84240bfe981ded28e301f923e21854b160622b47 100644
--- a/main.py
+++ b/main.py
@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
 import numpy as np
 
 from config.config import load_args
-from dataset.dataset import load_data
+from dataset.dataset import load_data, load_data_duo
 import torch
 import torch.nn as nn
-from models.model import Classification_model
+from models.model import Classification_model, Classification_model_duo
 import torch.optim as optim
 from sklearn.metrics import confusion_matrix
 import seaborn as sn
@@ -88,6 +88,7 @@ def run(args):
     plt.plot(val_acc)
     plt.plot(train_acc)
     plt.plot(train_acc)
+    plt.ylim(0, 1.05)
     plt.show()
     plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
 
@@ -124,6 +125,121 @@ def make_prediction(model, data, f_name):
     plt.savefig(f_name)
 
 
+def train_duo(model, data_train, optimizer, loss_function, epoch):
+    model.train()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = True
+
+    for imaer,imana, label in data_train:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+        pred_logits = model.forward(imaer,imana)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    losses = losses/len(data_train.dataset)
+    acc = acc/len(data_train.dataset)
+    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses, acc
+
+def test_duo(model, data_test, loss_function, epoch):
+    model.eval()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for imaer,imana, label in data_test:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+        pred_logits = model.forward(imaer,imana)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+    losses = losses/len(data_test.dataset)
+    acc = acc/len(data_test.dataset)
+    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses,acc
+
+def run_duo(args):
+    data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
+    model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
+    if args.pretrain_path is not None :
+        load_model(model,args.pretrain_path)
+    if torch.cuda.is_available():
+        model = model.cuda()
+    best_acc = 0
+    train_acc=[]
+    train_loss=[]
+    val_acc=[]
+    val_loss=[]
+    loss_function = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+    for e in range(args.epoches):
+        loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
+        train_loss.append(loss)
+        train_acc.append(acc)
+        if e%args.eval_inter==0 :
+            loss, acc = test_duo(model,data_test,loss_function,e)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            if acc > best_acc :
+                save_model(model,args.save_path)
+                best_acc = acc
+    plt.plot(train_acc)
+    plt.plot(val_acc)
+    plt.plot(train_acc)
+    plt.plot(train_acc)
+    plt.ylim(0, 1.05)
+    plt.show()
+    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
+
+    load_model(model, args.save_path)
+    make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
+
+
+def make_prediction_duo(model, data, f_name):
+    y_pred = []
+    y_true = []
+
+    # iterate over test data
+    for imaer,imana, label in data:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+        output = model(imaer,imana)
+
+        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
+        y_pred.extend(output)
+
+        label = label.data.cpu().numpy()
+        y_true.extend(label)  # Save Truth
+    # constant for classes
+
+    classes = data.dataset.dataset.classes
+
+    # Build confusion matrix
+    cf_matrix = confusion_matrix(y_true, y_pred)
+    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
+                         columns=[i for i in classes])
+    plt.figure(figsize=(12, 7))
+    sn.heatmap(df_cm, annot=True)
+    plt.savefig(f_name)
+
+
 def save_model(model, path):
     print('Model saved')
     torch.save(model.state_dict(), path)
@@ -135,4 +251,4 @@ def load_model(model, path):
 
 if __name__ == '__main__':
     args = load_args()
-    run(args)
\ No newline at end of file
+    run_duo(args)
\ No newline at end of file
diff --git a/models/model.py b/models/model.py
index 192120347961c6c5e1555d3e80c67f2389228206..cea9caa2a300f4832e708b618bdca4fff20aca10 100644
--- a/models/model.py
+++ b/models/model.py
@@ -270,4 +270,22 @@ class Classification_model(nn.Module):
 
 
     def forward(self, input):
-        return self.im_encoder(input)
\ No newline at end of file
+        return self.im_encoder(input)
+
+class Classification_model_duo(nn.Module):
+
+    def __init__(self, model, n_class, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.n_class = n_class
+        if model =='ResNet18':
+            self.im_encoder = resnet18(num_classes=self.n_class)
+
+        self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
+
+
+    def forward(self, input_aer, input_ana):
+        out_aer =  self.im_encoder(input_aer)
+        out_ana = self.im_encoder(input_ana)
+        out = torch.concat([out_aer,out_ana],dim=1)
+        return self.predictor(out)
+
diff --git a/output/training_plot.png b/output/training_plot.png
deleted file mode 100644
index 98bd0695f3fa500dbed85dc1907b46a8c6efac46..0000000000000000000000000000000000000000
Binary files a/output/training_plot.png and /dev/null differ