diff --git a/osrt/utils/training.py b/osrt/utils/training.py
index b65844a042c67c454b59fe525db9f121e54ff619..70392ee09c99551e4927ad3a039ca8d919d648a9 100644
--- a/osrt/utils/training.py
+++ b/osrt/utils/training.py
@@ -285,7 +285,7 @@ def visualize(model, vis_data, out_dir, render_args, fabric, config):
                 columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering'))
                 if i == 0:
                     ari = compute_adjusted_rand_index(
-                        input_mask.flatten(1, 2).transpose(1, 2)[:, 1:],
+                        input_mask.cpu().flatten(1, 2).transpose(1, 2)[:, 1:],
                         pred_seg.flatten(1, 2).transpose(1, 2))
                     row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]