diff --git a/config/config.py b/config/config.py index 5c79da036ed9738cfab704960f83d42b757f4656..fca846dabdb432fa0b0e7253be5b9e3642b6bfed 100644 --- a/config/config.py +++ b/config/config.py @@ -12,7 +12,7 @@ def load_args(): parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model_type', type=str, default='duo') - parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') + parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training') parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--pretrain_path', type=str, default=None) diff --git a/dataset/dataset.py b/dataset/dataset.py index b937657cb0d030af8b56bc9579c3dee3cdfec6ca..1512488f9aa1f3b9873029c8c2053ca3b10ce310 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torchvision import torchvision.transforms as transforms @@ -9,7 +10,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union from pathlib import Path from collections import OrderedDict -IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") +IMG_EXTENSIONS = ".npy" class Threshold_noise: """Remove intensities under given threshold""" @@ -85,6 +86,11 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): def default_loader(path): return Image.open(path).convert('RGB') +def npy_loader(path): + sample = torch.from_numpy(np.load(path)) + sample = sample.unsqueeze(0) + return sample + def remove_aer_ana(l): l = map(lambda x : x.split('_')[0],l) return list(OrderedDict.fromkeys(l)) @@ -132,8 +138,8 @@ def make_dataset_custom( for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): fnames_base = remove_aer_ana(fnames) for fname in sorted(fnames_base): - fname_ana = fname+'_ANA.png' - fname_aer = fname + '_AER.png' + fname_ana = fname+'_ANA.npy' + fname_aer = fname + '_AER.npy' path_ana = os.path.join(root, fname_ana) path_aer = os.path.join(root, fname_aer) if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): @@ -155,7 +161,7 @@ def make_dataset_custom( class ImageFolderDuo(data.Dataset): def __init__(self, root, transform=None, target_transform=None, - flist_reader=make_dataset_custom, loader=default_loader): + flist_reader=make_dataset_custom, loader=npy_loader): self.root = root self.imlist = flist_reader(root) self.transform = transform @@ -180,18 +186,14 @@ class ImageFolderDuo(data.Dataset): def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): train_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224, 224)), + [transforms.Resize((224, 224)), Threshold_noise(noise_threshold), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default train transform') val_transform = transforms.Compose( - [transforms.Grayscale(num_output_channels=1), - transforms.ToTensor(), - transforms.Resize((224, 224)), + [transforms.Resize((224, 224)), Threshold_noise(noise_threshold), Log_normalisation(), transforms.Normalize(0.5, 0.5)]) diff --git a/main.py b/main.py index 07cc7e36e67402d8cb9a4e174185f327c93cbe8d..30f7843b0f6285ad5c47bf5d31e0176e2666e70e 100644 --- a/main.py +++ b/main.py @@ -177,6 +177,7 @@ def test_duo(model, data_test, loss_function, epoch): def run_duo(args): data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size) model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes)) + model.double() if args.pretrain_path is not None : load_model(model,args.pretrain_path) if torch.cuda.is_available(): diff --git a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png b/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index f8866cfa301d83e4352666475741f68bc412cab1..0000000000000000000000000000000000000000 Binary files a/output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ diff --git a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png b/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png deleted file mode 100644 index 7703bfce9a571e063ccddc9c850e35597ccbaa45..0000000000000000000000000000000000000000 Binary files a/output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png and /dev/null differ