diff --git a/barlow_twin_like/config.py b/barlow_twin_like/config.py index a9aeb92a95bc0c883ca972f72bfdd30d1b30eaf8..c7d078e3cf4e67f065bd778c667637d8c7ebe35a 100644 --- a/barlow_twin_like/config.py +++ b/barlow_twin_like/config.py @@ -9,6 +9,7 @@ def load_args_barlow(): parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--test_inter', type=int, default=10) parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--prob_erasing', type=float, default=0.3) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--lambd', type=float, default=0.005) parser.add_argument('--opti', type=str, default='adam') diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py index 3b8c77720633a4537e6caf7be09746170ef0ebc2..d0f93128071f206e35897974b6a53a566c6d43b9 100644 --- a/barlow_twin_like/dataset_barlow.py +++ b/barlow_twin_like/dataset_barlow.py @@ -12,6 +12,7 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict +from torch.distributions.utils import probs_to_logits from torch.utils.data import WeightedRandomSampler IMG_EXTENSIONS = ".npy" @@ -32,6 +33,18 @@ class Random_shift_rt: return x +class Random_erasing_pixel: + def __init__(self,prob): + self.prob=prob + + def __call__(self,x): + shape = x.size() + # Generate a random tensor of size (3, 3) with values 0 or 1 + prob_array = torch.zeros_like(x)+self.prob + value_mask = torch.bernoulli(prob_array) + + return x*value_mask + def npy_loader(path): sample = torch.from_numpy(np.load(path)) sample = sample.unsqueeze(0) @@ -213,11 +226,12 @@ class ImageFolderDuo(data.Dataset): def __len__(self): return len(self.imlist) -def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): +def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, prob_erasing=0, shuffle=True,ref_dir = None, sampler=None): train_transform = transforms.Compose( [Random_shift_rt(1,0,15), + Random_erasing_pixel(prob=prob_erasing), transforms.Resize((224, 224)), transforms.Normalize(0.5, 0.5)]) diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index 02357705e7b03e1d133ff61be56a4007fa5cba64..7d9db6a62de8894938d5c5b7158e30584920bb56 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -254,6 +254,7 @@ def run(): base_dir_val=args.dataset_val_dir, base_dir_test=args.dataset_test_dir, batch_size=args.batch_size, + prob_erasing=args.prob_erasing, ref_dir=args.dataset_ref_dir, sampler=args.sampler)) diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index c48992f6b1e3b0de54b312ca3d38aa2b55915f52..68b5da14a288d30ebc590ee326b290e80b028c83 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -52,6 +52,19 @@ class Intensity_shift: return torch.mul(x,intensity_ratio_map) +class Random_shift_rt: + """With a probability prob, shifts verticaly the image depending on a gaussian distribution""" + + def __init__(self, prob, mean, std): + self.prob = prob + self.mean = torch.tensor(float(mean)) + self.std = float(std) + + def __call__(self, x): + if np.random.rand() < self.prob: + shift = torch.normal(self.mean, self.std) + return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0]) + return x def default_loader(path): return Image.open(path).convert('RGB') @@ -189,8 +202,7 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff ref_transform = transforms.Compose( [transforms.Resize((224, 224)), - Threshold_noise(0), - Log_normalisation(), + Random_shift_rt(1,0,10), transforms.Normalize(0.5, 0.5) ])