diff --git a/predict_pose.py b/predict_pose.py index 9686aadcd6b51174c2de2886330d978ba2559b12..3ae6cd02db5022bc2d01d98537ce30db82125602 100644 --- a/predict_pose.py +++ b/predict_pose.py @@ -74,9 +74,11 @@ def predict_pose(class_name, image, fps_points): # showImage(classPred) # let's see our class prediction output # ==================== - population = np.where(classPred > .9)[:2] # .9 + #print(classPred) + population = np.where(classPred > 0.1)[:2] # .9 population = list(zip(population[0], population[1])) - # print(len(population)) # the number of class pixels found + print(len(population)) # the number of class pixels found + #print(population) # ==================== hypDict = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: []} @@ -156,15 +158,26 @@ if __name__ == '__main__': images_ls, labels_ls, mask_ls, choice_ls = data.getAllValDataFruits(class_name) print(len(images_ls)) + if not os.path.exists(f"{basePath}/Pose_prediction"): + os.makedirs(f"{basePath}/Pose_prediction") + for i, img in enumerate(images_ls): img_id = choice_ls[i].split('.png') img_id = int(img_id[0]) + print("id : ", img_id) + try : + r_pre, t_pre = predict_pose(class_name, img, fps) + r = R.from_rotvec(r_pre.reshape(3, )) + r_pre_mx = np.array(r.as_matrix()) + t_pre = np.array(t_pre).reshape(3, ) + + res = np.zeros((3, 4)) + res[:3, :3] = r_pre_mx + res[:3, 3] = t_pre + print(res) + np.save(f'{basePath}/Pose_prediction/{img_id}.npy', res) # save + + except : + print("The image is not good, mess than 50 pix segmentation") - r_pre, t_pre = predict_pose(class_name, img, fps) - r = R.from_rotvec(r_pre.reshape(3, )) - r_pre_mx = np.array(r.as_matrix()) - res = np.zeros((3, 4)) - res[:3, :3] = r_pre_mx - res[:3, 3] = t_pre - np.save(f'{basePath}/Pose_prediction/{class_name}/{img_id}.npy', res) # save