diff --git a/dataset/dataset.py b/dataset/dataset.py index e5a1a7b22715124b657750e5074fe2981a78e108..b937657cb0d030af8b56bc9579c3dee3cdfec6ca 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -86,7 +86,7 @@ def default_loader(path): return Image.open(path).convert('RGB') def remove_aer_ana(l): - l = l.map(lambda x : x.split('_')[0]) + l = map(lambda x : x.split('_')[0],l) return list(OrderedDict.fromkeys(l)) def make_dataset_custom( @@ -118,7 +118,7 @@ def make_dataset_custom( if extensions is not None: def is_valid_file(x: str) -> bool: - return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] + return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type] is_valid_file = cast(Callable[[str], bool], is_valid_file) @@ -136,7 +136,7 @@ def make_dataset_custom( fname_aer = fname + '_AER.png' path_ana = os.path.join(root, fname_ana) path_aer = os.path.join(root, fname_aer) - if is_valid_file(path_ana) and is_valid_file(path_aer): + if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): item = path_aer, path_ana, class_index instances.append(item) @@ -161,12 +161,12 @@ class ImageFolderDuo(data.Dataset): self.transform = transform self.target_transform = target_transform self.loader = loader - self.classes = torchvision.datasets.folder.find_classes(root) + self.classes = torchvision.datasets.folder.find_classes(root)[0] def __getitem__(self, index): impathAER, impathANA, target = self.imlist[index] - imgAER = self.loader(os.path.join(self.root, impathAER)) - imgANA = self.loader(os.path.join(self.root, impathANA)) + imgAER = self.loader(impathAER) + imgANA = self.loader(impathANA) if self.transform is not None: imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) @@ -196,8 +196,8 @@ def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0): Log_normalisation(), transforms.Normalize(0.5, 0.5)]) print('Default val transform') - train_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=train_transform) - val_dataset = torchvision.datasets.ImageFolderDuo(root=base_dir, transform=val_transform) + train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform) + val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform) generator1 = torch.Generator().manual_seed(42) indices = torch.randperm(len(train_dataset), generator=generator1) val_size = len(train_dataset) // 5 diff --git a/main.py b/main.py index 13946a3ba0746cf702c8be77b993710b04517c48..84240bfe981ded28e301f923e21854b160622b47 100644 --- a/main.py +++ b/main.py @@ -2,10 +2,10 @@ import matplotlib.pyplot as plt import numpy as np from config.config import load_args -from dataset.dataset import load_data +from dataset.dataset import load_data, load_data_duo import torch import torch.nn as nn -from models.model import Classification_model +from models.model import Classification_model, Classification_model_duo import torch.optim as optim from sklearn.metrics import confusion_matrix import seaborn as sn @@ -88,6 +88,7 @@ def run(args): plt.plot(val_acc) plt.plot(train_acc) plt.plot(train_acc) + plt.ylim(0, 1.05) plt.show() plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) @@ -124,6 +125,121 @@ def make_prediction(model, data, f_name): plt.savefig(f_name) +def train_duo(model, data_train, optimizer, loss_function, epoch): + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for imaer,imana, label in data_train: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + pred_logits = model.forward(imaer,imana) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses/len(data_train.dataset) + acc = acc/len(data_train.dataset) + print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses, acc + +def test_duo(model, data_test, loss_function, epoch): + model.eval() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer,imana, label in data_test: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + pred_logits = model.forward(imaer,imana) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + losses = losses/len(data_test.dataset) + acc = acc/len(data_test.dataset) + print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses,acc + +def run_duo(args): + data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size) + model = Classification_model_duo(model = args.model, n_class=len(data_train.dataset.dataset.classes)) + if args.pretrain_path is not None : + load_model(model,args.pretrain_path) + if torch.cuda.is_available(): + model = model.cuda() + best_acc = 0 + train_acc=[] + train_loss=[] + val_acc=[] + val_loss=[] + loss_function = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for e in range(args.epoches): + loss, acc = train_duo(model,data_train,optimizer,loss_function,e) + train_loss.append(loss) + train_acc.append(acc) + if e%args.eval_inter==0 : + loss, acc = test_duo(model,data_test,loss_function,e) + val_loss.append(loss) + val_acc.append(acc) + if acc > best_acc : + save_model(model,args.save_path) + best_acc = acc + plt.plot(train_acc) + plt.plot(val_acc) + plt.plot(train_acc) + plt.plot(train_acc) + plt.ylim(0, 1.05) + plt.show() + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + + load_model(model, args.save_path) + make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}.png'.format(args.noise_threshold,args.lr,args.model)) + + +def make_prediction_duo(model, data, f_name): + y_pred = [] + y_true = [] + + # iterate over test data + for imaer,imana, label in data: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + output = model(imaer,imana) + + output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() + y_pred.extend(output) + + label = label.data.cpu().numpy() + y_true.extend(label) # Save Truth + # constant for classes + + classes = data.dataset.dataset.classes + + # Build confusion matrix + cf_matrix = confusion_matrix(y_true, y_pred) + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + columns=[i for i in classes]) + plt.figure(figsize=(12, 7)) + sn.heatmap(df_cm, annot=True) + plt.savefig(f_name) + + def save_model(model, path): print('Model saved') torch.save(model.state_dict(), path) @@ -135,4 +251,4 @@ def load_model(model, path): if __name__ == '__main__': args = load_args() - run(args) \ No newline at end of file + run_duo(args) \ No newline at end of file diff --git a/models/model.py b/models/model.py index 192120347961c6c5e1555d3e80c67f2389228206..cea9caa2a300f4832e708b618bdca4fff20aca10 100644 --- a/models/model.py +++ b/models/model.py @@ -270,4 +270,22 @@ class Classification_model(nn.Module): def forward(self, input): - return self.im_encoder(input) \ No newline at end of file + return self.im_encoder(input) + +class Classification_model_duo(nn.Module): + + def __init__(self, model, n_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_class = n_class + if model =='ResNet18': + self.im_encoder = resnet18(num_classes=self.n_class) + + self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class) + + + def forward(self, input_aer, input_ana): + out_aer = self.im_encoder(input_aer) + out_ana = self.im_encoder(input_ana) + out = torch.concat([out_aer,out_ana],dim=1) + return self.predictor(out) + diff --git a/output/training_plot.png b/output/training_plot.png deleted file mode 100644 index 98bd0695f3fa500dbed85dc1907b46a8c6efac46..0000000000000000000000000000000000000000 Binary files a/output/training_plot.png and /dev/null differ