From 3b64549def313fe48f48875a9cbb773b395f5877 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Fri, 18 Apr 2025 13:15:10 +0200 Subject: [PATCH] fix : float double error --- image_ref/main_ray.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/image_ref/main_ray.py b/image_ref/main_ray.py index 2ead9df..a45a89f 100644 --- a/image_ref/main_ray.py +++ b/image_ref/main_ray.py @@ -80,9 +80,9 @@ def train_model(config,args): param.requires_grad = True for imaer, imana, img_ref, label in data_train: - imaer.float() - imana.float() - img_ref.float() + imaer = imaer.float() + imana = imana.float() + img_ref = img_ref.float() label = label.long() if torch.cuda.is_available(): imaer = imaer.cuda() @@ -112,9 +112,9 @@ def train_model(config,args): param.requires_grad = False for imaer, imana, img_ref, label in data_val_batch: - imaer.float() - imana.float() - img_ref.float() + imaer = imaer.float() + imana = imana.float() + img_ref = img_ref.float() imaer = imaer.transpose(0, 1) imana = imana.transpose(0, 1) img_ref = img_ref.transpose(0, 1) @@ -183,7 +183,7 @@ def test_model(best_result, args): # load model model = Classification_model_duo_contrastive(model=args.model, n_class=2) - model.double() + model.float() # load weight checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt") @@ -209,6 +209,9 @@ def test_model(best_result, args): param.requires_grad = False for imaer, imana, img_ref, label in data_val_batch: + imaer = imaer.float() + imana = imana.float() + img_ref = img_ref.float() imaer = imaer.transpose(0, 1) imana = imana.transpose(0, 1) img_ref = img_ref.transpose(0, 1) -- GitLab