From 308cc51c4f77d2db9fd0ae85e6e4df982d42f40d Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Thu, 13 Jul 2023 15:51:05 +0200
Subject: [PATCH] Fix issue on computing render

---
 osrt/utils/training.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/osrt/utils/training.py b/osrt/utils/training.py
index b65844a..70392ee 100644
--- a/osrt/utils/training.py
+++ b/osrt/utils/training.py
@@ -285,7 +285,7 @@ def visualize(model, vis_data, out_dir, render_args, fabric, config):
                 columns.append((f'pred seg {angle_deg}°', pred_seg.argmax(-1).numpy(), 'clustering'))
                 if i == 0:
                     ari = compute_adjusted_rand_index(
-                        input_mask.flatten(1, 2).transpose(1, 2)[:, 1:],
+                        input_mask.cpu().flatten(1, 2).transpose(1, 2)[:, 1:],
                         pred_seg.flatten(1, 2).transpose(1, 2))
                     row_labels = ['2D Fg-ARI={:.1f}'.format(x.item() * 100) for x in ari]
 
-- 
GitLab