diff --git a/dataset/dataset.py b/dataset/dataset.py index 435b2044c07fb404cde98fdfe256b3b4f729d66a..ec59fead2ec74552a18246e363b644762b764d7a 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -28,18 +28,33 @@ class Random_shift_rt(): def load_data(base_dir, batch_size, shuffle=True, transform=None): if transform is None : - transform = transforms.Compose( + train_transform = transforms.Compose( [transforms.Grayscale(num_output_channels=1), transforms.ToTensor(), transforms.Resize((224,224)), + Threshold_noise(500), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) - print('default transform') - dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=transform) + print('Default train transform') + + val_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224,224)), + Threshold_noise(500), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') + train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) + val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) generator1 = torch.Generator().manual_seed(42) - data_train, data_test = random_split(dataset, [0.8, 0.2], generator=generator1) + indices = torch.randperm(len(train_dataset),generator=generator1) + val_size = len(train_dataset) // 5 + train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) + val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) + data_loader_train = DataLoader( - dataset=data_train, + dataset=train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, @@ -48,7 +63,7 @@ def load_data(base_dir, batch_size, shuffle=True, transform=None): ) data_loader_test = DataLoader( - dataset=data_test, + dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=0, diff --git a/output.png b/output.png deleted file mode 100644 index c06de6a4542828356cb331664b0b1017ef83949e..0000000000000000000000000000000000000000 Binary files a/output.png and /dev/null differ