Skip to content
Snippets Groups Projects
Commit d0df642b authored by Schneider Leo's avatar Schneider Leo
Browse files

fix : typo out file name

parent 176aa057
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,8 @@ def load_args(): ...@@ -9,8 +9,8 @@ def load_args():
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet50')
parser.add_argument('--model_type', type=str, default='duo') parser.add_argument('--model_type', type=str, default='duo')
parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training') parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--output', type=str, default='output/out.csv')
......
...@@ -9,12 +9,12 @@ def load_args_contrastive(): ...@@ -9,12 +9,12 @@ def load_args_contrastive():
parser.add_argument('--eval_inter', type=int, default=1) parser.add_argument('--eval_inter', type=int, default=1)
parser.add_argument('--noise_threshold', type=int, default=0) parser.add_argument('--noise_threshold', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--positive_prop', type=int, default=None) parser.add_argument('--positive_prop', type=int, default=None)
parser.add_argument('--model', type=str, default='ResNet18') parser.add_argument('--model', type=str, default='ResNet50')
parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive') parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive')
parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive') parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive')
parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref')
parser.add_argument('--output', type=str, default='output/out_contrastive.csv') parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt') parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
parser.add_argument('--pretrain_path', type=str, default=None) parser.add_argument('--pretrain_path', type=str, default=None)
......
...@@ -84,6 +84,7 @@ def run_duo(args): ...@@ -84,6 +84,7 @@ def run_duo(args):
load_model(model,args.pretrain_path) load_model(model,args.pretrain_path)
#move parameters to GPU #move parameters to GPU
if torch.cuda.is_available(): if torch.cuda.is_available():
print('model loaded on GPU')
model = model.cuda() model = model.cuda()
#init accumulators #init accumulators
......
...@@ -281,6 +281,10 @@ class Classification_model_duo_contrastive(nn.Module): ...@@ -281,6 +281,10 @@ class Classification_model_duo_contrastive(nn.Module):
self.n_class = n_class self.n_class = n_class
if model =='ResNet18': if model =='ResNet18':
self.im_encoder = resnet18(num_classes=2, in_channels=2) self.im_encoder = resnet18(num_classes=2, in_channels=2)
if model =='ResNet34':
self.im_encoder = resnet34(num_classes=2, in_channels=2)
if model =='ResNet50':
self.im_encoder = resnet34(num_classes=2, in_channels=2)
self.predictor = nn.Linear(in_features=2*2,out_features=2) self.predictor = nn.Linear(in_features=2*2,out_features=2)
......
...@@ -281,6 +281,8 @@ class Classification_model_duo(nn.Module): ...@@ -281,6 +281,8 @@ class Classification_model_duo(nn.Module):
self.n_class = n_class self.n_class = n_class
if model =='ResNet18': if model =='ResNet18':
self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1) self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1)
if model =='ResNet50':
self.im_encoder = resnet50(num_classes=self.n_class, in_channels=1)
self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class) self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment