Skip to content
Snippets Groups Projects
Commit 599ab449 authored by Guillaume Duret's avatar Guillaume Duret
Browse files

compute predicted pose for evaluation

parent 195c8b74
No related branches found
No related tags found
No related merge requests found
...@@ -74,9 +74,11 @@ def predict_pose(class_name, image, fps_points): ...@@ -74,9 +74,11 @@ def predict_pose(class_name, image, fps_points):
# showImage(classPred) # let's see our class prediction output # showImage(classPred) # let's see our class prediction output
# ==================== # ====================
population = np.where(classPred > .9)[:2] # .9 #print(classPred)
population = np.where(classPred > 0.1)[:2] # .9
population = list(zip(population[0], population[1])) population = list(zip(population[0], population[1]))
# print(len(population)) # the number of class pixels found print(len(population)) # the number of class pixels found
#print(population)
# ==================== # ====================
hypDict = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: []} hypDict = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: []}
...@@ -156,15 +158,26 @@ if __name__ == '__main__': ...@@ -156,15 +158,26 @@ if __name__ == '__main__':
images_ls, labels_ls, mask_ls, choice_ls = data.getAllValDataFruits(class_name) images_ls, labels_ls, mask_ls, choice_ls = data.getAllValDataFruits(class_name)
print(len(images_ls)) print(len(images_ls))
if not os.path.exists(f"{basePath}/Pose_prediction"):
os.makedirs(f"{basePath}/Pose_prediction")
for i, img in enumerate(images_ls): for i, img in enumerate(images_ls):
img_id = choice_ls[i].split('.png') img_id = choice_ls[i].split('.png')
img_id = int(img_id[0]) img_id = int(img_id[0])
print("id : ", img_id)
try :
r_pre, t_pre = predict_pose(class_name, img, fps)
r = R.from_rotvec(r_pre.reshape(3, ))
r_pre_mx = np.array(r.as_matrix())
t_pre = np.array(t_pre).reshape(3, )
res = np.zeros((3, 4))
res[:3, :3] = r_pre_mx
res[:3, 3] = t_pre
print(res)
np.save(f'{basePath}/Pose_prediction/{img_id}.npy', res) # save
except :
print("The image is not good, mess than 50 pix segmentation")
r_pre, t_pre = predict_pose(class_name, img, fps)
r = R.from_rotvec(r_pre.reshape(3, ))
r_pre_mx = np.array(r.as_matrix())
res = np.zeros((3, 4))
res[:3, :3] = r_pre_mx
res[:3, 3] = t_pre
np.save(f'{basePath}/Pose_prediction/{class_name}/{img_id}.npy', res) # save
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment