diff --git a/config/config.py b/config/config.py index 648796581937c1e5a86e13b630d2d07949973256..74a9dd34d61281e9fe92d79e0e048f44cd3342c4 100644 --- a/config/config.py +++ b/config/config.py @@ -9,8 +9,8 @@ def load_args(): parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--batch_size', type=int, default=8) + parser.add_argument('--model', type=str, default='ResNet50') parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training') parser.add_argument('--output', type=str, default='output/out.csv') diff --git a/image_ref/config.py b/image_ref/config.py index 3f96b9eb85e5be7b488e249318e3a753377e08a5..7a6ca0f915dbdc413a568b1527b028ed32d875c9 100644 --- a/image_ref/config.py +++ b/image_ref/config.py @@ -9,12 +9,12 @@ def load_args_contrastive(): parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--lr', type=float, default=0.001) - parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--positive_prop', type=int, default=None) - parser.add_argument('--model', type=str, default='ResNet18') - parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') - parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') - parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') + parser.add_argument('--model', type=str, default='ResNet50') + parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive') + parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive') + parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref') parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--pretrain_path', type=str, default=None) diff --git a/image_ref/main.py b/image_ref/main.py index 0d109924a702ab54673efe14e823f09a5b9487cb..ef6382ea822a3061bd47818011fe3f27aa26abfc 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -84,6 +84,7 @@ def run_duo(args): load_model(model,args.pretrain_path) #move parameters to GPU if torch.cuda.is_available(): + print('model loaded on GPU') model = model.cuda() #init accumulators diff --git a/image_ref/model.py b/image_ref/model.py index e73ef1b0cedaaf52ac08dfef10ea3b6855b1adbf..5302e3b122d79d71372a781aee834f805840f5a4 100644 --- a/image_ref/model.py +++ b/image_ref/model.py @@ -281,6 +281,10 @@ class Classification_model_duo_contrastive(nn.Module): self.n_class = n_class if model =='ResNet18': self.im_encoder = resnet18(num_classes=2, in_channels=2) + if model =='ResNet34': + self.im_encoder = resnet34(num_classes=2, in_channels=2) + if model =='ResNet50': + self.im_encoder = resnet34(num_classes=2, in_channels=2) self.predictor = nn.Linear(in_features=2*2,out_features=2) diff --git a/models/model.py b/models/model.py index f0a3d836adecc0160a2c2068306b54b945bfe13c..ce4e396c7a821d5a2406cb96d829a82b317dfd2c 100644 --- a/models/model.py +++ b/models/model.py @@ -281,6 +281,8 @@ class Classification_model_duo(nn.Module): self.n_class = n_class if model =='ResNet18': self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1) + if model =='ResNet50': + self.im_encoder = resnet50(num_classes=self.n_class, in_channels=1) self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)