From 2886a78122baa0bb1f707fe4e0742ae10d51b7ec Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Wed, 21 May 2025 13:16:45 +0200 Subject: [PATCH] add : rt shift transform --- barlow_twin_like/dataset_barlow.py | 37 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py index 611fa78..3b8c777 100644 --- a/barlow_twin_like/dataset_barlow.py +++ b/barlow_twin_like/dataset_barlow.py @@ -16,6 +16,22 @@ from torch.utils.data import WeightedRandomSampler IMG_EXTENSIONS = ".npy" + +class Random_shift_rt: + """With a probability prob, shifts verticaly the image depending on a gaussian distribution""" + + def __init__(self, prob, mean, std): + self.prob = prob + self.mean = torch.tensor(float(mean)) + self.std = float(std) + + def __call__(self, x): + if np.random.rand() < self.prob: + shift = torch.normal(self.mean, self.std) + return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0]) + return x + + def npy_loader(path): sample = torch.from_numpy(np.load(path)) sample = sample.unsqueeze(0) @@ -200,22 +216,27 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): - transform = transforms.Compose( + train_transform = transforms.Compose( + [Random_shift_rt(1,0,15), + transforms.Resize((224, 224)), + transforms.Normalize(0.5, 0.5)]) + + val_transform = transforms.Compose( [transforms.Resize((224, 224)), - transforms.Normalize(0.5, 0.5)]) + transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=transform, ref_transform=transform) - val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=transform, ref_transform=transform) + train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=train_transform, ref_transform=train_transform) + val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=train_transform, ref_transform=val_transform) - train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=transform) - val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=transform) + train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=train_transform) + val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=val_transform) if base_dir_test is not None : - test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=transform, ref_transform=transform) + test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=val_transform, ref_transform=val_transform) - test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=transform) + test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=val_transform) -- GitLab