From 32d76e1dfaa8351adec8080dba623ef3d8b1a890 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 11 Feb 2025 10:21:51 +0100 Subject: [PATCH] dataset pred for diann --- diann_lib_processing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diann_lib_processing.py b/diann_lib_processing.py index e876c76..1a95c51 100644 --- a/diann_lib_processing.py +++ b/diann_lib_processing.py @@ -67,6 +67,9 @@ if __name__ =='__main__': decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + if torch.cuda.is_available(): + model = model.cuda() + model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test, -- GitLab