From 32d76e1dfaa8351adec8080dba623ef3d8b1a890 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 11 Feb 2025 10:21:51 +0100
Subject: [PATCH] dataset pred for diann

---
 diann_lib_processing.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/diann_lib_processing.py b/diann_lib_processing.py
index e876c76..1a95c51 100644
--- a/diann_lib_processing.py
+++ b/diann_lib_processing.py
@@ -67,6 +67,9 @@ if __name__ =='__main__':
                              decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate,
                              embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first)
 
+    if torch.cuda.is_available():
+        model = model.cuda()
+
     model.load_state_dict(torch.load(args.model_weigh, weights_only=True))
 
     data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test,
-- 
GitLab