diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py index 611fa78f05d08a8cc85f7ff776ee15f2a19471a2..3b8c77720633a4537e6caf7be09746170ef0ebc2 100644 --- a/barlow_twin_like/dataset_barlow.py +++ b/barlow_twin_like/dataset_barlow.py @@ -16,6 +16,22 @@ from torch.utils.data import WeightedRandomSampler IMG_EXTENSIONS = ".npy" + +class Random_shift_rt: + """With a probability prob, shifts verticaly the image depending on a gaussian distribution""" + + def __init__(self, prob, mean, std): + self.prob = prob + self.mean = torch.tensor(float(mean)) + self.std = float(std) + + def __call__(self, x): + if np.random.rand() < self.prob: + shift = torch.normal(self.mean, self.std) + return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0]) + return x + + def npy_loader(path): sample = torch.from_numpy(np.load(path)) sample = sample.unsqueeze(0) @@ -200,22 +216,27 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): - transform = transforms.Compose( + train_transform = transforms.Compose( + [Random_shift_rt(1,0,15), + transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.5)]) + + val_transform = transforms.Compose( [transforms.Resize((224, 224)), - transforms.Normalize(0.5, 0.5)]) + transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=transform, ref_transform=transform) - val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=transform, ref_transform=transform) + train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=train_transform, ref_transform=train_transform) + val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=train_transform, ref_transform=val_transform) - train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=transform) - val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=transform) + train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=train_transform) + val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=val_transform) if base_dir_test is not None : - test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=transform, ref_transform=transform) + test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=val_transform, ref_transform=val_transform) - test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=transform) + test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=val_transform)