diff --git a/main.py b/main.py index f38d050666ec339dc6067d20137d3aaaff2be349..2ad0475975b0c326c5254d6059b088dff097dc10 100644 --- a/main.py +++ b/main.py @@ -57,6 +57,8 @@ def test(model, data_test, loss_function, epoch): def run(args): model = Classification_model(n_class=9) + if torch.cuda.is_available(): + model = model.cuda() best_acc = 0 train_acc=[] train_loss=[]