diff --git a/image_ref/main.py b/image_ref/main.py index a6b0988ffbd094513f4f645a15124eb842858385..cbb8986ed53e15f1e4890afba3905573114472b7 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -107,7 +107,7 @@ def run_duo(args): # load model model = Classification_model_duo_contrastive(model=args.model, n_class=2) - model.float() + model.double() # load weight if args.pretrain_path is not None: print('Model weight loaded')