diff --git a/barlow_twin_like/config.py b/barlow_twin_like/config.py
index 2f38fd186d86959d688dc56eb38a9c27306dd621..ee4a58b211b7b5a047f62cf8d7ab87cd8acef510 100644
--- a/barlow_twin_like/config.py
+++ b/barlow_twin_like/config.py
@@ -9,15 +9,16 @@ def load_args_barlow():
     parser.add_argument('--eval_inter', type=int, default=1)
     parser.add_argument('--test_inter', type=int, default=10)
     parser.add_argument('--lr', type=float, default=0.001)
-    parser.add_argument('--batch_size', type=int, default=256)
+    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--lambd', type=float, default=0.005)
     parser.add_argument('--opti', type=str, default='adam')
     parser.add_argument('--model', type=str, default='ResNet18')
-    parser.add_argument('--projector', type=str, default='1024-512-256-128')
+    parser.add_argument('--projector', type=str, default='256-128-64')
     parser.add_argument('--sampler', type=str, default=None)
-    parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/train_data')
-    parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/test_data')
+    parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/train data')
+    parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/val data')
     parser.add_argument('--dataset_test_dir', type=str, default=None)
-    parser.add_argument('--base_out', type=str, default='output/best_model_base_ray')
+    parser.add_argument('--base_out', type=str, default='output/barlow_model')
     parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
     parser.add_argument('--output', type=str, default='output/out_barlow.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model_barlow.pt')
diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py
index 274ecdf0bf89399f3b96d63addcac4c3e0e9fa7b..611fa78f05d08a8cc85f7ff776ee15f2a19471a2 100644
--- a/barlow_twin_like/dataset_barlow.py
+++ b/barlow_twin_like/dataset_barlow.py
@@ -177,13 +177,11 @@ class ImageFolder(data.Dataset):
         return len(self.imlist)
 
 class ImageFolderDuo(data.Dataset):
-    def __init__(self, root, transform=None, target_transform=None,
-                 flist_reader=make_dataset_base, loader=npy_loader, ref_transform=None):
+    def __init__(self, root, transform=None,
+                 flist_reader=make_dataset_base, loader=npy_loader):
         self.root = root
         self.imlist = flist_reader(root)
         self.transform = transform
-        self.target_transform = target_transform
-        self.ref_transform = ref_transform
         self.loader = loader
         self.classes = torchvision.datasets.folder.find_classes(root)[0]
 
@@ -201,18 +199,26 @@ class ImageFolderDuo(data.Dataset):
 
 def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None):
 
+
+    transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         transforms.Normalize(0.5, 0.5)])
+
     print('Default val transform')
 
-    train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir)
-    val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir)
+    train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=transform, ref_transform=transform)
+    val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=transform, ref_transform=transform)
 
-    train_dataset_classifier = ImageFolderDuo(root=base_dir_train)
-    val_dataset_classifier  = ImageFolderDuo(root=base_dir_val)
+    train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=transform)
+    val_dataset_classifier  = ImageFolderDuo(root=base_dir_val,transform=transform)
 
     if base_dir_test is not None :
-        test_dataset = ImageFolder(root=base_dir_test,  ref_dir=ref_dir)
+        test_dataset = ImageFolder(root=base_dir_test,  ref_dir=ref_dir,transform=transform, ref_transform=transform)
+
+        test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=transform)
+
+
 
-        test_dataset_classifier = ImageFolderDuo(root=base_dir_test)
 
 
     if sampler =='balanced' :
diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py
index 43cc536b2a91e7cee1aac6d2734f9754f5aa6609..2542c308a697f554e67e36a3e4a8631222fb67ba 100644
--- a/barlow_twin_like/main.py
+++ b/barlow_twin_like/main.py
@@ -1,8 +1,11 @@
 import os
-
+import seaborn as sn
 import numpy as np
+import pandas as pd
 import torch
 import wandb as wdb
+from matplotlib import pyplot as plt
+from sklearn.metrics import confusion_matrix
 from torch import optim, nn
 
 from model import BarlowTwins, BaseClassifier
@@ -71,14 +74,17 @@ def train_classification(model, classifier, data_train, optimizer, epoch, wandb)
     for param in classifier.parameters():
         param.requires_grad = True
 
-    for img, label in data_train:
-        img = img.float()
+    for imgana, imgaer, label in data_train:
+        imgana = imgana.float()
+        imgaer = imgaer.float()
         label = label.long()
         if torch.cuda.is_available():
-            img = img.cuda()
+            imgana = imgana.cuda()
+            imgaer = imgaer.cuda()
             label = label.cuda()
-        representation = model(img)
-        pred_logits = classifier(representation)
+        representation_ana = model.compute_representation(imgana)
+        representation_aer = model.compute_representation(imgaer)
+        pred_logits = classifier(representation_ana, representation_aer)
         pred_class = torch.argmax(pred_logits, dim=1)
         acc += (pred_class == label).sum().item()
         loss = loss_function(pred_logits, label)
@@ -104,14 +110,17 @@ def test_classification(model, classifier, data_val, epoch, wandb):
     for param in classifier.parameters():
         param.requires_grad = False
 
-    for img, label in data_val:
-        img = img.float()
+    for imgana, imgaer, label in data_val:
+        imgana = imgana.float()
+        imgaer = imgaer.float()
         label = label.long()
         if torch.cuda.is_available():
-            img = img.cuda()
+            imgana = imgana.cuda()
+            imgaer = imgaer.cuda()
             label = label.cuda()
-        representation = model(img)
-        pred_logits = classifier(representation)
+        representation_ana = model.compute_representation(imgana)
+        representation_aer = model.compute_representation(imgaer)
+        pred_logits = classifier(representation_ana, representation_aer)
         pred_class = torch.argmax(pred_logits, dim=1)
         acc += (pred_class == label).sum().item()
         loss = loss_function(pred_logits, label)
@@ -125,6 +134,38 @@ def test_classification(model, classifier, data_val, epoch, wandb):
 
     return losses, acc
 
+def make_prediction_duo(model,classifier, data, f_name):
+    y_pred = []
+    y_true = []
+    # iterate over test data
+    for imgana, imgaer, label in data:
+        imgana = imgana.float()
+        imgaer = imgaer.float()
+        if torch.cuda.is_available():
+            imgana = imgana.cuda()
+            imgaer = imgaer.cuda()
+        representation_ana = model.compute_representation(imgana)
+        representation_aer = model.compute_representation(imgaer)
+        pred_logits = classifier(representation_ana, representation_aer)
+        pred_class = torch.argmax(pred_logits, dim=1)
+        y_pred+=pred_class.tolist()
+        y_true+=label.tolist()  # Save Truth
+    # constant for classes
+
+    # Build confusion matrix
+    classes = data.dataset.classes
+    cf_matrix = confusion_matrix(y_true, y_pred)
+
+    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
+                         columns=[i for i in classes])
+    print('Saving Confusion Matrix')
+    plt.clf()
+    plt.figure(figsize=(14, 9))
+    sn.heatmap(df_cm, annot=cf_matrix)
+    plt.savefig(f_name)
+
+
+
 
 def run():
     args = load_args_barlow()
@@ -153,8 +194,9 @@ def run():
                       sampler=args.sampler))
 
     # load model
+    n_classes = len(data_val_classifier.dataset.classes)
     model = BarlowTwins(args)
-    classifier = BaseClassifier(args)
+    classifier = BaseClassifier(args,n_classes=n_classes)
     model.float()
     classifier.float()
     # load weight
@@ -173,8 +215,8 @@ def run():
         optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
 
     best_loss = np.inf
-    for e in args.epoches:
-        loss = train_representation(model, data_train, optimizer, e, args.wandb)
+    for e in range(args.epoches):
+        _ = train_representation(model, data_train, optimizer, e, args.wandb)
         if e % args.eval_inter == 0:
             loss = test_representation(model, data_val, e, args.wandb)
             if loss < best_loss:
@@ -185,9 +227,16 @@ def run():
     for param in model.parameters():  # freezing representations before classifier training
         param.requires_grad = False
 
-    for e in args.classification_epoches:
+    for e in range(args.classification_epoches):
         train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb)
-        test_classification()
+        test_classification(model, classifier, data_val_classifier, e, args.wandb)
+
+    make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png')
+
+    wdb.finish()
+
+
+
 
 if __name__ == '__main__':
     run()
\ No newline at end of file
diff --git a/barlow_twin_like/model.py b/barlow_twin_like/model.py
index 35cb5bb183a546d08f5734d924a339f1297fbd93..8a8a82efe25ebbd30ca89b8420f8eaf380722cc4 100644
--- a/barlow_twin_like/model.py
+++ b/barlow_twin_like/model.py
@@ -280,7 +280,7 @@ class BarlowTwins(nn.Module):
         self.backbone.fc = nn.Identity() #remove final fc layer
 
         # projector
-        sizes = [2048] + list(map(int, args.projector.split('-')))
+        sizes = [512] + list(map(int, args.projector.split('-')))
         layers = []
         for i in range(len(sizes) - 2):
             layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
@@ -300,8 +300,8 @@ class BarlowTwins(nn.Module):
         c = self.bn(z1).T @ self.bn(z2)
 
         # sum the cross-correlation matrix between all gpus
-        c.div_(self.args.batch_size)
-        torch.distributed.all_reduce(c)
+        # c.div_(self.args.batch_size)
+        # torch.distributed.all_reduce(c)
 
         on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
         off_diag = off_diagonal(c).pow_(2).sum()
@@ -316,7 +316,7 @@ class BaseClassifier(nn.Module):
     def __init__(self, args,n_classes):
         super().__init__()
         self.classifier = nn.Sequential(
-            nn.Linear(list(map(int, args.projector.split('-')))[-1]*2,n_classes)
+            nn.Linear(1024,n_classes)
         )
 
     def forward(self, y1, y2):