Skip to content
Snippets Groups Projects
Commit 5b8b27c9 authored by Schneider Leo's avatar Schneider Leo
Browse files

add : peak erasing transform

parent 45055883
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ def load_args_barlow(): ...@@ -9,6 +9,7 @@ def load_args_barlow():
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--test_inter', type=int, default=10) parser.add_argument('--test_inter', type=int, default=10)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--prob_erasing', type=float, default=0.3)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--lambd', type=float, default=0.005) parser.add_argument('--lambd', type=float, default=0.005)
parser.add_argument('--opti', type=str, default='adam') parser.add_argument('--opti', type=str, default='adam')
......
...@@ -12,6 +12,7 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union ...@@ -12,6 +12,7 @@ from typing import Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
from torch.distributions.utils import probs_to_logits
from torch.utils.data import WeightedRandomSampler from torch.utils.data import WeightedRandomSampler
IMG_EXTENSIONS = ".npy" IMG_EXTENSIONS = ".npy"
...@@ -32,6 +33,18 @@ class Random_shift_rt: ...@@ -32,6 +33,18 @@ class Random_shift_rt:
return x return x
class Random_erasing_pixel:
def __init__(self,prob):
self.prob=prob
def __call__(self,x):
shape = x.size()
# Generate a random tensor of size (3, 3) with values 0 or 1
prob_array = torch.zeros_like(x)+self.prob
value_mask = torch.bernoulli(prob_array)
return x*value_mask
def npy_loader(path): def npy_loader(path):
sample = torch.from_numpy(np.load(path)) sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0) sample = sample.unsqueeze(0)
...@@ -213,11 +226,12 @@ class ImageFolderDuo(data.Dataset): ...@@ -213,11 +226,12 @@ class ImageFolderDuo(data.Dataset):
def __len__(self): def __len__(self):
return len(self.imlist) return len(self.imlist)
def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, prob_erasing=0, shuffle=True,ref_dir = None, sampler=None):
train_transform = transforms.Compose( train_transform = transforms.Compose(
[Random_shift_rt(1,0,15), [Random_shift_rt(1,0,15),
Random_erasing_pixel(prob=prob_erasing),
transforms.Resize((224, 224)), transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
......
...@@ -254,6 +254,7 @@ def run(): ...@@ -254,6 +254,7 @@ def run():
base_dir_val=args.dataset_val_dir, base_dir_val=args.dataset_val_dir,
base_dir_test=args.dataset_test_dir, base_dir_test=args.dataset_test_dir,
batch_size=args.batch_size, batch_size=args.batch_size,
prob_erasing=args.prob_erasing,
ref_dir=args.dataset_ref_dir, ref_dir=args.dataset_ref_dir,
sampler=args.sampler)) sampler=args.sampler))
......
...@@ -52,6 +52,19 @@ class Intensity_shift: ...@@ -52,6 +52,19 @@ class Intensity_shift:
return torch.mul(x,intensity_ratio_map) return torch.mul(x,intensity_ratio_map)
class Random_shift_rt:
"""With a probability prob, shifts verticaly the image depending on a gaussian distribution"""
def __init__(self, prob, mean, std):
self.prob = prob
self.mean = torch.tensor(float(mean))
self.std = float(std)
def __call__(self, x):
if np.random.rand() < self.prob:
shift = torch.normal(self.mean, self.std)
return transforms.functional.affine(x, 0, [0, shift], 1, [0, 0])
return x
def default_loader(path): def default_loader(path):
return Image.open(path).convert('RGB') return Image.open(path).convert('RGB')
...@@ -189,8 +202,7 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff ...@@ -189,8 +202,7 @@ def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuff
ref_transform = transforms.Compose( ref_transform = transforms.Compose(
[transforms.Resize((224, 224)), [transforms.Resize((224, 224)),
Threshold_noise(0), Random_shift_rt(1,0,10),
Log_normalisation(),
transforms.Normalize(0.5, 0.5) transforms.Normalize(0.5, 0.5)
]) ])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment