Skip to content
Snippets Groups Projects
Commit f01e80b6 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : dataset ray base

parent ab986e15
No related branches found
No related tags found
No related merge requests found
...@@ -34,6 +34,11 @@ class Random_shift_rt(): ...@@ -34,6 +34,11 @@ class Random_shift_rt():
pass pass
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
train_transform = transforms.Compose( train_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), [transforms.Grayscale(num_output_channels=1),
...@@ -52,8 +57,8 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): ...@@ -52,8 +57,8 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
Log_normalisation(), Log_normalisation(),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
print('Default val transform') print('Default val transform')
train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform,loader=npy_loader)
val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform, loader=npy_loader)
train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=42, shuffle=True, train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=42, shuffle=True,
stratify=True) stratify=True)
...@@ -86,10 +91,6 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): ...@@ -86,10 +91,6 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
def default_loader(path): def default_loader(path):
return Image.open(path).convert('RGB') return Image.open(path).convert('RGB')
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def remove_aer_ana(l): def remove_aer_ana(l):
l = map(lambda x : x.split('_')[0],l) l = map(lambda x : x.split('_')[0],l)
......
...@@ -26,7 +26,7 @@ def train_model(config,args): ...@@ -26,7 +26,7 @@ def train_model(config,args):
) )
# load model # load model
model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes)) model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.dataset.classes))
# move parameters to GPU # move parameters to GPU
model.float() model.float()
...@@ -144,7 +144,7 @@ def test_model(best_result, args): ...@@ -144,7 +144,7 @@ def test_model(best_result, args):
noise_threshold=best_result.config['noise']) noise_threshold=best_result.config['noise'])
# load model # load model
model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes)) model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.dataset.classes))
model.float() model.float()
# load weight # load weight
checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt") checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment