diff --git a/config/config.py b/config/config.py index bad9ecbb069ec9c572c7594f4be0f9a3aec353b1..d898b6bbef2f8ea9c6a9e3bebd3380057cba4ca1 100644 --- a/config/config.py +++ b/config/config.py @@ -7,6 +7,7 @@ def load_args(): parser.add_argument('--epoches', type=int, default=100) parser.add_argument('--save_inter', type=int, default=50) parser.add_argument('--eval_inter', type=int, default=1) + parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') diff --git a/dataset/dataset.py b/dataset/dataset.py index ec59fead2ec74552a18246e363b644762b764d7a..6f807d03000d5328897ff4d009413ab6799a6a0a 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -26,25 +26,24 @@ class Random_shift_rt(): pass -def load_data(base_dir, batch_size, shuffle=True, transform=None): - if transform is None : - train_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224,224)), - Threshold_noise(500), - Log_normalisation(), - transforms.Normalize(0.5, 0.5)]) - print('Default train transform') +def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): + train_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224,224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default train transform') - val_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224,224)), - Threshold_noise(500), - Log_normalisation(), - transforms.Normalize(0.5, 0.5)]) - print('Default val transform') + val_transform = transforms.Compose( + [transforms.Grayscale(num_output_channels=1), + transforms.ToTensor(), + transforms.Resize((224,224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform) val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform) generator1 = torch.Generator().manual_seed(42)