diff --git a/main.py b/main.py
index f38d050666ec339dc6067d20137d3aaaff2be349..2ad0475975b0c326c5254d6059b088dff097dc10 100644
--- a/main.py
+++ b/main.py
@@ -57,6 +57,8 @@ def test(model, data_test, loss_function, epoch):
 
 def run(args):
     model = Classification_model(n_class=9)
+    if torch.cuda.is_available():
+        model = model.cuda()
     best_acc = 0
     train_acc=[]
     train_loss=[]