From 4ace431f6ed7d337887305aa02108ec09bf8f97d Mon Sep 17 00:00:00 2001 From: Robert Spektor <goodfella47@gmail.com> Date: Sun, 7 May 2023 17:56:15 +0300 Subject: [PATCH] Fix model point sampling in data_loader --- core/gdrn_modeling/datasets/data_loader.py | 3 +-- core/gdrn_modeling/datasets/data_loader_online.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/core/gdrn_modeling/datasets/data_loader.py b/core/gdrn_modeling/datasets/data_loader.py index 1b72e24..cbc1eb2 100644 --- a/core/gdrn_modeling/datasets/data_loader.py +++ b/core/gdrn_modeling/datasets/data_loader.py @@ -248,8 +248,7 @@ class GDRN_DatasetFromList(Base_DatasetFromList): num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS) for i in range(len(cur_model_points)): - keep_idx = np.arange(num) - np.random.shuffle(keep_idx) # random sampling + keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False) cur_model_points[i] = cur_model_points[i][keep_idx, :] self.model_points[dataset_name] = cur_model_points diff --git a/core/gdrn_modeling/datasets/data_loader_online.py b/core/gdrn_modeling/datasets/data_loader_online.py index f2f6961..257985a 100644 --- a/core/gdrn_modeling/datasets/data_loader_online.py +++ b/core/gdrn_modeling/datasets/data_loader_online.py @@ -241,8 +241,7 @@ class GDRN_Online_DatasetFromList(Base_DatasetFromList): num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS) for i in range(len(cur_model_points)): - keep_idx = np.arange(num) - np.random.shuffle(keep_idx) # random sampling + keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False) cur_model_points[i] = cur_model_points[i][keep_idx, :] self.model_points[dataset_name] = cur_model_points -- GitLab