diff --git a/barlow_twin_like/config.py b/barlow_twin_like/config.py index 2f38fd186d86959d688dc56eb38a9c27306dd621..ee4a58b211b7b5a047f62cf8d7ab87cd8acef510 100644 --- a/barlow_twin_like/config.py +++ b/barlow_twin_like/config.py @@ -9,15 +9,16 @@ def load_args_barlow(): parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--test_inter', type=int, default=10) parser.add_argument('--lr', type=float, default=0.001) - parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--lambd', type=float, default=0.005) parser.add_argument('--opti', type=str, default='adam') parser.add_argument('--model', type=str, default='ResNet18') - parser.add_argument('--projector', type=str, default='1024-512-256-128') + parser.add_argument('--projector', type=str, default='256-128-64') parser.add_argument('--sampler', type=str, default=None) - parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/train_data') - parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/test_data') + parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/train data') + parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/val data') parser.add_argument('--dataset_test_dir', type=str, default=None) - parser.add_argument('--base_out', type=str, default='output/best_model_base_ray') + parser.add_argument('--base_out', type=str, default='output/barlow_model') parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_barlow.csv') parser.add_argument('--save_path', type=str, default='output/best_model_barlow.pt') diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py index 274ecdf0bf89399f3b96d63addcac4c3e0e9fa7b..611fa78f05d08a8cc85f7ff776ee15f2a19471a2 100644 --- a/barlow_twin_like/dataset_barlow.py +++ b/barlow_twin_like/dataset_barlow.py @@ -177,13 +177,11 @@ class ImageFolder(data.Dataset): return len(self.imlist) class ImageFolderDuo(data.Dataset): - def __init__(self, root, transform=None, target_transform=None, - flist_reader=make_dataset_base, loader=npy_loader, ref_transform=None): + def __init__(self, root, transform=None, + flist_reader=make_dataset_base, loader=npy_loader): self.root = root self.imlist = flist_reader(root) self.transform = transform - self.target_transform = target_transform - self.ref_transform = ref_transform self.loader = loader self.classes = torchvision.datasets.folder.find_classes(root)[0] @@ -201,18 +199,26 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): + + transform = transforms.Compose( + [transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') - train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir) - val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir) + train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=transform, ref_transform=transform) + val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=transform, ref_transform=transform) - train_dataset_classifier = ImageFolderDuo(root=base_dir_train) - val_dataset_classifier = ImageFolderDuo(root=base_dir_val) + train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=transform) + val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=transform) if base_dir_test is not None : - test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir) + test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=transform, ref_transform=transform) + + test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=transform) + + - test_dataset_classifier = ImageFolderDuo(root=base_dir_test) if sampler =='balanced' : diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 43cc536b2a91e7cee1aac6d2734f9754f5aa6609..2542c308a697f554e67e36a3e4a8631222fb67ba 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -1,8 +1,11 @@ import os - +import seaborn as sn import numpy as np +import pandas as pd import torch import wandb as wdb +from matplotlib import pyplot as plt +from sklearn.metrics import confusion_matrix from torch import optim, nn from model import BarlowTwins, BaseClassifier @@ -71,14 +74,17 @@ def train_classification(model, classifier, data_train, optimizer, epoch, wandb) for param in classifier.parameters(): param.requires_grad = True - for img, label in data_train: - img = img.float() + for imgana, imgaer, label in data_train: + imgana = imgana.float() + imgaer = imgaer.float() label = label.long() if torch.cuda.is_available(): - img = img.cuda() + imgana = imgana.cuda() + imgaer = imgaer.cuda() label = label.cuda() - representation = model(img) - pred_logits = classifier(representation) + representation_ana = model.compute_representation(imgana) + representation_aer = model.compute_representation(imgaer) + pred_logits = classifier(representation_ana, representation_aer) pred_class = torch.argmax(pred_logits, dim=1) acc += (pred_class == label).sum().item() loss = loss_function(pred_logits, label) @@ -104,14 +110,17 @@ def test_classification(model, classifier, data_val, epoch, wandb): for param in classifier.parameters(): param.requires_grad = False - for img, label in data_val: - img = img.float() + for imgana, imgaer, label in data_val: + imgana = imgana.float() + imgaer = imgaer.float() label = label.long() if torch.cuda.is_available(): - img = img.cuda() + imgana = imgana.cuda() + imgaer = imgaer.cuda() label = label.cuda() - representation = model(img) - pred_logits = classifier(representation) + representation_ana = model.compute_representation(imgana) + representation_aer = model.compute_representation(imgaer) + pred_logits = classifier(representation_ana, representation_aer) pred_class = torch.argmax(pred_logits, dim=1) acc += (pred_class == label).sum().item() loss = loss_function(pred_logits, label) @@ -125,6 +134,38 @@ def test_classification(model, classifier, data_val, epoch, wandb): return losses, acc +def make_prediction_duo(model,classifier, data, f_name): + y_pred = [] + y_true = [] + # iterate over test data + for imgana, imgaer, label in data: + imgana = imgana.float() + imgaer = imgaer.float() + if torch.cuda.is_available(): + imgana = imgana.cuda() + imgaer = imgaer.cuda() + representation_ana = model.compute_representation(imgana) + representation_aer = model.compute_representation(imgaer) + pred_logits = classifier(representation_ana, representation_aer) + pred_class = torch.argmax(pred_logits, dim=1) + y_pred+=pred_class.tolist() + y_true+=label.tolist() # Save Truth + # constant for classes + + # Build confusion matrix + classes = data.dataset.classes + cf_matrix = confusion_matrix(y_true, y_pred) + + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + columns=[i for i in classes]) + print('Saving Confusion Matrix') + plt.clf() + plt.figure(figsize=(14, 9)) + sn.heatmap(df_cm, annot=cf_matrix) + plt.savefig(f_name) + + + def run(): args = load_args_barlow() @@ -153,8 +194,9 @@ def run(): sampler=args.sampler)) # load model + n_classes = len(data_val_classifier.dataset.classes) model = BarlowTwins(args) - classifier = BaseClassifier(args) + classifier = BaseClassifier(args,n_classes=n_classes) model.float() classifier.float() # load weight @@ -173,8 +215,8 @@ def run(): optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) best_loss = np.inf - for e in args.epoches: - loss = train_representation(model, data_train, optimizer, e, args.wandb) + for e in range(args.epoches): + _ = train_representation(model, data_train, optimizer, e, args.wandb) if e % args.eval_inter == 0: loss = test_representation(model, data_val, e, args.wandb) if loss < best_loss: @@ -185,9 +227,16 @@ def run(): for param in model.parameters(): # freezing representations before classifier training param.requires_grad = False - for e in args.classification_epoches: + for e in range(args.classification_epoches): train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb) - test_classification() + test_classification(model, classifier, data_val_classifier, e, args.wandb) + + make_prediction_duo(model, classifier, data_val_classifier, args.base_out+'_confusion_matrix_val.png') + + wdb.finish() + + + if __name__ == '__main__': run() \ No newline at end of file diff --git a/barlow_twin_like/model.py b/barlow_twin_like/model.py index 35cb5bb183a546d08f5734d924a339f1297fbd93..8a8a82efe25ebbd30ca89b8420f8eaf380722cc4 100644 --- a/barlow_twin_like/model.py +++ b/barlow_twin_like/model.py @@ -280,7 +280,7 @@ class BarlowTwins(nn.Module): self.backbone.fc = nn.Identity() #remove final fc layer # projector - sizes = [2048] + list(map(int, args.projector.split('-'))) + sizes = [512] + list(map(int, args.projector.split('-'))) layers = [] for i in range(len(sizes) - 2): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) @@ -300,8 +300,8 @@ class BarlowTwins(nn.Module): c = self.bn(z1).T @ self.bn(z2) # sum the cross-correlation matrix between all gpus - c.div_(self.args.batch_size) - torch.distributed.all_reduce(c) + # c.div_(self.args.batch_size) + # torch.distributed.all_reduce(c) on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = off_diagonal(c).pow_(2).sum() @@ -316,7 +316,7 @@ class BaseClassifier(nn.Module): def __init__(self, args,n_classes): super().__init__() self.classifier = nn.Sequential( - nn.Linear(list(map(int, args.projector.split('-')))[-1]*2,n_classes) + nn.Linear(1024,n_classes) ) def forward(self, y1, y2):