Skip to content
Snippets Groups Projects
Commit 5123f205 authored by Schneider Leo's avatar Schneider Leo
Browse files

model duo

parent 10dfce75
No related branches found
No related tags found
No related merge requests found
......@@ -86,7 +86,7 @@ def default_loader(path):
return Image.open(path).convert('RGB')
def remove_aer_ana(l):
l = l.map(lambda x : x.split('_')[0])
l = map(lambda x : x.split('_')[0],l)
return list(OrderedDict.fromkeys(l))
def make_dataset_custom(
......@@ -118,7 +118,7 @@ def make_dataset_custom(
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
is_valid_file = cast(Callable[[str], bool], is_valid_file)
......@@ -136,7 +136,7 @@ def make_dataset_custom(
fname_aer = fname + '_AER.png'
path_ana = os.path.join(root, fname_ana)
path_aer = os.path.join(root, fname_aer)
if is_valid_file(path_ana) and is_valid_file(path_aer):
if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
item = path_aer, path_ana, class_index
instances.append(item)
......@@ -161,12 +161,12 @@ class ImageFolderDuo(data.Dataset):
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.classes = torchvision.datasets.folder.find_classes(root)
self.classes = torchvision.datasets.folder.find_classes(root)[0]
def __getitem__(self, index):
impathAER, impathANA, target = self.imlist[index]
imgAER = self.loader(os.path.join(self.root, impathAER))
imgANA = self.loader(os.path.join(self.root, impathANA))
imgAER = self.loader(impathAER)
imgANA = self.loader(impathANA)
if self.transform is not None:
imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA)
......@@ -196,8 +196,8 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0):
Log_normalisation(),
transforms.Normalize(0.5, 0.5)])
print('Default val transform')
train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform)
val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform)
train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform)
val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform)
generator1 = torch.Generator().manual_seed(42)
indices = torch.randperm(len(train_dataset), generator=generator1)
val_size = len(train_dataset) // 5
......
......@@ -2,10 +2,10 @@ import matplotlib.pyplot as plt
import numpy as np
from config.config import load_args
from dataset.dataset import load_data
from dataset.dataset import load_data, load_data_duo
import torch
import torch.nn as nn
from models.model import Classification_model
from models.model import Classification_model, Classification_model_duo
import torch.optim as optim
from sklearn.metrics import confusion_matrix
import seaborn as sn
......@@ -88,6 +88,7 @@ def run(args):
plt.plot(val_acc)
plt.plot(train_acc)
plt.plot(train_acc)
plt.ylim(0, 1.05)
plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
......@@ -124,6 +125,121 @@ def make_prediction(model, data, f_name):
plt.savefig(f_name)
def train_duo(model, data_train, optimizer, loss_function, epoch):
model.train()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = True
for imaer,imana, label in data_train:
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
pred_logits = model.forward(imaer,imana)
pred_class = torch.argmax(pred_logits,dim=1)
acc += (pred_class==label).sum().item()
loss = loss_function(pred_logits,label)
losses += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses = losses/len(data_train.dataset)
acc = acc/len(data_train.dataset)
print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses, acc
def test_duo(model, data_test, loss_function, epoch):
model.eval()
losses = 0.
acc = 0.
for param in model.parameters():
param.requires_grad = False
for imaer,imana, label in data_test:
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
pred_logits = model.forward(imaer,imana)
pred_class = torch.argmax(pred_logits,dim=1)
acc += (pred_class==label).sum().item()
loss = loss_function(pred_logits,label)
losses += loss.item()
losses = losses/len(data_test.dataset)
acc = acc/len(data_test.dataset)
print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
return losses,acc
def run_duo(args):
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size)
model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes))
if args.pretrain_path is not None :
load_model(model,args.pretrain_path)
if torch.cuda.is_available():
model = model.cuda()
best_acc = 0
train_acc=[]
train_loss=[]
val_acc=[]
val_loss=[]
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for e in range(args.epoches):
loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
train_loss.append(loss)
train_acc.append(acc)
if e%args.eval_inter==0 :
loss, acc = test_duo(model,data_test,loss_function,e)
val_loss.append(loss)
val_acc.append(acc)
if acc > best_acc :
save_model(model,args.save_path)
best_acc = acc
plt.plot(train_acc)
plt.plot(val_acc)
plt.plot(train_acc)
plt.plot(train_acc)
plt.ylim(0, 1.05)
plt.show()
plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
load_model(model, args.save_path)
make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model))
def make_prediction_duo(model, data, f_name):
y_pred = []
y_true = []
# iterate over test data
for imaer,imana, label in data:
label = label.long()
if torch.cuda.is_available():
imaer = imaer.cuda()
imana = imana.cuda()
output = model(imaer,imana)
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output)
label = label.data.cpu().numpy()
y_true.extend(label) # Save Truth
# constant for classes
classes = data.dataset.dataset.classes
# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
columns=[i for i in classes])
plt.figure(figsize=(12, 7))
sn.heatmap(df_cm, annot=True)
plt.savefig(f_name)
def save_model(model, path):
print('Model saved')
torch.save(model.state_dict(), path)
......@@ -135,4 +251,4 @@ def load_model(model, path):
if __name__ == '__main__':
args = load_args()
run(args)
\ No newline at end of file
run_duo(args)
\ No newline at end of file
......@@ -270,4 +270,22 @@ class Classification_model(nn.Module):
def forward(self, input):
return self.im_encoder(input)
\ No newline at end of file
return self.im_encoder(input)
class Classification_model_duo(nn.Module):
def __init__(self, model, n_class, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_class = n_class
if model =='ResNet18':
self.im_encoder = resnet18(num_classes=self.n_class)
self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
def forward(self, input_aer, input_ana):
out_aer = self.im_encoder(input_aer)
out_ana = self.im_encoder(input_ana)
out = torch.concat([out_aer,out_ana],dim=1)
return self.predictor(out)
output/training_plot.png

31.8 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment