diff --git a/README.md b/README.md
index 5fd140e1157070a6baa4d4ed2db00706820a88eb..f22b41d5042a03fea5bb196f5f2fd288adf263b0 100644
--- a/README.md
+++ b/README.md
@@ -72,7 +72,16 @@ The difference between this repo and GDR-Net (CVPR2021) mainly including:
 `./core/gdrn_modeling/test_gdrn.sh <config_path> <gpu_ids> <ckpt_path> (other args)`
 
 ## Pose Refinement
-See [Pose Refinement](https://github.com/shanice-l/gdrnpp_bop2022/tree/pose_refine) for details.
+
+We utilize depth information to further refine the estimated pose.
+We provide two types of refinement: fast refinement and iterative refinement.
+
+For fast refinement, we compare the rendered object depth and the observed depth to refine translation.
+Run
+
+`./core/gdrn_modeling/test_gdrn_depth_refine.sh <config_path> <gpu_ids> <ckpt_path> (other args)`
+
+For iterative refinement, please checkout to the pose_refine branch for details.
 
 ## Citing GDRNPP
 
@@ -85,4 +94,5 @@ If you use GDRNPP in your research, please use the following BibTeX entry.
   howpublished = {\url{https://github.com/shanice-l/gdrnpp_bop2022}},
   year =         {2022}
 }
-```
\ No newline at end of file
+```
+
diff --git a/configs/_base_/gdrn_base.py b/configs/_base_/gdrn_base.py
index a37894cd84045a109f8964faa3a5764b0a902e41..182eaa5bcf6738d76f740daa985d21040d0cd9dc 100644
--- a/configs/_base_/gdrn_base.py
+++ b/configs/_base_/gdrn_base.py
@@ -167,4 +167,8 @@ TEST = dict(
     # net_ransac_pnp_rot (net_init + ransanc pnp --> net t + pnp R)
     PNP_TYPE="ransac_pnp",
     PRECISE_BN=dict(ENABLED=False, NUM_ITER=200),
+    USE_DEPTH_REFINE=False,
+    DEPTH_REFINE_ITER=2,
+    DEPTH_REFINE_THRESHOLD=0.8,
+    USE_COOR_Z_REFINE=False
 )
diff --git a/core/gdrn_modeling/engine/engine_utils.py b/core/gdrn_modeling/engine/engine_utils.py
index e4c54b9d990ae45a8f639ae4fab2b69503a5c9f9..5b4dee4df8297422386fa03be7169904f6d97df8 100644
--- a/core/gdrn_modeling/engine/engine_utils.py
+++ b/core/gdrn_modeling/engine/engine_utils.py
@@ -240,6 +240,31 @@ def batch_data_test(cfg, data, device="cuda"):
 
     return batch
 
+def batch_data_inference_roi(cfg, data, device='cuda'):
+    net_cfg = cfg.MODEL.POSE_NET
+    g_head_cfg = net_cfg.GEO_HEAD
+    batch = {}
+    batch["roi_img"] = torch.cat([d["roi_img"] for d in data], dim=0).to(device, non_blocking=True)
+    bs = batch["roi_img"].shape[0]
+
+
+    batch["roi_cam"] = torch.cat([d["cam"] for d in data], dim=0).to(device, non_blocking=True)
+    batch["roi_center"] = torch.cat([d["bbox_center"] for d in data], dim=0).to(
+        device=device, dtype=torch.float32, non_blocking=True
+    )
+    batch["roi_scale"] = [torch.as_tensor(d["scale"], device=device, dtype=torch.float32) for d in data]
+    batch["roi_scale"] = torch.cat(batch["roi_scale"], dim=0).to(
+        device=device, dtype=torch.float32, non_blocking=True)
+    batch["resize_ratio"] = [torch.as_tensor(d["resize_ratio"], device=device, dtype=torch.float32) for d in data]  # out_res/scale
+    batch["resize_ratio"] = torch.cat(batch["resize_ratio"], dim=0).to(
+        device=device, dtype=torch.float32, non_blocking=True)
+    # get crop&resized K -------------------------------------------
+    roi_crop_xy_batch = batch["roi_center"] - batch["roi_scale"].view(bs, -1) / 2
+    out_res = net_cfg.OUTPUT_RES
+    roi_resize_ratio_batch = out_res / batch["roi_scale"].view(bs, -1)
+    batch["roi_zoom_K"] = get_K_crop_resize(batch["roi_cam"], roi_crop_xy_batch, roi_resize_ratio_batch)
+    return batch
+
 
 def get_renderer(cfg, data_ref, obj_names, gpu_id=None):
     """for rendering the targets (xyz) online."""
diff --git a/core/gdrn_modeling/engine/gdrn_evaluator.py b/core/gdrn_modeling/engine/gdrn_evaluator.py
index 4b88b7283b822897a6ff15b28e4749d9410b8db2..d47050d8cb30bfc16b0dbe24ba8797f0ebc0f9c7 100644
--- a/core/gdrn_modeling/engine/gdrn_evaluator.py
+++ b/core/gdrn_modeling/engine/gdrn_evaluator.py
@@ -29,7 +29,7 @@ from lib.utils.mask_utils import binary_mask_to_rle
 from lib.utils.utils import dprint
 from lib.vis_utils.image import grid_show, vis_image_bboxes_cv2, vis_image_mask_cv2
 
-from .engine_utils import batch_data, get_out_coor, get_out_mask
+from .engine_utils import batch_data, get_out_coor, get_out_mask, batch_data_inference_roi
 from .test_utils import eval_cached_results, save_and_eval_results, to_list
 
 
@@ -61,20 +61,30 @@ class GDRN_Evaluator(DatasetEvaluator):
         self.models_3d = [
             inout.load_ply(model_path, vertex_scale=self.data_ref.vertex_scale) for model_path in self.model_paths
         ]
-        if cfg.DEBUG:
+        if cfg.DEBUG or cfg.TEST.USE_DEPTH_REFINE:
             from lib.render_vispy.model3d import load_models
             from lib.render_vispy.renderer import Renderer
 
-            self.ren = Renderer(size=(self.data_ref.width, self.data_ref.height), cam=self.data_ref.camera_matrix)
+            if cfg.TEST.USE_DEPTH_REFINE:
+                net_cfg = cfg.MODEL.POSE_NET
+                width = net_cfg.OUTPUT_RES
+                height = width
+            else:
+                width = self.data_ref.width
+                height = self.data_ref.height
+
+            self.ren = Renderer(size=(width, height), cam=self.data_ref.camera_matrix)
             self.ren_models = load_models(
                 model_paths=self.data_ref.model_paths,
                 scale_to_meter=0.001,
                 cache_dir=".cache",
-                texture_paths=self.data_ref.texture_paths,
+                texture_paths=self.data_ref.texture_paths if cfg.TEST.DEBUG else None,
                 center=False,
                 use_cache=True,
             )
 
+        self.depth_refine_threshold = cfg.TEST.DEPTH_REFINE_THRESHOLD
+
         # eval cached
         if cfg.VAL.EVAL_CACHED or cfg.VAL.EVAL_PRINT_ONLY:
             eval_cached_results(self.cfg, self._output_dir, obj_ids=self.obj_ids)
@@ -164,6 +174,9 @@ class GDRN_Evaluator(DatasetEvaluator):
             else:
                 raise NotImplementedError
 
+        if cfg.TEST.USE_DEPTH_REFINE:
+            return self.process_depth_refine(inputs, outputs, out_dict)
+
         out_rots = out_dict["rot"].detach().to(self._cpu_device).numpy()
         out_transes = out_dict["trans"].detach().to(self._cpu_device).numpy()
 
@@ -445,6 +458,122 @@ class GDRN_Evaluator(DatasetEvaluator):
                 item["time"] = output["time"]
             self._predictions.extend(json_results)
 
+    def process_depth_refine(self, inputs, outputs, out_dict):
+        """
+        Args:
+            inputs: the inputs to a model.
+                It is a list of dict. Each dict corresponds to an image and
+                contains keys like "height", "width", "file_name", "image_id", "scene_id".
+            outputs:
+        """
+        cfg = self.cfg
+        net_cfg = cfg.MODEL.POSE_NET
+        out_coor_x = out_dict["coor_x"].detach()
+        out_coor_y = out_dict["coor_y"].detach()
+        out_coor_z = out_dict["coor_z"].detach()
+        out_xyz = get_out_coor(cfg, out_coor_x, out_coor_y, out_coor_z)
+        out_xyz = out_xyz.to(self._cpu_device) #.numpy()
+
+        out_mask = get_out_mask(cfg, out_dict["mask"].detach())
+        out_mask = out_mask.to(self._cpu_device) #.numpy()
+        out_rots = out_dict["rot"].detach().to(self._cpu_device).numpy()
+        out_transes = out_dict["trans"].detach().to(self._cpu_device).numpy()
+
+        zoom_K = batch_data_inference_roi(cfg, inputs)['roi_zoom_K']
+
+        out_i = -1
+        for i, (_input, output) in enumerate(zip(inputs, outputs)):
+            start_process_time = time.perf_counter()
+            json_results = []
+            for inst_i in range(len(_input["roi_img"])):
+                out_i += 1
+
+                K = _input["cam"][inst_i].cpu().numpy().copy()
+                # print('K', K)
+
+                K_crop = zoom_K[inst_i].cpu().numpy().copy()
+                # print('K_crop', K_crop)
+
+                roi_label = _input["roi_cls"][inst_i]  # 0-based label
+                score = _input["score"][inst_i]
+                roi_label, cls_name = self._maybe_adapt_label_cls_name(roi_label)
+                if cls_name is None:
+                    continue
+
+                scene_im_id_split = _input["scene_im_id"][inst_i].split("/")
+                scene_id = scene_im_id_split[0]
+                im_id = int(scene_im_id_split[1])
+                obj_id = self.data_ref.obj2id[cls_name]
+
+                # get pose
+                xyz_i = out_xyz[out_i].permute(1, 2, 0)
+                mask_i = np.squeeze(out_mask[out_i])
+
+                rot_est = out_rots[out_i]
+                trans_est = out_transes[out_i]
+                pose_est = np.hstack([rot_est, trans_est.reshape(3, 1)])
+                depth_sensor_crop = _input['roi_depth'][inst_i].cpu().numpy().copy().squeeze()
+                depth_sensor_mask_crop = depth_sensor_crop > 0
+
+                net_cfg = cfg.MODEL.POSE_NET
+                crop_res = net_cfg.OUTPUT_RES
+
+
+
+                for _ in range(cfg.TEST.DEPTH_REFINE_ITER):
+                    self.ren.clear()
+                    self.ren.set_cam(K_crop)
+                    self.ren.draw_model(self.ren_models[self.data_ref.objects.index(cls_name)], pose_est)
+                    ren_im, ren_dp = self.ren.finish()
+                    ren_mask = ren_dp > 0
+
+                    if self.cfg.TEST.USE_COOR_Z_REFINE:
+                        coor_np = xyz_i.numpy()
+                        coor_np_t = coor_np.reshape(-1, 3)
+                        coor_np_t = coor_np_t.T
+                        coor_np_r = rot_est @ coor_np_t
+                        coor_np_r = coor_np_r.T
+                        coor_np_r = coor_np_r.reshape(crop_res, crop_res, 3)
+                        query_img_norm = coor_np_r[:, :, -1] * mask_i.numpy()
+                        query_img_norm = query_img_norm * ren_mask * depth_sensor_mask_crop
+                    else:
+                        query_img = xyz_i
+
+                        query_img_norm = torch.norm(query_img, dim=-1) * mask_i
+                        query_img_norm = query_img_norm.numpy() * ren_mask * depth_sensor_mask_crop
+                    norm_sum = query_img_norm.sum()
+                    if norm_sum == 0:
+                        continue
+                    query_img_norm /= norm_sum
+                    norm_mask = query_img_norm > (query_img_norm.max() * self.depth_refine_threshold)
+                    yy, xx = np.argwhere(norm_mask).T  # 2 x (N,)
+                    depth_diff = depth_sensor_crop[yy, xx] - ren_dp[yy, xx]
+                    depth_adjustment = np.median(depth_diff)
+
+
+
+                    yx_coords = np.meshgrid(np.arange(crop_res), np.arange(crop_res))
+                    yx_coords = np.stack(yx_coords[::-1], axis=-1)  # (crop_res, crop_res, 2yx)
+                    yx_ray_2d = (yx_coords * query_img_norm[..., None]).sum(axis=(0, 1))  # y, x
+                    ray_3d = np.linalg.inv(K_crop) @ (*yx_ray_2d[::-1], 1)
+                    ray_3d /= ray_3d[2]
+
+                    trans_delta = ray_3d[:, None] * depth_adjustment
+                    trans_est = trans_est + trans_delta.reshape(3)
+                    pose_est = np.hstack([rot_est, trans_est.reshape(3, 1)])
+
+                json_results.extend(
+                    self.pose_prediction_to_json(
+                        pose_est, scene_id, im_id, obj_id=obj_id, score=score, pose_time=output["time"], K=K
+                    )
+                )
+            output["time"] += time.perf_counter() - start_process_time
+
+            # process time for this image
+            for item in json_results:
+                item["time"] = output["time"]
+            self._predictions.extend(json_results)
+
     def evaluate(self):
         # bop toolkit eval in subprocess, no return value
         if self._distributed:
@@ -601,7 +730,7 @@ def gdrn_inference_on_dataset(cfg, model, data_loader, evaluator, amp_test=False
             #         show_titles.extend(["coord_2d_x", "coord_2d_y"])
             #         grid_show(show_ims, show_titles, row=1, col=3)
 
-            if cfg.INPUT.WITH_DEPTH:
+            if cfg.INPUT.WITH_DEPTH and "depth" in cfg.MODEL.POSE_NET.NAME.lower():
                 inp = torch.cat([batch["roi_img"], batch["roi_depth"]], dim=1)
             else:
                 inp = batch["roi_img"]
diff --git a/core/gdrn_modeling/test_gdrn_depth_refine.sh b/core/gdrn_modeling/test_gdrn_depth_refine.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e242c6b4c4c49fd4bc1062d010d0eda6abb21ca0
--- /dev/null
+++ b/core/gdrn_modeling/test_gdrn_depth_refine.sh
@@ -0,0 +1,26 @@
+#!/usr/bin/env bash
+# test
+set -x
+this_dir=$(dirname "$0")
+# commonly used opts:
+
+# MODEL.WEIGHTS: resume or pretrained, or test checkpoint
+CFG=$1
+CUDA_VISIBLE_DEVICES=$2
+IFS=',' read -ra GPUS <<< "$CUDA_VISIBLE_DEVICES"
+# GPUS=($(echo "$CUDA_VISIBLE_DEVICES" | tr ',' '\n'))
+NGPU=${#GPUS[@]}  # echo "${GPUS[0]}"
+echo "use gpu ids: $CUDA_VISIBLE_DEVICES num gpus: $NGPU"
+CKPT=$3
+if [ ! -f "$CKPT" ]; then
+    echo "$CKPT does not exist."
+    exit 1
+fi
+NCCL_DEBUG=INFO
+OMP_NUM_THREADS=1
+MKL_NUM_THREADS=1
+PYTHONPATH="$this_dir/../..":$PYTHONPATH \
+CUDA_VISIBLE_DEVICES=$2 python $this_dir/main_gdrn.py \
+    --config-file $CFG --num-gpus $NGPU --eval-only \
+    --opts MODEL.WEIGHTS=$CKPT INPUT.WITH_DEPTH=True TEST.USE_DEPTH_REFINE=True \
+    ${@:4}