diff --git a/config/config.py b/config/config.py index 4975373d77200ddcd1fee00e9d2b19b59ebd6193..902bbbc1ed3242e377c39d53f81ebedc0fc39d4a 100644 --- a/config/config.py +++ b/config/config.py @@ -32,7 +32,7 @@ def load_args_contrastive(): parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training') - parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') + parser.add_argument('--dataset_ref_dir', type=str, default='img_ref') parser.add_argument('--output', type=str, default='../output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default=None) diff --git a/dataset/dataset_ref.py b/dataset/dataset_ref.py index b77f9c7e09e7e30edd292f3a6afab1e97ddbe46a..45fd094d6321fe1a935149c8846113b4125f2f67 100644 --- a/dataset/dataset_ref.py +++ b/dataset/dataset_ref.py @@ -127,13 +127,15 @@ class ImageFolderDuo(data.Dataset): imgANA = self.loader(impathANA) label_ref = np.random.randint(0,len(self.classes)-1) class_ref = self.classes[label_ref] - path_ref = self.ref_dir + class_ref +'.npy' + path_ref = self.ref_dir +'/'+ class_ref + '.npy' + img_ref = self.loader(path_ref) if self.transform is not None: imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) + img_ref = self.transform(img_ref) if self.target_transform is not None: target = 0 if self.target_transform(target) == label_ref else 1 - img_ref = self.loader(path_ref) + return imgAER, imgANA, img_ref, target def __len__(self): diff --git a/image_ref/img_ref/Citrobacter freundii.npy b/image_ref/img_ref/Citrobacter freundii.npy new file mode 100644 index 0000000000000000000000000000000000000000..9edcae93f12cb1afcc253a6ee7ea0ca106d9c368 Binary files /dev/null and b/image_ref/img_ref/Citrobacter freundii.npy differ diff --git a/image_ref/img_ref/Citrobacter freundii.png b/image_ref/img_ref/Citrobacter freundii.png new file mode 100644 index 0000000000000000000000000000000000000000..bb0c83b2c79f37c029b02ac7ea3ff3515e8f7729 Binary files /dev/null and b/image_ref/img_ref/Citrobacter freundii.png differ diff --git a/image_ref/img_ref/Enterobacter hormaechei.npy b/image_ref/img_ref/Enterobacter hormaechei.npy new file mode 100644 index 0000000000000000000000000000000000000000..633057f00780be8852cda8d4b52782b5c51377b0 Binary files /dev/null and b/image_ref/img_ref/Enterobacter hormaechei.npy differ diff --git a/image_ref/img_ref/Enterobacter hormaechei.png b/image_ref/img_ref/Enterobacter hormaechei.png new file mode 100644 index 0000000000000000000000000000000000000000..67de2dc3a2022ca1dd5ac38b257a12721171a0cb Binary files /dev/null and b/image_ref/img_ref/Enterobacter hormaechei.png differ diff --git a/image_ref/img_ref/Klebsiella oxytoca.npy b/image_ref/img_ref/Klebsiella oxytoca.npy new file mode 100644 index 0000000000000000000000000000000000000000..c43f586d30716e0d4eedd52ab0e4be34459635c2 Binary files /dev/null and b/image_ref/img_ref/Klebsiella oxytoca.npy differ diff --git a/image_ref/img_ref/Klebsiella oxytoca.png b/image_ref/img_ref/Klebsiella oxytoca.png new file mode 100644 index 0000000000000000000000000000000000000000..c38c76634ef58a36a3c75241dace225bf220ba6b Binary files /dev/null and b/image_ref/img_ref/Klebsiella oxytoca.png differ diff --git a/image_ref/img_ref/Klebsiella pneumoniae.npy b/image_ref/img_ref/Klebsiella pneumoniae.npy new file mode 100644 index 0000000000000000000000000000000000000000..701ad33c6557d615d252a6dfcf5f233743e9780c Binary files /dev/null and b/image_ref/img_ref/Klebsiella pneumoniae.npy differ diff --git a/image_ref/img_ref/Klebsiella pneumoniae.png b/image_ref/img_ref/Klebsiella pneumoniae.png new file mode 100644 index 0000000000000000000000000000000000000000..3933455b5802a3c3bc76b0f039e5f7a3fe4ffd83 Binary files /dev/null and b/image_ref/img_ref/Klebsiella pneumoniae.png differ diff --git a/image_ref/img_ref/Proteus mirabilis.npy b/image_ref/img_ref/Proteus mirabilis.npy new file mode 100644 index 0000000000000000000000000000000000000000..35a5f5392fa0be0dbb52fa061bd1f1116a1a4f40 Binary files /dev/null and b/image_ref/img_ref/Proteus mirabilis.npy differ diff --git a/image_ref/img_ref/Proteus mirabilis.png b/image_ref/img_ref/Proteus mirabilis.png new file mode 100644 index 0000000000000000000000000000000000000000..c0276c7d4082bd653e79525861022994c20960bf Binary files /dev/null and b/image_ref/img_ref/Proteus mirabilis.png differ diff --git a/image_ref/main.py b/image_ref/main.py index b659004dffa105a4abcc92af90722fbadcb67362..f4590a29e0fdf414869185ed7775934b4138378c 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt import numpy as np -from config.config import load_args +from config.config import load_args, load_args_contrastive from dataset.dataset_ref import load_data, load_data_duo import torch import torch.nn as nn @@ -146,13 +146,14 @@ def train_duo(model, data_train, optimizer, loss_function, epoch): for param in model.parameters(): param.requires_grad = True - for imaer,imana, label in data_train: + for imaer,imana, img_ref, label in data_train: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() + img_ref = img_ref.cuda() label = label.cuda() - pred_logits = model.forward(imaer,imana) + pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) @@ -172,13 +173,14 @@ def test_duo(model, data_test, loss_function, epoch): for param in model.parameters(): param.requires_grad = False - for imaer,imana, label in data_test: + for imaer,imana, img_ref, label in data_test: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() + img_ref = img_ref.cuda() label = label.cuda() - pred_logits = model.forward(imaer,imana) + pred_logits = model.forward(imaer,imana,img_ref) pred_class = torch.argmax(pred_logits,dim=1) acc += (pred_class==label).sum().item() loss = loss_function(pred_logits,label) @@ -190,7 +192,7 @@ def test_duo(model, data_test, loss_function, epoch): def run_duo(args): #load data - data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ) + data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir) #load model model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes)) model.double() @@ -240,13 +242,14 @@ def make_prediction_duo(model, data, f_name): y_pred = [] y_true = [] # iterate over test data - for imaer,imana, label in data: + for imaer,imana,img_ref, label in data: label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() imana = imana.cuda() + img_ref = img_ref.cuda() label = label.cuda() - output = model(imaer,imana) + output = model(imaer,imana,img_ref) output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() y_pred.extend(output) @@ -277,7 +280,7 @@ def load_model(model, path): if __name__ == '__main__': - args = load_args() + args = load_args_contrastive() if args.model_type=='duo': run_duo(args) else : diff --git a/image_ref/model.py b/image_ref/model.py index 0374d1e73f5e94567ab34aec52130b90e9b710dc..1f956957ee41df11683918195c64312f5db2938d 100644 --- a/image_ref/model.py +++ b/image_ref/model.py @@ -286,8 +286,8 @@ class Classification_model_duo_contrastive(nn.Module): def forward(self, input_aer, input_ana, input_ref): - input_ana = torch.concat(input_ana, input_ref, dim=2) - input_aer = torch.concat(input_aer, input_ref, dim=2) + input_ana = torch.concat([input_ana, input_ref], dim=1) + input_aer = torch.concat([input_aer, input_ref], dim=1) out_aer = self.im_encoder(input_aer) out_ana = self.im_encoder(input_ana) out = torch.concat([out_aer,out_ana],dim=1)