Skip to content
Snippets Groups Projects
Commit b7a12ac7 authored by Schneider Leo's avatar Schneider Leo
Browse files

input type to .npy (from .png)

parent cbca2cb8
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,7 @@ def load_args(): ...@@ -12,7 +12,7 @@ def load_args():
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt') parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None) parser.add_argument('--pretrain_path', type=str, default=None)
......
import numpy as np
import torch import torch
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
...@@ -9,7 +10,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union ...@@ -9,7 +10,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path from pathlib import Path
from collections import OrderedDict from collections import OrderedDict
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") IMG_EXTENSIONS = ".npy"
class Threshold_noise: class Threshold_noise:
"""Remove intensities under given threshold""" """Remove intensities under given threshold"""
...@@ -85,6 +86,11 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0): ...@@ -85,6 +86,11 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
def default_loader(path): def default_loader(path):
return Image.open(path).convert('RGB') return Image.open(path).convert('RGB')
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def remove_aer_ana(l): def remove_aer_ana(l):
l = map(lambda x : x.split('_')[0],l) l = map(lambda x : x.split('_')[0],l)
return list(OrderedDict.fromkeys(l)) return list(OrderedDict.fromkeys(l))
...@@ -132,8 +138,8 @@ def make_dataset_custom( ...@@ -132,8 +138,8 @@ def make_dataset_custom(
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
fnames_base = remove_aer_ana(fnames) fnames_base = remove_aer_ana(fnames)
for fname in sorted(fnames_base): for fname in sorted(fnames_base):
fname_ana = fname+'_ANA.png' fname_ana = fname+'_ANA.npy'
fname_aer = fname + '_AER.png' fname_aer = fname + '_AER.npy'
path_ana = os.path.join(root, fname_ana) path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer) path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
...@@ -155,7 +161,7 @@ def make_dataset_custom( ...@@ -155,7 +161,7 @@ def make_dataset_custom(
class ImageFolderDuo(data.Dataset): class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=default_loader): flist_reader=make_dataset_custom, loader=npy_loader):
self.root = root self.root = root
self.imlist = flist_reader(root) self.imlist = flist_reader(root)
self.transform = transform self.transform = transform
...@@ -180,18 +186,14 @@ class ImageFolderDuo(data.Dataset): ...@@ -180,18 +186,14 @@ class ImageFolderDuo(data.Dataset):
def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
train_transform = transforms.Compose( train_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), [transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Resize((224, 224)),
Threshold_noise(noise_threshold), Threshold_noise(noise_threshold),
Log_normalisation(), Log_normalisation(),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
print('Default train transform') print('Default train transform')
val_transform = transforms.Compose( val_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), [transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Resize((224, 224)),
Threshold_noise(noise_threshold), Threshold_noise(noise_threshold),
Log_normalisation(), Log_normalisation(),
transforms.Normalize(0.5, 0.5)]) transforms.Normalize(0.5, 0.5)])
......
...@@ -177,6 +177,7 @@ def test_duo(model, data_test, loss_function, epoch): ...@@ -177,6 +177,7 @@ def test_duo(model, data_test, loss_function, epoch):
def run_duo(args): def run_duo(args):
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size) data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes)) model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
model.double()
if args.pretrain_path is not None : if args.pretrain_path is not None :
load_model(model,args.pretrain_path) load_model(model,args.pretrain_path)
if torch.cuda.is_available(): if torch.cuda.is_available():
......
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png

29.4 KiB

output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png

8.03 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment