Skip to content
Snippets Groups Projects
Commit b7a12ac7 authored by Schneider Leo's avatar Schneider Leo
Browse files

input type to .npy (from .png)

parent cbca2cb8
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ def load_args():
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/png_image/data_training')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
......
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
......@@ -9,7 +10,7 @@ from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from pathlib import Path
from collections import OrderedDict
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
IMG_EXTENSIONS = ".npy"
class Threshold_noise:
"""Remove intensities under given threshold"""
......@@ -85,6 +86,11 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
def default_loader(path):
return Image.open(path).convert('RGB')
def npy_loader(path):
sample = torch.from_numpy(np.load(path))
sample = sample.unsqueeze(0)
return sample
def remove_aer_ana(l):
l = map(lambda x : x.split('_')[0],l)
return list(OrderedDict.fromkeys(l))
......@@ -132,8 +138,8 @@ def make_dataset_custom(
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
fnames_base = remove_aer_ana(fnames)
for fname in sorted(fnames_base):
fname_ana = fname+'_ANA.png'
fname_aer = fname + '_AER.png'
fname_ana = fname+'_ANA.npy'
fname_aer = fname + '_AER.npy'
path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
......@@ -155,7 +161,7 @@ def make_dataset_custom(
class ImageFolderDuo(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
flist_reader=make_dataset_custom, loader=default_loader):
flist_reader=make_dataset_custom, loader=npy_loader):
self.root = root
self.imlist = flist_reader(root)
self.transform = transform
......@@ -180,18 +186,14 @@ class ImageFolderDuo(data.Dataset):
def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
train_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Resize((224, 224)),
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default train transform')
val_transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Resize((224, 224)),
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
......
......@@ -177,6 +177,7 @@ def test_duo(model, data_test, loss_function, epoch):
def run_duo(args):
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
model.double()
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
if torch.cuda.is_available():
......
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png

29.4 KiB

output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png

8.03 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment