Skip to content
Snippets Groups Projects
Commit fdd886a8 authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : model contrastive

add : ref image .npy
parent 85d9b647
No related branches found
No related tags found
No related merge requests found
Showing with 19 additions and 14 deletions
...@@ -32,7 +32,7 @@ def load_args_contrastive(): ...@@ -32,7 +32,7 @@ def load_args_contrastive():
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet18')
parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training') parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training')
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='img_ref')
parser.add_argument('--output', type=str, default='../output/out_contrastive.csv') parser.add_argument('--output', type=str, default='../output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default=None) parser.add_argument('--pretrain_path', type=str, default=None)
......
...@@ -127,13 +127,15 @@ class ImageFolderDuo(data.Dataset): ...@@ -127,13 +127,15 @@ class ImageFolderDuo(data.Dataset):
imgANA = self.loader(impathANA) imgANA = self.loader(impathANA)
label_ref = np.random.randint(0,len(self.classes)-1) label_ref = np.random.randint(0,len(self.classes)-1)
class_ref = self.classes[label_ref] class_ref = self.classes[label_ref]
path_ref = self.ref_dir + class_ref +'.npy' path_ref = self.ref_dir +'/'+ class_ref + '.npy'
img_ref = self.loader(path_ref)
if self.transform is not None: if self.transform is not None:
imgAER = self.transform(imgAER) imgAER = self.transform(imgAER)
imgANA = self.transform(imgANA) imgANA = self.transform(imgANA)
img_ref = self.transform(img_ref)
if self.target_transform is not None: if self.target_transform is not None:
target = 0 if self.target_transform(target) == label_ref else 1 target = 0 if self.target_transform(target) == label_ref else 1
img_ref = self.loader(path_ref)
return imgAER, imgANA, img_ref, target return imgAER, imgANA, img_ref, target
def __len__(self): def __len__(self):
......
File added
image_ref/img_ref/Citrobacter freundii.png

8.21 KiB

File added
image_ref/img_ref/Enterobacter hormaechei.png

8.63 KiB

File added
image_ref/img_ref/Klebsiella oxytoca.png

8.17 KiB

File added
image_ref/img_ref/Klebsiella pneumoniae.png

8.25 KiB

File added
image_ref/img_ref/Proteus mirabilis.png

7.78 KiB

...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from config.config import load_args from config.config import load_args, load_args_contrastive
from dataset.dataset_ref import load_data, load_data_duo from dataset.dataset_ref import load_data, load_data_duo
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -146,13 +146,14 @@ def train_duo(model, data_train, optimizer, loss_function, epoch): ...@@ -146,13 +146,14 @@ def train_duo(model, data_train, optimizer, loss_function, epoch):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = True param.requires_grad = True
for imaer,imana, label in data_train: for imaer,imana, img_ref, label in data_train:
label = label.long() label = label.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
pred_logits = model.forward(imaer,imana) pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits,dim=1) pred_class = torch.argmax(pred_logits,dim=1)
acc += (pred_class==label).sum().item() acc += (pred_class==label).sum().item()
loss = loss_function(pred_logits,label) loss = loss_function(pred_logits,label)
...@@ -172,13 +173,14 @@ def test_duo(model, data_test, loss_function, epoch): ...@@ -172,13 +173,14 @@ def test_duo(model, data_test, loss_function, epoch):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
for imaer,imana, label in data_test: for imaer,imana, img_ref, label in data_test:
label = label.long() label = label.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
pred_logits = model.forward(imaer,imana) pred_logits = model.forward(imaer,imana,img_ref)
pred_class = torch.argmax(pred_logits,dim=1) pred_class = torch.argmax(pred_logits,dim=1)
acc += (pred_class==label).sum().item() acc += (pred_class==label).sum().item()
loss = loss_function(pred_logits,label) loss = loss_function(pred_logits,label)
...@@ -190,7 +192,7 @@ def test_duo(model, data_test, loss_function, epoch): ...@@ -190,7 +192,7 @@ def test_duo(model, data_test, loss_function, epoch):
def run_duo(args): def run_duo(args):
#load data #load data
data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ) data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir)
#load model #load model
model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes)) model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes))
model.double() model.double()
...@@ -240,13 +242,14 @@ def make_prediction_duo(model, data, f_name): ...@@ -240,13 +242,14 @@ def make_prediction_duo(model, data, f_name):
y_pred = [] y_pred = []
y_true = [] y_true = []
# iterate over test data # iterate over test data
for imaer,imana, label in data: for imaer,imana,img_ref, label in data:
label = label.long() label = label.long()
if torch.cuda.is_available(): if torch.cuda.is_available():
imaer = imaer.cuda() imaer = imaer.cuda()
imana = imana.cuda() imana = imana.cuda()
img_ref = img_ref.cuda()
label = label.cuda() label = label.cuda()
output = model(imaer,imana) output = model(imaer,imana,img_ref)
output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
y_pred.extend(output) y_pred.extend(output)
...@@ -277,7 +280,7 @@ def load_model(model, path): ...@@ -277,7 +280,7 @@ def load_model(model, path):
if __name__ == '__main__': if __name__ == '__main__':
args = load_args() args = load_args_contrastive()
if args.model_type=='duo': if args.model_type=='duo':
run_duo(args) run_duo(args)
else : else :
......
...@@ -286,8 +286,8 @@ class Classification_model_duo_contrastive(nn.Module): ...@@ -286,8 +286,8 @@ class Classification_model_duo_contrastive(nn.Module):
def forward(self, input_aer, input_ana, input_ref): def forward(self, input_aer, input_ana, input_ref):
input_ana = torch.concat(input_ana, input_ref, dim=2) input_ana = torch.concat([input_ana, input_ref], dim=1)
input_aer = torch.concat(input_aer, input_ref, dim=2) input_aer = torch.concat([input_aer, input_ref], dim=1)
out_aer = self.im_encoder(input_aer) out_aer = self.im_encoder(input_aer)
out_ana = self.im_encoder(input_ana) out_ana = self.im_encoder(input_ana)
out = torch.concat([out_aer,out_ana],dim=1) out = torch.concat([out_aer,out_ana],dim=1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment