Skip to content
Snippets Groups Projects
Unverified Commit f3ca1863 authored by Gu Wang's avatar Gu Wang Committed by GitHub
Browse files

Merge pull request #59 from goodfella47/fix-model-points-sampling

Fix biased model point sampling in data_loader
parents f9aca009 4ace431f
No related branches found
No related tags found
No related merge requests found
......@@ -248,8 +248,7 @@ class GDRN_DatasetFromList(Base_DatasetFromList):
num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS)
for i in range(len(cur_model_points)):
keep_idx = np.arange(num)
np.random.shuffle(keep_idx) # random sampling
keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False)
cur_model_points[i] = cur_model_points[i][keep_idx, :]
self.model_points[dataset_name] = cur_model_points
......
......@@ -241,8 +241,7 @@ class GDRN_Online_DatasetFromList(Base_DatasetFromList):
num = min(num, cfg.MODEL.POSE_NET.LOSS_CFG.NUM_PM_POINTS)
for i in range(len(cur_model_points)):
keep_idx = np.arange(num)
np.random.shuffle(keep_idx) # random sampling
keep_idx = np.random.choice(len(cur_model_points[i]), num, replace=False)
cur_model_points[i] = cur_model_points[i][keep_idx, :]
self.model_points[dataset_name] = cur_model_points
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment