diff --git a/main_ray_tune.py b/main_ray_tune.py index 0ff6221d836adb3e082a082146b04bce94eb7b53..f4cd748b5cff6275807ed1406bab272024875c94 100644 --- a/main_ray_tune.py +++ b/main_ray_tune.py @@ -199,9 +199,9 @@ def test_best_model(best_result, args): if torch.cuda.is_available(): device = "cuda:0" if torch.cuda.device_count() > 1: - print(print(type(best_trained_model.module))) + print(type(best_trained_model)) best_trained_model = torch.nn.DataParallel(best_trained_model) - print(print(type(best_trained_model.module))) + print(type(best_trained_model)) best_trained_model.to(device) criterion_rt = torch.nn.MSELoss()