Skip to content
Snippets Groups Projects
Commit 7dad5406 authored by Léo Calmettes's avatar Léo Calmettes
Browse files

modifié : config/config.py

	modifié :         dataset/dataset.py
	modifié :         main.py
	modifié :         models/model.py
parent cfc3262b
No related branches found
No related tags found
No related merge requests found
import argparse
import torch
def load_args():
parser = argparse.ArgumentParser()
parser.add_argument('--epoches', type=int, default=3)
parser.add_argument('--save_inter', type=int, default=50)
parser.add_argument('--epoches', type=int, default=20)
parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--noise_threshold', type=int, default=1000)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--optim', type = str, default = "Adam")
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--classes_names', type=list, default = ["Citrobacter freundii","Citrobacter koseri","Enterobacter asburiae","Enterobacter cloacae","Enterobacter hormaechei","Escherichia coli","Klebsiella aerogenes","Klebsiella michiganensis","Klebsiella oxytoca","Klebsiella pneumoniae","Klebsiella quasipneumoniae","Proteus mirabilis","Salmonella enterica"])
parser.add_argument('--classes_numbers', type=list, default = [51,12,9,10,86,231,20,13,24,96,11,39,11])
parser.add_argument('--weighted_entropy', type=bool, default = True)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
parser.add_argument('--dataset_dir', type=str, default='data/fused_data/species_training')
parser.add_argument('--output', type=str, default='output/out.csv')
parser.add_argument('--save_path', type=str, default='output/best_model.pt')
parser.add_argument('--pretrain_path', type=str, default=None)
......
......@@ -54,7 +54,6 @@ def load_data(base_dir, batch_size, shuffle=True, noise_threshold=0):
print('Default val transform')
train_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=train_transform)
val_dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=val_transform)
train_dataset, _ = train_test_split(train_dataset, test_size=None, train_size=None, random_state=42, shuffle=True,
stratify=True)
_, val_dataset = train_test_split(val_dataset, test_size=None, train_size=None, random_state=42, shuffle=True,
......@@ -178,34 +177,37 @@ class ImageFolderDuo(data.Dataset):
imgANA = self.transform(imgANA)
if self.target_transform is not None:
target = self.target_transform(target)
return imgAER, imgANA, target
def __len__(self):
return len(self.imlist)
def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
def load_data_duo(base_dir, batch_size, args, shuffle=True):
train_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Threshold_noise(args.noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default train transform')
val_transform = transforms.Compose(
[transforms.Resize((224, 224)),
Threshold_noise(noise_threshold),
Threshold_noise(args.noise_threshold),
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
generator1 = torch.Generator().manual_seed(42)
indices = torch.randperm(len(train_dataset), generator=generator1)
val_size = len(train_dataset) // 5
train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
repart_weights = [args.classes_names[i] for i in range(len(args.classes_names)) for k in range(args.classes_numbers[i])]
iTrain,iVal = train_test_split(range(len(train_dataset)),test_size = 0.2, shuffle = shuffle, stratify = repart_weights, random_state=42)
train_dataset = torch.utils.data.Subset(train_dataset, iTrain)
val_dataset = torch.utils.data.Subset(val_dataset, iVal)
# generator1 = torch.Generator().manual_seed(42)
# indices = torch.randperm(len(train_dataset), generator=generator1)
# val_size = len(train_dataset) // 5
# train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
# val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
data_loader_train = data.DataLoader(
dataset=train_dataset,
......
import matplotlib.pyplot as plt
import numpy as np
from config.config import load_args
from dataset.dataset import load_data, load_data_duo
import torch
......@@ -18,6 +17,7 @@ def train(model, data_train, optimizer, loss_function, epoch):
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
......@@ -133,7 +133,12 @@ def make_prediction(model, data, f_name):
sn.heatmap(df_cm, annot=cf_matrix)
plt.savefig(f_name)
def memPrint(balise = ""):
print("balise ",balise)
print("mémoire allouée : ",torch.cuda.memory_allocated())
print("mémoire réservée : ",torch.cuda.memory_reserved(),"\n")
def train_duo(model, data_train, optimizer, loss_function, epoch):
model.train()
losses = 0.
......@@ -143,6 +148,8 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
for imaer,imana, label in data_train:
label = label.long()
imaer = imaer.float()
imana = imana.float()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
......@@ -169,6 +176,9 @@ def test_duo(model, data_test, loss_function, epoch):
for imaer,imana, label in data_test:
label = label.long()
imaer = imaer.float()
imana = imana.float()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
......@@ -185,44 +195,62 @@ def test_duo(model, data_test, loss_function, epoch):
def run_duo(args):
#load data
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
memPrint(1)
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size,args=args)
#load model
memPrint(2)
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
model.double()
memPrint(3)
#model = model.double()
memPrint(4)
#load weight
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
memPrint(5)
#move parameters to GPU
if torch.cuda.is_available():
model = model.cuda()
memPrint(6)
#init accumulators
best_acc = 0
best_loss = 1
train_acc=[]
train_loss=[]
val_acc=[]
val_loss=[]
#init training
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
if args.weighted_entropy:
loss_weights = torch.tensor([1/n for n in args.classes_numbers])
if torch.cuda.is_available():
loss_weights = loss_weights.cuda()
loss_function = nn.CrossEntropyLoss(loss_weights)
else:
loss_function = nn.CrossEntropyLoss()
val_loss_function = nn.CrossEntropyLoss()
if args.optim == "SGD":
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
elif args.optim == "Adam":
optimizer = optim.Adam(model.parameters(), lr=args.lr)
else:
raise Exception("Unusual args.optim")
#train model
for e in range(args.epoches):
loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
train_loss.append(loss)
train_acc.append(acc)
if e%args.eval_inter==0 :
loss, acc = test_duo(model,data_test,loss_function,e)
loss, acc = test_duo(model,data_test,val_loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
if acc > best_acc :
if loss < best_loss :
save_model(model,args.save_path)
best_acc = acc
best_loss = loss
# plot and save training figs
plt.plot(train_acc)
plt.plot(val_acc)
plt.plot(train_acc)
plt.plot(train_acc)
plt.ylim(0, 1.05)
plt.plot(train_loss)
plt.plot(val_loss)
plt.plot(train_loss)
plt.plot(train_loss)
plt.ylim(0, 0.025)
plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
......@@ -237,12 +265,14 @@ def make_prediction_duo(model, data, f_name):
# iterate over test data
for imaer,imana, label in data:
label = label.long()
imaer = imaer.float()
imana = imana.float()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
label = label.cuda()
output = model(imaer,imana)
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output)
......
......@@ -292,11 +292,17 @@ class Classification_model_duo(nn.Module):
self.n_class = n_class
if model =='ResNet18':
self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1)
elif model =='ResNet34':
self.im_encoder = resnet34(num_classes=self.n_class, in_channels=1)
elif model =='ResNet50':
self.im_encoder = resnet50(num_classes=self.n_class, in_channels=1)
else:
raise Exception("unusual args.model")
self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
def forward(self, input_aer, input_ana, input_ref):
def forward(self, input_aer, input_ana):
out_aer = self.im_encoder(input_aer)
out_ana = self.im_encoder(input_ana)
out = torch.concat([out_aer,out_ana],dim=1)
......
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png

36.6 KiB | W: | H:

output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png

76.7 KiB | W: | H:

output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_0_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png

36.7 KiB | W: | H:

output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png

77.3 KiB | W: | H:

output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_1000_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png

36.7 KiB | W: | H:

output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png

76.9 KiB | W: | H:

output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png
output/confusion_matrix_noise_100_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png

18.4 KiB | W: | H:

output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png

19.9 KiB | W: | H:

output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_0_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png

19.4 KiB | W: | H:

output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png

20.6 KiB | W: | H:

output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_1000_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png

21.3 KiB | W: | H:

output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png

20.1 KiB | W: | H:

output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png
output/training_plot_noise_100_lr_0.001_model_ResNet18_duo.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment