From a70764f6f8d397e7a4a049c10790de7c539de3f0 Mon Sep 17 00:00:00 2001
From: liuxingyu <lxy17@foxmail.com>
Date: Sat, 25 Mar 2023 17:11:37 +0800
Subject: [PATCH] fix some bugs on depth refinement

---
 core/gdrn_modeling/engine/gdrn_evaluator.py           | 2 +-
 core/gdrn_modeling/models/GDRN.py                     | 2 +-
 core/gdrn_modeling/models/GDRN_Dstream_double_mask.py | 2 +-
 core/gdrn_modeling/models/GDRN_cls.py                 | 2 +-
 core/gdrn_modeling/models/GDRN_cls2reg.py             | 2 +-
 core/gdrn_modeling/models/GDRN_double_mask.py         | 2 +-
 core/gdrn_modeling/models/GDRN_no_region.py           | 2 +-
 7 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/core/gdrn_modeling/engine/gdrn_evaluator.py b/core/gdrn_modeling/engine/gdrn_evaluator.py
index e1cb318..a2bf2f4 100644
--- a/core/gdrn_modeling/engine/gdrn_evaluator.py
+++ b/core/gdrn_modeling/engine/gdrn_evaluator.py
@@ -512,7 +512,7 @@ class GDRN_Evaluator(DatasetEvaluator):
                 rot_est = out_rots[out_i]
                 trans_est = out_transes[out_i]
                 pose_est = np.hstack([rot_est, trans_est.reshape(3, 1)])
-                depth_sensor_crop = _input['roi_depth'][inst_i].cpu().numpy().copy().squeeze()
+                depth_sensor_crop = cv2.resize(_input['roi_depth'][inst_i][-1].cpu().numpy().copy().squeeze(), (self.out_res, self.out_res))
                 depth_sensor_mask_crop = depth_sensor_crop > 0
 
                 net_cfg = cfg.MODEL.POSE_NET
diff --git a/core/gdrn_modeling/models/GDRN.py b/core/gdrn_modeling/models/GDRN.py
index 1a0e3be..98cddad 100644
--- a/core/gdrn_modeling/models/GDRN.py
+++ b/core/gdrn_modeling/models/GDRN.py
@@ -198,7 +198,7 @@ class GDRN(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 out_dict.update({"mask": mask, "coor_x": coor_x, "coor_y": coor_y, "coor_z": coor_z, "region": region})
         else:
diff --git a/core/gdrn_modeling/models/GDRN_Dstream_double_mask.py b/core/gdrn_modeling/models/GDRN_Dstream_double_mask.py
index 9d76dc9..cfa1382 100644
--- a/core/gdrn_modeling/models/GDRN_Dstream_double_mask.py
+++ b/core/gdrn_modeling/models/GDRN_Dstream_double_mask.py
@@ -221,7 +221,7 @@ class GDRN_Dstream_DoubleMask(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 out_dict.update(
                     {
diff --git a/core/gdrn_modeling/models/GDRN_cls.py b/core/gdrn_modeling/models/GDRN_cls.py
index 66dc9b2..4596fd5 100644
--- a/core/gdrn_modeling/models/GDRN_cls.py
+++ b/core/gdrn_modeling/models/GDRN_cls.py
@@ -211,7 +211,7 @@ class GDRN_CLS(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 out_dict.update({"mask": mask, "coor_x": coor_x, "coor_y": coor_y, "coor_z": coor_z, "region": region})
         else:
diff --git a/core/gdrn_modeling/models/GDRN_cls2reg.py b/core/gdrn_modeling/models/GDRN_cls2reg.py
index 20a566e..13bb814 100644
--- a/core/gdrn_modeling/models/GDRN_cls2reg.py
+++ b/core/gdrn_modeling/models/GDRN_cls2reg.py
@@ -211,7 +211,7 @@ class GDRN_CLS2REG(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 # TODO: use cls or reg in cfg
                 out_dict.update({"mask": mask, "coor_x": coor_x, "coor_y": coor_y, "coor_z": coor_z, "region": region})
diff --git a/core/gdrn_modeling/models/GDRN_double_mask.py b/core/gdrn_modeling/models/GDRN_double_mask.py
index dad1230..7973d93 100644
--- a/core/gdrn_modeling/models/GDRN_double_mask.py
+++ b/core/gdrn_modeling/models/GDRN_double_mask.py
@@ -200,7 +200,7 @@ class GDRN_DoubleMask(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 out_dict.update(
                     {
diff --git a/core/gdrn_modeling/models/GDRN_no_region.py b/core/gdrn_modeling/models/GDRN_no_region.py
index 87c73a8..3758999 100644
--- a/core/gdrn_modeling/models/GDRN_no_region.py
+++ b/core/gdrn_modeling/models/GDRN_no_region.py
@@ -197,7 +197,7 @@ class GDRN_NoRegion(nn.Module):
 
         if not do_loss:  # test
             out_dict = {"rot": pred_ego_rot, "trans": pred_trans}
-            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY:
+            if cfg.TEST.USE_PNP or cfg.TEST.SAVE_RESULTS_ONLY or cfg.TEST.USE_DEPTH_REFINE:
                 # TODO: move the pnp/ransac inside forward
                 out_dict.update({"mask": mask, "coor_x": coor_x, "coor_y": coor_y, "coor_z": coor_z, "region": region})
         else:
-- 
GitLab