From 4ace431f6ed7d337887305aa02108ec09bf8f97d Mon Sep 17 00:00:00 2001
From: Robert Spektor <goodfella47@gmail.com>
Date: Sun, 7 May 2023 17:56:15 +0300
Subject: [PATCH] Fix model point sampling in data_loader

---
 core/gdrn_modeling/datasets/data_loader.py        | 3 +--
 core/gdrn_modeling/datasets/data_loader_online.py | 3 +--
 2 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/core/gdrn_modeling/datasets/data_loader.py b/core/gdrn_modeling/datasets/data_loader.py
index 1b72e24..cbc1eb2 100644
--- a/core/gdrn_modeling/datasets/data_loader.py
+++ b/core/gdrn_modeling/datasets/data_loader.py
@@ -248,8 +248,7 @@ class GDRN_DatasetFromList(Base_DatasetFromList):
 
         num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS)
         for i in range(len(cur_model_points)):
-            keep_idx = np.arange(num)
-            np.random.shuffle(keep_idx)  # random sampling
+            keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False)
             cur_model_points[i] = cur_model_points[i][keep_idx, :]
 
         self.model_points[dataset_name] = cur_model_points
diff --git a/core/gdrn_modeling/datasets/data_loader_online.py b/core/gdrn_modeling/datasets/data_loader_online.py
index f2f6961..257985a 100644
--- a/core/gdrn_modeling/datasets/data_loader_online.py
+++ b/core/gdrn_modeling/datasets/data_loader_online.py
@@ -241,8 +241,7 @@ class GDRN_Online_DatasetFromList(Base_DatasetFromList):
 
         num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS)
         for i in range(len(cur_model_points)):
-            keep_idx = np.arange(num)
-            np.random.shuffle(keep_idx)  # random sampling
+            keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False)
             cur_model_points[i] = cur_model_points[i][keep_idx, :]
 
         self.model_points[dataset_name] = cur_model_points
-- 
GitLab