Skip to content
Snippets Groups Projects
Commit 2886a781 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : rt shift transform

parent 18fe8085
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,22 @@ from torch.utils.data import WeightedRandomSampler ...@@ -16,6 +16,22 @@ from torch.utils.data import WeightedRandomSampler
IMG_EXTENSIONS = ".npy" IMG_EXTENSIONS = ".npy"
class Random_shift_rt:
"""With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
def __init__(self, prob, mean, std):
self.prob = prob
self.mean = torch.tensor(float(mean))
self.std = float(std)
def __call__(self, x):
if np.random.rand() < self.prob:
shift = torch.normal(self.mean, self.std)
return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
return x
def npy_loader(path): def npy_loader(path):
sample = torch.from_numpy(np.load(path)) sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0) sample = sample.unsqueeze(0)
...@@ -200,22 +216,27 @@ class ImageFolderDuo(data.Dataset): ...@@ -200,22 +216,27 @@ class ImageFolderDuo(data.Dataset):
def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None):
transform = transforms.Compose( train_transform = transforms.Compose(
[Random_shift_rt(1,0,15),
transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)])
val_transform = transforms.Compose(
[transforms.Resize((224, 224)), [transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
print('Default val transform') print('Default val transform')
train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=transform, ref_transform=transform) train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir,transform=train_transform, ref_transform=train_transform)
val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=transform, ref_transform=transform) val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir,transform=train_transform, ref_transform=val_transform)
train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=transform) train_dataset_classifier = ImageFolderDuo(root=base_dir_train,transform=train_transform)
val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=transform) val_dataset_classifier = ImageFolderDuo(root=base_dir_val,transform=val_transform)
if base_dir_test is not None : if base_dir_test is not None :
test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=transform, ref_transform=transform) test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir,transform=val_transform, ref_transform=val_transform)
test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=transform) test_dataset_classifier = ImageFolderDuo(root=base_dir_test,transform=val_transform)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment