diff --git a/dataset/dataset.py b/dataset/dataset.py index 9f84e631c11a608f0602633cc0fc6b93a90b611b..5e054286fb4747a8d18e3614ae9074a1b32c6190 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -34,6 +34,11 @@ class Random_shift_rt(): pass +def npy_loader(path): + sample = torch.from_numpy(np.load(path)) + sample = sample.unsqueeze(0) + return sample + def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): train_transform = transforms.Compose( [transforms.Grayscale(num_output_channels=1), @@ -52,8 +57,8 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) - val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) + train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform,loader=npy_loader) + val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform, loader=npy_loader) train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=42, shuffle=True, stratify=True) @@ -86,10 +91,6 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): def default_loader(path): return Image.open(path).convert('RGB') -def npy_loader(path): - sample = torch.from_numpy(np.load(path)) - sample = sample.unsqueeze(0) - return sample def remove_aer_ana(l): l = map(lambda x : x.split('_')[0],l) diff --git a/main_ray.py b/main_ray.py index 6734ecafe76008e5b48c822efa667ef6b75c5260..2a32229022c9f1059969bfbd0e45f60123ae4ae2 100644 --- a/main_ray.py +++ b/main_ray.py @@ -26,7 +26,7 @@ def train_model(config,args): ) # load model - model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.classes)) + model = Classification_model_duo(model=args.model, n_class=len(data_train.dataset.dataset.classes)) # move parameters to GPU model.float() @@ -144,7 +144,7 @@ def test_model(best_result, args): noise_threshold=best_result.config['noise']) # load model - model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.classes)) + model = Classification_model_duo(model=args.model, n_class=len(data_test.dataset.dataset.classes)) model.float() # load weight checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")