diff --git a/core/gdrn_modeling/datasets/data_loader.py b/core/gdrn_modeling/datasets/data_loader.py index 1b72e24980834f9c140aa07426969f4eaf740f9d..cbc1eb229970c9cbf947fec9af98b5f387c90820 100644 --- a/core/gdrn_modeling/datasets/data_loader.py +++ b/core/gdrn_modeling/datasets/data_loader.py @@ -248,8 +248,7 @@ class GDRN_DatasetFromList(Base_DatasetFromList): num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS) for i in range(len(cur_model_points)): - keep_idx = np.arange(num) - np.random.shuffle(keep_idx) # random sampling + keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False) cur_model_points[i] = cur_model_points[i][keep_idx, :] self.model_points[dataset_name] = cur_model_points diff --git a/core/gdrn_modeling/datasets/data_loader_online.py b/core/gdrn_modeling/datasets/data_loader_online.py index f2f6961c44cbb4edd8fc0e430ba5ff3edaa57557..257985afc2248b91cc735c3d385dbd741a65cdc8 100644 --- a/core/gdrn_modeling/datasets/data_loader_online.py +++ b/core/gdrn_modeling/datasets/data_loader_online.py @@ -241,8 +241,7 @@ class GDRN_Online_DatasetFromList(Base_DatasetFromList): num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS) for i in range(len(cur_model_points)): - keep_idx = np.arange(num) - np.random.shuffle(keep_idx) # random sampling + keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False) cur_model_points[i] = cur_model_points[i][keep_idx, :] self.model_points[dataset_name] = cur_model_points