diff --git a/dataset/dataset_ref.py b/dataset/dataset_ref.py index 45fd094d6321fe1a935149c8846113b4125f2f67..849426be3bdebf316f9fdb6bb4ac4ab10ef18198 100644 --- a/dataset/dataset_ref.py +++ b/dataset/dataset_ref.py @@ -133,9 +133,7 @@ class ImageFolderDuo(data.Dataset): imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) img_ref = self.transform(img_ref) - if self.target_transform is not None: - target = 0 if self.target_transform(target) == label_ref else 1 - + target = 0 if target == label_ref else 1 return imgAER, imgANA, img_ref, target def __len__(self): diff --git a/image_ref/main.py b/image_ref/main.py index f4590a29e0fdf414869185ed7775934b4138378c..75780880767b52cfab96176a9e9d8f5a8eb23685 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -194,7 +194,7 @@ def run_duo(args): #load data data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ref_dir=args.dataset_ref_dir) #load model - model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes)) + model = Classification_model_duo_contrastive(model = args.model, n_class=2) model.double() #load weight if args.pretrain_path is not None : @@ -232,10 +232,10 @@ def run_duo(args): plt.ylim(0, 1.05) plt.show() - plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + plt.savefig('../output/training_plot_contrastive_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) #load and evaluate best model load_model(model, args.save_path) - make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + make_prediction_duo(model,data_test, '../output/confusion_matrix_contractive_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) def make_prediction_duo(model, data, f_name): @@ -258,12 +258,10 @@ def make_prediction_duo(model, data, f_name): y_true.extend(label) # Save Truth # constant for classes - classes = data.dataset.dataset.classes # Build confusion matrix - print(len(y_true),len(y_pred)) cf_matrix = confusion_matrix(y_true, y_pred) - df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], - columns=[i for i in classes]) + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in range(2)], + columns=['True','False']) print('Saving Confusion Matrix') plt.figure(figsize=(14, 9)) sn.heatmap(df_cm, annot=cf_matrix) diff --git a/image_ref/model.py b/image_ref/model.py index 1f956957ee41df11683918195c64312f5db2938d..e73ef1b0cedaaf52ac08dfef10ea3b6855b1adbf 100644 --- a/image_ref/model.py +++ b/image_ref/model.py @@ -280,9 +280,9 @@ class Classification_model_duo_contrastive(nn.Module): super().__init__(*args, **kwargs) self.n_class = n_class if model =='ResNet18': - self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2) + self.im_encoder = resnet18(num_classes=2, in_channels=2) - self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class) + self.predictor = nn.Linear(in_features=2*2,out_features=2) def forward(self, input_aer, input_ana, input_ref):