diff --git a/osrt/utils/training.py b/osrt/utils/training.py index b65844a042c67c454b59fe525db9f121e54ff619..70392ee09c99551e4927ad3a039ca8d919d648a9 100644 --- a/osrt/utils/training.py +++ b/osrt/utils/training.py @@ -285,7 +285,7 @@ def visualize(model, vis_data, out_dir, render_args, fabric, config): columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering')) if i == 0: ari = compute_adjusted_rand_index( - input_mask.flatten(1, 2).transpose(1, 2)[:, 1:], + input_mask.cpu().flatten(1, 2).transpose(1, 2)[:, 1:], pred_seg.flatten(1, 2).transpose(1, 2)) row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]