diff --git a/image_ref/dataset_ref.py b/image_ref/dataset_ref.py index 37054d71a490ca9d2702b791f759539d4e4b92b1..a086f2d316cae2ebdad6ad280a71d12b5d2c4d54 100644 --- a/image_ref/dataset_ref.py +++ b/image_ref/dataset_ref.py @@ -145,8 +145,8 @@ class ImageFolderDuo(data.Dataset): imgAER = self.transform(imgAER) imgANA = self.transform(imgANA) img_ref = self.transform(img_ref) - target = 0 if target == label_ref else 1 - return imgAER, imgANA, img_ref, target + contrastive_target = 0 if target == label_ref else 1 + return imgAER, imgANA, img_ref, contrastive_target def __len__(self): return len(self.imlist) diff --git a/image_ref/main.py b/image_ref/main.py index 704f625d7440e6cb8ff6b563e5495d1c87e6de0d..35fd390c66ccacb2109420a670905ea4e9915591 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -133,7 +133,7 @@ def run_duo(args): plt.show() - plt.savefig('output/training_plot_contrastive_{}.png'.format(args.prop)) + plt.savefig('output/training_plot_contrastive_{}.png'.format(args.positive_prop)) #load and evaluate best model load_model(model, args.save_path)