From d0df642bc84a92e5e6fc3375cd964993cab2f0d2 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Thu, 3 Apr 2025 16:40:00 +0200
Subject: [PATCH] fix : typo out file name

---
 config/config.py    |  4 ++--
 image_ref/config.py | 10 +++++-----
 image_ref/main.py   |  1 +
 image_ref/model.py  |  4 ++++
 models/model.py     |  2 ++
 5 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/config/config.py b/config/config.py
index 64879658..74a9dd34 100644
--- a/config/config.py
+++ b/config/config.py
@@ -9,8 +9,8 @@ def load_args():
     parser.add_argument('--eval_inter', type=int, default=1)
     parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
-    parser.add_argument('--batch_size', type=int, default=64)
-    parser.add_argument('--model', type=str, default='ResNet18')
+    parser.add_argument('--batch_size', type=int, default=8)
+    parser.add_argument('--model', type=str, default='ResNet50')
     parser.add_argument('--model_type', type=str, default='duo')
     parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
     parser.add_argument('--output', type=str, default='output/out.csv')
diff --git a/image_ref/config.py b/image_ref/config.py
index 3f96b9eb..7a6ca0f9 100644
--- a/image_ref/config.py
+++ b/image_ref/config.py
@@ -9,12 +9,12 @@ def load_args_contrastive():
     parser.add_argument('--eval_inter', type=int, default=1)
     parser.add_argument('--noise_threshold', type=int, default=0)
     parser.add_argument('--lr', type=float, default=0.001)
-    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--batch_size', type=int, default=16)
     parser.add_argument('--positive_prop', type=int, default=None)
-    parser.add_argument('--model', type=str, default='ResNet18')
-    parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data/npy_image/data_training_contrastive')
-    parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data/npy_image/data_test_contrastive')
-    parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
+    parser.add_argument('--model', type=str, default='ResNet50')
+    parser.add_argument('--dataset_train_dir', type=str, default='../data/processed_data/npy_image/data_training_contrastive')
+    parser.add_argument('--dataset_val_dir', type=str, default='../data/processed_data/npy_image/data_test_contrastive')
+    parser.add_argument('--dataset_ref_dir', type=str, default='../image_ref/img_ref')
     parser.add_argument('--output', type=str, default='output/out_contrastive.csv')
     parser.add_argument('--save_path', type=str, default='output/best_model_constrastive.pt')
     parser.add_argument('--pretrain_path', type=str, default=None)
diff --git a/image_ref/main.py b/image_ref/main.py
index 0d109924..ef6382ea 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -84,6 +84,7 @@ def run_duo(args):
         load_model(model,args.pretrain_path)
     #move parameters to GPU
     if torch.cuda.is_available():
+        print('model loaded on GPU')
         model = model.cuda()
 
     #init accumulators
diff --git a/image_ref/model.py b/image_ref/model.py
index e73ef1b0..5302e3b1 100644
--- a/image_ref/model.py
+++ b/image_ref/model.py
@@ -281,6 +281,10 @@ class Classification_model_duo_contrastive(nn.Module):
         self.n_class = n_class
         if model =='ResNet18':
             self.im_encoder = resnet18(num_classes=2, in_channels=2)
+        if model =='ResNet34':
+            self.im_encoder = resnet34(num_classes=2, in_channels=2)
+        if model =='ResNet50':
+            self.im_encoder = resnet34(num_classes=2, in_channels=2)
 
         self.predictor = nn.Linear(in_features=2*2,out_features=2)
 
diff --git a/models/model.py b/models/model.py
index f0a3d836..ce4e396c 100644
--- a/models/model.py
+++ b/models/model.py
@@ -281,6 +281,8 @@ class Classification_model_duo(nn.Module):
         self.n_class = n_class
         if model =='ResNet18':
             self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1)
+        if model =='ResNet50':
+            self.im_encoder = resnet50(num_classes=self.n_class, in_channels=1)
 
         self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
 
-- 
GitLab