diff --git a/diann_lib_processing.py b/diann_lib_processing.py index e876c7668764060373f3e38a7ef88d7634077a05..1a95c514440491bbce2e68b6e0a5be222f28ffaf 100644 --- a/diann_lib_processing.py +++ b/diann_lib_processing.py @@ -67,6 +67,9 @@ if __name__ =='__main__': decoder_rt_num_layer=args.decoder_rt_num_layer, drop_rate=args.drop_rate, embedding_dim=args.embedding_dim, acti=args.activation, norm=args.norm_first) + if torch.cuda.is_available(): + model = model.cuda() + model.load_state_dict(torch.load(args.model_weigh, weights_only=True)) data_test = load_data(data_source=args.dataset_test, batch_size=args.batch_size, length=30, mode=args.split_test,