From 4efef23b65355269a4f864ac970c19e39a2aad07 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 10 Mar 2025 14:07:52 +0100 Subject: [PATCH] model cuda loading --- main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/main.py b/main.py index f38d050..2ad0475 100644 --- a/main.py +++ b/main.py @@ -57,6 +57,8 @@ def test(model, data_test, loss_function, epoch): def run(args): model = Classification_model(n_class=9) + if torch.cuda.is_available(): + model = model.cuda() best_acc = 0 train_acc=[] train_loss=[] -- GitLab