diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..6c04bd539b6db25e1ec259f86fc110f66acc5230
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,115 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+/data
+/output
+/models
+/log
+/lib/dataset/cocoapi
+/lib/nms/build
+*.pyd
+*.obj
+/.idea
+
+*.npz
+
diff --git a/LICENCE b/LICENCE
new file mode 100644
index 0000000000000000000000000000000000000000..d1a6609a839be0cef72a5f22e55842a1d384849f
--- /dev/null
+++ b/LICENCE
@@ -0,0 +1,19 @@
+MIT License
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE
\ No newline at end of file
diff --git a/README.md b/README.md
index 0df49b08d9ad0371b6b2084cff6650c5563b4dac..3678f80e9d2ad00d2045925bae5a21bb615014e0 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
 <h1 style="text-align:center">
-DriPE: A Dataset for Human Pose Estimation in Real-World Driving Settings
+Multitask Metamodel for Keypoint Visibility Prediction in Human Pose Estimation
 </h1>
 <div style="text-align:center">
 <h3>
@@ -8,106 +8,68 @@ DriPE: A Dataset for Human Pose Estimation in Real-World Driving Settings
 <a href="https://liris.cnrs.fr/page-membre/laure-tougne">Laure Tougne</a>
 <br>
 <br>
-ICCV: International Conference on Computer Vision 2021
-<br>
-Workshop AVVision : Autonomous Vehicle Vision
+International Joint Conference on Computer Vision, Imaging and Computer Graphics Theory and Applications (VISAPP) 
 </h3>
 </div>
 
 # Table of content
 - [Overview](#overview)
-- [Dataset](#dataset)
-- [Networks](#networks)
-- [Evaluation](#evaluation)
+- [Installation](#installation)
+- [Testing](#valid-visibility-module)
+- [Training](#training-visibility-module)
 - [Citation](#citation)
 - [Acknowledgements](#acknowledgements)
 
 # Overview
 This repository contains the materials presented in the paper
-[DriPE: A Dataset for Human Pose Estimation in Real-World Driving Settings](https://openaccess.thecvf.com/content/ICCV2021W/AVVision/papers/Guesdon_DriPE_A_Dataset_for_Human_Pose_Estimation_in_Real-World_Driving_ICCVW_2021_paper.pdf).
-
-We provide the link to download the DriPE [dataset](#dataset),
-along with trained weights for the three [networks](#networks) presented in this paper: 
-SBl, MSPN and RSN.
-Furthermore, we provide the code to evaluate HPE networks with [mAPK metric](#evaluation), our keypoint-centered metric.
-
-# Dataset
-DriPE dataset can be download [here](http://dionysos.univ-lyon2.fr/~ccrispim/DriPE/DriPE.zip). We provide 10k images, 
-along with keypoint annotations, split as:
-* 6.4k for training
-* 1.3k for validation
-* 1.3k for testing
+[Multitask Metamodel for Keypoint Visibility Prediction in Human Pose Estimation]().
 
-The annotation files follow the COCO annotation style, with 17 keypoints. 
-More information can be found [here](https://cocodataset.org/#format-data).
+![Metamodel](assets/metamodel.png)
 
-##### **DriPE image samples**
-![DriPE image samples](assets/dripe_sample.png)
+Our code is based on the [Simple Baseline](https://github.com/microsoft/human-pose-estimation.pytorch) pytorch implementation.
+However, several modifications have been made from the original code to allow the implementation of our metamodel.
 
-# Networks
-We used in our study three architectures:
-* __SBl__: Simple Baselines for Human Pose Estimation and Tracking (Xiao 2018) [GitHub](https://github.com/microsoft/human-pose-estimation.pytorch)
-* __MSPN__: Rethinking on Multi-Stage Networks for Human Pose Estimation (Li 2019) [GitHub](https://github.com/megvii-detection/MSPN)
-* __RSN__: Learning Delicate Local Representations for Multi-Person Pose Estimation (Cai 2020) [GitHub](https://github.com/caiyuanhao1998/RSN)
-
-We used for training and for inference the code provided by the authors in the three linked repositories.
-Weights of the trained model evaluated in our study can be found [here](http://dionysos.univ-lyon2.fr/~ccrispim/DriPE/models).
-More details about the training can be found in our [paper](https://openaccess.thecvf.com/content/ICCV2021W/AVVision/papers/Guesdon_DriPE_A_Dataset_for_Human_Pose_Estimation_in_Real-World_Driving_ICCVW_2021_paper.pdf).
-
-##### **HPE on the COCO 2017 validation set.**
-AP OKS (\%) | AP | AP<sup>50</sup> | AP<sup>75</sup>  | AP<sup>L</sup> | AR | AR<sup>50</sup> | AR<sup>75</sup> | AR<sup>L</sup>
-:---- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 
-SBl | 72 | 92 | 80 | 77 | 76 | 93 | 82 | 80
-MSPN | __77__ | 94 | 85 | 82 | __80__ | 95 | 87 | 85
-RSN | 76 | 94 | 84 | 81 | 79 | 94 | 85 | 84
+# Installation
+- Clone this repository
+- Follow the instructions on the [Simple Baseline repository](https://github.com/microsoft/human-pose-estimation.pytorch).
+- Download the DriPE dataset [here](https://gitlab.liris.cnrs.fr/aura_autobehave/dripe) and place it in the `data/` directory.
 
-##### **HPE on the DriPE test set.**
-AP OKS (\%) | AP | AP<sup>50</sup> | AP<sup>75</sup> | AP<sup>L</sup> | AR | AR<sup>50</sup> | AR<sup>75</sup> | AR<sup>L</sup>
-:---- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 
-SBl | 75 | 99 | 91 | 75 | 81 | 99 | 94 | 81
-MSPN | 81 | 99 | 97 | __81__ | 85 | 99 | 97 | __85__
-RSN | 75 | 99 | 93 | 75 | 79 | 99 | 95 | 79
 
-# Evaluation
-Evaluation is performed using two metrics:
-* __AP OKS__, the original metric from COCO dataset, which is already implemented in the [cocoapi](https://github.com/cocodataset/cocoapi)
-and in the three network repositories
-* __mAPK__, our new keypoint-centered metric. We provide script for evaluate the network predictions in this repository.
-
-Evaluation with mAPK can be used by running the eval_mpk.py script.
-```Script to evaluate prediction in COCO format using the mAPK metric.
-Usage: python eval_mapk.py [json_prediction_path] [json_annotation_path]
-Paths can be absolute, relative to the script or relative to the respective json/gts or json/preds directory.
-    -h, --help\tdisplay this help message and exit
+# Valid visibility module
+You can download pretrained weights [here](http://dionysos.univ-lyon2.fr/~ccrispimVisPred/models).
+For example, execute:
+```
+python pose_estimation/valid.py \
+    --cfg experiments/coco/resnet50/256x192_vis_freeze.yaml \
+    --flip-test \
+    --model-file models/pytorch/pose_coco/coco_vis2_raise_soft.pth.tar
 ```
 
-We provide in this repo one annotation and one prediction file. To evaluate these predictions, run:
+# Training visibility module
+You can download pretrained weight on the base model [here](http://dionysos.univ-lyon2.fr/~ccrispimVisPred/models/coco_vis2_0_no_linear.pth.tar).
+Place this file in `models/pytorch/resnet50_vis`.
+Then, execute:
 ```
-python eval_mapk.py keypoints_out_SBL_autob_test-repo.json autob_coco_test.json
+python pose_estimation/train.py \
+    --cfg experiments/coco/resnet50/256x192_vis_freeze.yaml
 ```
 
-Expected results are :
-	F1 score: 0.733
 
-Metric |  Head  |  Should.  |  Elbow  |  Wrist  |  Hip  |  Knee  |  Ankle  |  All  |  Mean  |  Std
-:--- | :---: | :----: | :----: | :---: | :----: | :----: | :----: | :----: | :----: | :-----:
-AP  | 0.30 |  0.86 |  0.78 |  0.92 | 0.91 | 0.76 |  0.13 | 0.68 | 0.67 | 0.29
-AR  | 0.87 |  0.92 |  0.93 |  0.96 | 0.88 | 0.61 |  0.05 | 0.80 | 0.75 | 0.31
+##### **HPE on the COCO 2017 validation set.**
+AP OKS (\%) | AP | AP<sup>50</sup> | AP<sup>75</sup> | AP<sup>M</sup> | AP<sup>L</sup> | AR | AR<sup>50</sup> | AR<sup>75</sup> | AR<sup>M</sup> | AR<sup>L</sup>
+:--- | :---: | :---: | :---: | :---: | :---: | :---: |:---: | :---: | :---: | :---: | 
+SBl | 72 | 92 | 79 | 69 | 76 | 75 | 93 | 82 | 72 | 80
 
-# Citation
-If you use this dataset or code in your research, please send us an email with the following details and we will update our webpage with your results.
-* Performance (%)
-* Experimental Setup
-* Paper details
 
-The DRIPE dataset is only to be used for scientific purposes. It must not be republished other than by the original authors. The scientific use includes processing the data and showing it in publications and presentations. If you use it, please cite:
+# Citation
+If you use it, please cite:
 ```
-@InProceedings{Guesdon_2021_ICCV,
+@InProceedings{Guesdon_2022_Visapp,
     author    = {Guesdon, Romain and Crispim-Junior, Carlos and Tougne, Laure},
-    title     = {DriPE: A Dataset for Human Pose Estimation in Real-World Driving Settings},
+    title     = {Multitask Metamodel for Keypoint Visibility Prediction in Human Pose Estimation},
     booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops},
-    month     = {October},
-    year      = {2021},
+    month     = {February},
+    year      = {2022},
     pages     = {2865-2874}
 }
 ```
diff --git a/assets/logo_liris.png b/assets/logo_liris.png
new file mode 100644
index 0000000000000000000000000000000000000000..37143f9d88e379ff7c6314eaa23b069f7ce997ef
Binary files /dev/null and b/assets/logo_liris.png differ
diff --git a/assets/logo_ra.png b/assets/logo_ra.png
new file mode 100644
index 0000000000000000000000000000000000000000..6ceef305fd77646ce549c8eabbc6b34c0f048c7a
Binary files /dev/null and b/assets/logo_ra.png differ
diff --git a/assets/metamodel.png b/assets/metamodel.png
new file mode 100644
index 0000000000000000000000000000000000000000..51ae6c24c47159243eea5928c435ec646295846e
Binary files /dev/null and b/assets/metamodel.png differ
diff --git a/experiments/coco/resnet50/256x192_d256x3_adam_lr1e-3.yaml b/experiments/coco/resnet50/256x192_d256x3_adam_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..107c1403c14d0cbf7109bb15134fdc010b5c3c30
--- /dev/null
+++ b/experiments/coco/resnet50/256x192_d256x3_adam_lr1e-3.yaml
@@ -0,0 +1,76 @@
+GPUS: '0'
+DATA_DIR: ''
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 4
+PRINT_FREQ: 100
+
+DATASET:
+  DATASET: 'coco'
+  ROOT: 'data/coco/'
+  TEST_SET: 'val2017'
+  TRAIN_SET: 'train2017'
+  FLIP: true
+  ROT_FACTOR: 40
+  SCALE_FACTOR: 0.3
+MODEL:
+  NAME: 'pose_resnet'
+  PRETRAINED: 'models/pytorch/imagenet/resnet50-19c8e357.pth'
+  IMAGE_SIZE:
+  - 192
+  - 256
+  NUM_JOINTS: 17
+  EXTRA:
+    TARGET_TYPE: 'gaussian'
+    HEATMAP_SIZE:
+    - 48
+    - 64
+    SIGMA: 2
+    FINAL_CONV_KERNEL: 1
+    DECONV_WITH_BIAS: false
+    NUM_DECONV_LAYERS: 3
+    NUM_DECONV_FILTERS:
+    - 256
+    - 256
+    - 256
+    NUM_DECONV_KERNELS:
+    - 4
+    - 4
+    - 4
+    NUM_LAYERS: 50
+LOSS:
+  USE_TARGET_WEIGHT: true
+TRAIN:
+  BATCH_SIZE: 32
+  SHUFFLE: true
+  BEGIN_EPOCH: 0
+  END_EPOCH: 140
+  RESUME: false
+  OPTIMIZER: 'adam'
+  LR: 0.001
+  LR_FACTOR: 0.1
+  LR_STEP:
+  - 90
+  - 120 
+  WD: 0.0001
+  GAMMA1: 0.99
+  GAMMA2: 0.0
+  MOMENTUM: 0.9
+  NESTEROV: false
+TEST:
+  BATCH_SIZE: 32
+  COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+  BBOX_THRE: 1.0
+  FLIP_TEST: false
+  IMAGE_THRE: 0.0
+  IN_VIS_THRE: 0.2
+  MODEL_FILE: ''
+  NMS_THRE: 1.0
+  OKS_THRE: 0.9
+  USE_GT_BBOX: false
+DEBUG:
+  DEBUG: true
+  SAVE_BATCH_IMAGES_GT: true
+  SAVE_BATCH_IMAGES_PRED: true
+  SAVE_HEATMAPS_GT: true
+  SAVE_HEATMAPS_PRED: true
diff --git a/experiments/coco/resnet50/256x192_vis_freeze.yaml b/experiments/coco/resnet50/256x192_vis_freeze.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..73d38dc93993a61c7d0af980c0a98b17b2fe1eb4
--- /dev/null
+++ b/experiments/coco/resnet50/256x192_vis_freeze.yaml
@@ -0,0 +1,88 @@
+GPUS: '0'
+DATA_DIR: ''
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 0
+PRINT_FREQ: 100
+
+DATASET:
+  DATASET: 'coco'
+  ROOT: 'data/coco/'
+  TEST_SET: 'val2017'
+  TRAIN_SET: 'train2017'
+  FLIP: true
+  ROT_FACTOR: 40
+  SCALE_FACTOR: 0.3
+MODEL:
+  NAME: 'pose_vis'
+  PRETRAINED: 'models/pytorch/resnet50_vis/coco_vis2_0_no_linear.pth.tar'
+  IMAGE_SIZE:
+  - 192
+  - 256
+  NUM_JOINTS: 17
+  PREDICT_VIS: true
+  EXTRA:
+    TARGET_TYPE: 'gaussian'
+    HEATMAP_SIZE:
+    - 48
+    - 64
+    SIGMA: 2
+    FINAL_CONV_KERNEL: 1
+    DECONV_WITH_BIAS: false
+    NUM_DECONV_LAYERS: 3
+    NUM_DECONV_FILTERS:
+    - 256
+    - 256
+    - 256
+    NUM_DECONV_KERNELS:
+    - 4
+    - 4
+    - 4
+    NUM_LINEAR_LAYERS:
+    - 2048
+    NUM_LAYERS: 50
+LOSS:
+  USE_TARGET_WEIGHT: false
+  
+  USE_CLASS_WEIGHT: true
+  VIS_RATIO: 0.1
+  VIS_FACTOR: 0.1
+  VIS_STEP:
+  - 20
+  - 40
+  - 60
+TRAIN:
+  BATCH_SIZE: 32
+  SHUFFLE: true
+  BEGIN_EPOCH: 0
+  END_EPOCH: 80
+  RESUME: false
+  OPTIMIZER: 'adam'
+  LR: 0.001
+  LR_FACTOR: 0.1
+  LR_STEP:
+  - 50
+  WD: 0.0001
+  GAMMA1: 0.99
+  GAMMA2: 0.0
+  MOMENTUM: 0.9
+  NESTEROV: false
+  SAVE_CHECKPOINT: 20
+  FREEZE: true
+TEST:
+  BATCH_SIZE: 32
+  COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+  BBOX_THRE: 1.0
+  FLIP_TEST: false
+  IMAGE_THRE: 0.0
+  IN_VIS_THRE: 0.2
+  MODEL_FILE: ''
+  NMS_THRE: 1.0
+  OKS_THRE: 0.9
+  USE_GT_BBOX: true
+DEBUG:
+  DEBUG: true
+  SAVE_BATCH_IMAGES_GT: false
+  SAVE_BATCH_IMAGES_PRED: false
+  SAVE_HEATMAPS_GT: false
+  SAVE_HEATMAPS_PRED: false
diff --git a/experiments/dripe/resnet50/256x192_vis_2_raise_soft_fine.yaml b/experiments/dripe/resnet50/256x192_vis_2_raise_soft_fine.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6a110d059d958311617025c115edc71328de0e5
--- /dev/null
+++ b/experiments/dripe/resnet50/256x192_vis_2_raise_soft_fine.yaml
@@ -0,0 +1,80 @@
+GPUS: '0'
+DATA_DIR: ''
+OUTPUT_DIR: 'output'
+LOG_DIR: 'log'
+WORKERS: 0
+PRINT_FREQ: 100
+
+DATASET:
+  DATASET: 'dripe'
+  ROOT: 'data/dripe/'
+  TEST_SET: 'dripe_coco_train'
+  TRAIN_SET: 'dripe_coco_train'
+  FLIP: true
+  ROT_FACTOR: 40
+  SCALE_FACTOR: 0.3
+MODEL:
+  NAME: 'pose_resnet_vis_2'
+  PRETRAINED: 'output/coco/pose_resnet_vis_2_50/256x192_vis_2_raise_soft/final_state.pth.tar'
+  IMAGE_SIZE:
+  - 192
+  - 256
+  NUM_JOINTS: 17
+  PREDICT_VIS: true
+  EXTRA:
+    TARGET_TYPE: 'gaussian'
+    HEATMAP_SIZE:
+    - 48
+    - 64
+    SIGMA: 2
+    FINAL_CONV_KERNEL: 1
+    DECONV_WITH_BIAS: false
+    NUM_DECONV_LAYERS: 3
+    NUM_DECONV_FILTERS:
+    - 256
+    - 256
+    - 256
+    NUM_DECONV_KERNELS:
+    - 4
+    - 4
+    - 4
+    NUM_LAYERS: 50
+LOSS:
+  USE_TARGET_WEIGHT: false
+  
+  USE_CLASS_WEIGHT: true
+  VIS_RATIO: 0.2
+TRAIN:
+  BATCH_SIZE: 32
+  SHUFFLE: true
+  BEGIN_EPOCH: 0
+  END_EPOCH: 10
+  RESUME: false
+  OPTIMIZER: 'adam'
+  LR: 0.0001
+  LR_FACTOR: 0.1
+  LR_STEP:
+  - 50
+  WD: 0.0001
+  GAMMA1: 0.99
+  GAMMA2: 0.0
+  MOMENTUM: 0.9
+  NESTEROV: false
+  SAVE_CHECKPOINT: 20
+TEST:
+  BATCH_SIZE: 12
+  COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
+  BBOX_THRE: 1.0
+  FLIP_TEST: false
+  IMAGE_THRE: 0.0
+  IN_VIS_THRE: 0.2
+  MODEL_FILE: ''
+  NMS_THRE: 1.0
+  OKS_THRE: 0.9
+  USE_GT_BBOX: true
+DEBUG:
+  DEBUG: true
+  SAVE_BATCH_IMAGES_GT: false
+  SAVE_BATCH_IMAGES_PRED: false
+  SAVE_HEATMAPS_GT: true
+  SAVE_HEATMAPS_PRED: true
diff --git a/lib/Makefile b/lib/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..57d5804403a10cbe1effa219b30a848c79a07785
--- /dev/null
+++ b/lib/Makefile
@@ -0,0 +1,4 @@
+all:
+	cd nms; python setup.py build_ext --inplace; rm -rf build; cd ../../
+clean:
+	cd nms; rm *.so; cd ../../
diff --git a/lib/core/config.py b/lib/core/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6d2e6aceaeb86846e4a495bdf3426c227db0e5
--- /dev/null
+++ b/lib/core/config.py
@@ -0,0 +1,267 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import yaml
+
+import numpy as np
+from easydict import EasyDict as edict
+
+config = edict()
+
+config.OUTPUT_DIR = ''
+config.LOG_DIR = ''
+config.DATA_DIR = ''
+config.GPUS = '0'
+config.WORKERS = 4
+config.PRINT_FREQ = 20
+
+# Cudnn related params
+config.CUDNN = edict()
+config.CUDNN.BENCHMARK = True
+config.CUDNN.DETERMINISTIC = False
+config.CUDNN.ENABLED = True
+
+# pose_resnet related params
+POSE_RESNET = edict()
+POSE_RESNET.NUM_LAYERS = 50
+POSE_RESNET.DECONV_WITH_BIAS = False
+POSE_RESNET.NUM_DECONV_LAYERS = 3
+POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
+POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
+POSE_RESNET.FINAL_CONV_KERNEL = 1
+POSE_RESNET.TARGET_TYPE = 'gaussian'
+POSE_RESNET.HEATMAP_SIZE = [64, 64]  # width * height, ex: 24 * 32
+POSE_RESNET.SIGMA = 2
+
+POSE_RESNET_VIS = edict()
+POSE_RESNET_VIS.update(POSE_RESNET)
+POSE_RESNET_VIS.NUM_LINEAR_LAYERS = [4096, 2048, 1024]
+
+POS_EFFICIENT_VIS = edict()
+POS_EFFICIENT_VIS.EFFICIENT_NAME = ''
+POS_EFFICIENT_VIS.NUM_LINEAR_LAYERS = [4096, 2048, 1024]
+
+POS_MSPN_VIS = edict()
+POS_MSPN_VIS.STAGE_NUM = 2
+POS_MSPN_VIS.UPSAMPLE_CHANNEL_NUM = 256
+
+MODEL_EXTRAS = {
+    'pose_resnet': POSE_RESNET,
+    'pose_resnet_vis': POSE_RESNET_VIS,
+    'pose_resnet_vis_2': POSE_RESNET_VIS,
+    'pose_resnet_vis_3': POSE_RESNET_VIS,
+    'pose_resnet_merge': POSE_RESNET_VIS,
+    'efficient_pose': POS_EFFICIENT_VIS,
+    'mspn_pose': POS_MSPN_VIS,
+}
+
+# common params for NETWORK
+config.MODEL = edict()
+config.MODEL.NAME = 'pose_resnet'
+config.MODEL.INIT_WEIGHTS = True
+config.MODEL.PRETRAINED = ''
+config.MODEL.NUM_JOINTS = 16
+config.MODEL.IMAGE_SIZE = [256, 256]  # width * height, ex: 192 * 256
+config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
+
+config.MODEL.PREDICT_VIS = False
+config.MODEL.NB_VIS = 3
+config.MODEL.STYLE = 'pytorch'
+
+config.LOSS = edict()
+config.LOSS.USE_TARGET_WEIGHT = True
+
+config.LOSS.VIS_RATIO = 0.5
+config.LOSS.VIS_FACTOR = 0.25
+config.LOSS.VIS_STEP = []
+config.LOSS.USE_CLASS_WEIGHT = False
+
+# DATASET related params
+config.DATASET = edict()
+config.DATASET.ROOT = ''
+config.DATASET.DATASET = 'mpii'
+config.DATASET.TRAIN_SET = 'train'
+config.DATASET.TEST_SET = 'valid'
+config.DATASET.DATA_FORMAT = 'jpg'
+config.DATASET.HYBRID_JOINTS_TYPE = ''
+config.DATASET.SELECT_DATA = False
+
+# training data augmentation
+config.DATASET.FLIP = True
+config.DATASET.SCALE_FACTOR = 0.25
+config.DATASET.ROT_FACTOR = 30
+
+# train
+config.TRAIN = edict()
+
+config.TRAIN.LR_FACTOR = 0.1
+config.TRAIN.LR_STEP = [90, 110]
+config.TRAIN.LR = 0.001
+
+config.TRAIN.OPTIMIZER = 'adam'
+config.TRAIN.MOMENTUM = 0.9
+config.TRAIN.WD = 0.0001
+config.TRAIN.NESTEROV = False
+config.TRAIN.GAMMA1 = 0.99
+config.TRAIN.GAMMA2 = 0.0
+
+config.TRAIN.BEGIN_EPOCH = 0
+config.TRAIN.END_EPOCH = 140
+
+config.TRAIN.RESUME = False
+config.TRAIN.CHECKPOINT = ''
+
+config.TRAIN.SAVE_CHECKPOINT = 0
+
+config.TRAIN.BATCH_SIZE = 32
+config.TRAIN.SHUFFLE = True
+
+config.TRAIN.FREEZE = False
+
+# testing
+config.TEST = edict()
+
+# size of images for each device
+config.TEST.BATCH_SIZE = 32
+# Test Model Epoch
+config.TEST.FLIP_TEST = False
+config.TEST.POST_PROCESS = True
+config.TEST.SHIFT_HEATMAP = True
+
+config.TEST.USE_GT_BBOX = False
+# nms
+config.TEST.OKS_THRE = 0.5
+config.TEST.IN_VIS_THRE = 0.0
+config.TEST.COCO_BBOX_FILE = ''
+config.TEST.BBOX_THRE = 1.0
+config.TEST.MODEL_FILE = ''
+config.TEST.IMAGE_THRE = 0.0
+config.TEST.NMS_THRE = 1.0
+
+# debug
+config.DEBUG = edict()
+config.DEBUG.DEBUG_MEMORY = False
+config.DEBUG.DEBUG = False
+config.DEBUG.SAVE_BATCH_IMAGES_GT = False
+config.DEBUG.SAVE_BATCH_IMAGES_PRED = False
+config.DEBUG.SAVE_HEATMAPS_GT = False
+config.DEBUG.SAVE_HEATMAPS_PRED = False
+
+
+def _update_dict(k, v):
+    if k == 'DATASET':
+        if 'MEAN' in v and v['MEAN']:
+            v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x
+                                  for x in v['MEAN']])
+        if 'STD' in v and v['STD']:
+            v['STD'] = np.array([eval(x) if isinstance(x, str) else x
+                                 for x in v['STD']])
+    if k == 'MODEL':
+        if 'NAME' in v:
+            config['MODEL']['EXTRA'] = MODEL_EXTRAS[v['NAME']]
+        if '_vis' in config['MODEL']['EXTRA'] and 'PREDICT_VIS' not in v:
+            v['PREDICT_VIS'] = True
+        if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']:
+            if isinstance(v['EXTRA']['HEATMAP_SIZE'], int):
+                v['EXTRA']['HEATMAP_SIZE'] = np.array(
+                    [v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']])
+            else:
+                v['EXTRA']['HEATMAP_SIZE'] = np.array(
+                    v['EXTRA']['HEATMAP_SIZE'])
+        if 'IMAGE_SIZE' in v:
+            if isinstance(v['IMAGE_SIZE'], int):
+                v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
+            else:
+                v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
+
+    for vk, vv in v.items():
+        if vk in config[k]:
+            if isinstance(vv, dict):
+                config[k][vk].update(vv)
+            else:
+                config[k][vk] = vv
+        else:
+            raise ValueError("{}.{} not exist in config.py".format(k, vk))
+
+
+def update_config(config_file):
+    exp_config = None
+    with open(config_file) as f:
+        exp_config = edict(yaml.load(f))
+        for k, v in exp_config.items():
+            if k in config:
+                if isinstance(v, dict):
+                    _update_dict(k, v)
+                else:
+                    if k == 'SCALES':
+                        config[k][0] = (tuple(v))
+                    else:
+                        config[k] = v
+            else:
+                raise ValueError("{} not exist in config.py".format(k))
+
+
+def gen_config(config_file):
+    cfg = dict(config)
+    for k, v in cfg.items():
+        if isinstance(v, edict):
+            cfg[k] = dict(v)
+
+    with open(config_file, 'w') as f:
+        yaml.dump(dict(cfg), f, default_flow_style=False)
+
+
+def update_dir(model_dir, log_dir, data_dir):
+    if model_dir:
+        config.OUTPUT_DIR = model_dir
+
+    if log_dir:
+        config.LOG_DIR = log_dir
+
+    if data_dir:
+        config.DATA_DIR = data_dir
+
+    config.DATASET.ROOT = os.path.join(
+        config.DATA_DIR, config.DATASET.ROOT)
+
+    config.TEST.COCO_BBOX_FILE = os.path.join(
+        config.DATA_DIR, config.TEST.COCO_BBOX_FILE)
+
+    config.MODEL.PRETRAINED = os.path.join(
+        config.DATA_DIR, config.MODEL.PRETRAINED)
+
+
+def get_model_name(cfg):
+    name = cfg.MODEL.NAME
+    full_name = cfg.MODEL.NAME
+    extra = cfg.MODEL.EXTRA
+    if name in MODEL_EXTRAS :
+        name = '{model}_{num_layers}'.format(
+            model=name,
+            num_layers=extra.NUM_LAYERS)
+        deconv_suffix = ''.join(
+            'd{}'.format(num_filters)
+            for num_filters in extra.NUM_DECONV_FILTERS)
+        full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
+            height=cfg.MODEL.IMAGE_SIZE[1],
+            width=cfg.MODEL.IMAGE_SIZE[0],
+            name=name,
+            deconv_suffix=deconv_suffix)
+    else:
+        raise ValueError('Unkown model: {}'.format(cfg.MODEL))
+
+    return name, full_name
+
+
+if __name__ == '__main__':
+    import sys
+
+    gen_config(sys.argv[1])
diff --git a/lib/core/evaluate.py b/lib/core/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fd05786a575402a0cd6f8e8329af0c5d2040d98
--- /dev/null
+++ b/lib/core/evaluate.py
@@ -0,0 +1,117 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import logging
+
+import numpy as np
+from sklearn.metrics import f1_score, precision_score, recall_score
+
+from core.inference import get_max_preds
+from utils.tabs import Tabs
+
+logger = logging.getLogger(__name__)
+
+
+def calc_dists(preds, target, normalize):
+    preds = preds.astype(np.float32)
+    target = target.astype(np.float32)
+    dists = np.zeros((preds.shape[1], preds.shape[0]))
+    for n in range(preds.shape[0]):
+        for c in range(preds.shape[1]):
+            if target[n, c, 0] > 1 and target[n, c, 1] > 1:
+                normed_preds = preds[n, c, :] / normalize[n]
+                normed_targets = target[n, c, :] / normalize[n]
+                dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
+            else:
+                dists[c, n] = -1
+    return dists
+
+
+def dist_acc(dists, thr=0.5):
+    ''' Return percentage below threshold while ignoring values with a -1 '''
+    dist_cal = np.not_equal(dists, -1)
+    num_dist_cal = dist_cal.sum()
+    if num_dist_cal > 0:
+        return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
+    else:
+        return -1
+
+
+def accuracy(output, target, hm_type='gaussian', thr=0.5):
+    '''
+    Calculate accuracy according to PCK,
+    but uses ground truth heatmap rather than x,y locations
+    First value to be returned is average accuracy across 'idxs',
+    followed by individual accuracies
+    '''
+    idx = list(range(output.shape[1]))
+    norm = 1.0
+    if hm_type == 'gaussian':
+        pred, _ = get_max_preds(output)
+        target, _ = get_max_preds(target)
+        h = output.shape[2]
+        w = output.shape[3]
+        norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
+    dists = calc_dists(pred, target, norm)
+
+    acc = np.zeros((len(idx) + 1))
+    avg_acc = 0
+    cnt = 0
+
+    for i in range(len(idx)):
+        acc[i + 1] = dist_acc(dists[idx[i]])
+        if acc[i + 1] >= 0:
+            avg_acc = avg_acc + acc[i + 1]
+            cnt += 1
+
+    avg_acc = avg_acc / cnt if cnt != 0 else 0
+    if cnt != 0:
+        acc[0] = avg_acc
+    return acc, avg_acc, cnt, pred
+
+
+def accuracy_vis(output, target, f1=False, save=None):
+    '''
+    Calculate accuracy according to PCK for visibilty
+    :param output: Prediction numpy array: either class (nb_preds, nb_joints) or model output (nb_preds, nb_joints, nb_classes)
+    :param target: Ground-truth numpy array (nb_preds, nb_joints)
+    :param f1: Boolean : compute and display or not f1 score
+    :param target: Path to save predictions and results to json
+    :return: accuracy
+    '''
+
+    lbls = [[i] for i in range(output.shape[2])] + [None]
+
+    output_preds = output
+    if len(output.shape) > 2:
+        output_preds = output.argmax(axis=2)
+    corrects = output_preds == target
+
+    if f1:
+        avg_mode = 'weighted'
+        flat_targ = target.flatten()
+        flat_out = output_preds.flatten()
+        prec_scs = []
+        reca_scs = []
+        f1_scs = []
+        for lb in lbls:
+            prec_scs.append(precision_score(flat_targ, flat_out, labels=lb, average=avg_mode, zero_division=0))
+            reca_scs.append(recall_score(flat_targ, flat_out, labels=lb, average=avg_mode, zero_division=0))
+            f1_scs.append(f1_score(flat_targ, flat_out, labels=lb, average=avg_mode, zero_division=0))
+
+        Tabs([prec_scs, reca_scs, f1_scs], lbls_r=['Prec', 'Reca', 'F1'], lbls_c=[str(lb) for lb in lbls],
+             logger=logger)
+
+    if save is not None:
+        with open(save, 'w') as f_out:
+            json.dump(dict([(l[0], l[1].tolist()) for l in [('target', target), ('output', output)]]), f_out)
+
+    return corrects.mean(), flat_out if f1 else None
diff --git a/lib/core/function.py b/lib/core/function.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bf372a358c9305be8fd9ef4e9927a4cd9027717
--- /dev/null
+++ b/lib/core/function.py
@@ -0,0 +1,393 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import logging
+import sys
+import time
+import os
+
+import numpy as np
+import torch
+import cv2
+
+from core.config import get_model_name
+from core.evaluate import accuracy, accuracy_vis
+from core.inference import get_final_preds, get_max_preds
+from utils.transforms import flip_back
+from utils.vis import save_debug_images
+from utils.debug import GradPlots, print_tensors
+
+logger = logging.getLogger(__name__)
+
+
+def train(config, train_loader, model, criterion, optimizer, epoch,
+          output_dir, tb_log_dir, writer_dict):
+    batch_time = AverageMeter()
+    data_time = AverageMeter()
+    losses = AverageMeter()
+    acc = AverageMeter()
+    acc_vis = AverageMeter()
+    losses_det = AverageMeters()
+
+    # switch to train mode
+    model.train()
+
+    end = time.time()
+    predict_vis = config.MODEL.PREDICT_VIS
+
+    grad_plot = GradPlots(key='fc')
+    for i, (input, target, target_weight, target_vis, meta) in enumerate(train_loader):
+        # measure data loading time
+        data_time.update(time.time() - end)
+
+        # compute output
+        output = model(input)
+        if predict_vis or type(output) == tuple:
+            output, out_vis = output
+        else:
+            out_vis = None
+
+        if config.MODEL.NB_VIS == 2:
+            target_vis = target_vis > 0
+            target_vis = target_vis.long()
+
+        target = target.cuda(non_blocking=True)
+        target_weight = target_weight.cuda(non_blocking=True)
+        target_vis = target_vis.cuda(non_blocking=True)
+
+        loss, loss_detail = criterion([output, out_vis], [target, target_vis], target_weight)
+
+        if config.DEBUG.DEBUG_MEMORY:
+            print_tensors()
+            sys.exit(0)
+
+        # compute gradient and do update step
+        optimizer.zero_grad()
+        loss.backward()
+
+        # plot_grad_flow(model.named_parameters())
+        # sys.exit(0)
+        # grad_plot.save_grads(model.named_parameters())
+
+        optimizer.step()
+
+        # measure accuracy and record loss
+        losses.update(loss.item(), input.size(0))
+        losses_det.update(loss_detail, input.size(0))
+
+        _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
+                                         target.detach().cpu().numpy())
+        acc.update(avg_acc, cnt)
+
+        if predict_vis:
+            accu_vis, _ = accuracy_vis(out_vis.detach().cpu().numpy(), target_vis.detach().cpu().numpy())
+            acc_vis.update(accu_vis, input.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if i % config.PRINT_FREQ == 0:
+            msg = 'Epoch: [{0}][{1}/{2}]\t' \
+                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
+                  'Speed {speed:.1f} samples/s\t' \
+                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
+                  'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
+                  f'Loss_Det {" ".join([f"{l.val:.5f}" for l in losses_det])} ({" ".join([f"{l.avg:.5f}" for l in losses_det])})\t' \
+                  'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
+                epoch, i, len(train_loader), batch_time=batch_time,
+                speed=input.size(0) / batch_time.val,
+                data_time=data_time, loss=losses, acc=acc)
+            if predict_vis:
+                msg += f'\tAccu_vis {acc_vis.val:.3f} ({acc_vis.avg:.3f})'
+            logger.info(msg)
+
+            writer = writer_dict['writer']
+            global_steps = writer_dict['train_global_steps']
+            writer.add_scalar('train_loss', losses.val, global_steps)
+            writer.add_scalar('train_acc', acc.val, global_steps)
+            writer_dict['train_global_steps'] = global_steps + 1
+
+            prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
+            save_debug_images(config, input, meta, target, pred * 4, output,
+                              prefix)
+
+
+def validate(config, val_loader, val_dataset, model, criterion, output_dir,
+             tb_log_dir, writer_dict=None):
+    batch_time = AverageMeter()
+    losses = AverageMeter()
+    acc = AverageMeter()
+    acc_vis = AverageMeter()
+
+    losses_det = AverageMeters()
+
+    # switch to evaluate mode
+    model.eval()
+
+    num_samples = len(val_dataset)
+    all_preds = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3),
+                         dtype=np.float32)
+    all_boxes = np.zeros((num_samples, 6))
+    all_maps = np.zeros(
+        (num_samples, config.MODEL.NUM_JOINTS, config.MODEL.EXTRA.HEATMAP_SIZE[1], config.MODEL.EXTRA.HEATMAP_SIZE[0]),
+        dtype=np.float32)
+    all_gts = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3), dtype=np.float32)
+    all_gts_maps = np.zeros(
+        (num_samples, config.MODEL.NUM_JOINTS, config.MODEL.EXTRA.HEATMAP_SIZE[1], config.MODEL.EXTRA.HEATMAP_SIZE[0]),
+        dtype=np.float32)
+    all_preds_vis = np.zeros((num_samples, config.MODEL.NUM_JOINTS, config.MODEL.NB_VIS), dtype=np.float32)
+    all_gts_vis = np.zeros((num_samples, config.MODEL.NUM_JOINTS), dtype=np.int32)
+
+    all_names = []
+    image_path = []
+    filenames = []
+    imgnums = []
+    idx = 0
+
+    predict_vis = config.MODEL.PREDICT_VIS
+
+    with torch.no_grad():
+        end = time.time()
+        for i, (input, target, target_weight, target_vis, meta) in enumerate(val_loader):
+            # compute output
+            output = model(input)
+            if predict_vis or type(output) == tuple:
+                output, out_vis = output
+            else:
+                out_vis = None
+
+            if config.TEST.FLIP_TEST:
+                # this part is ugly, because pytorch has not supported negative index
+                # input_flipped = model(input[:, :, :, ::-1])
+                input_flipped = np.flip(input.cpu().numpy(), 3).copy()
+                input_flipped = torch.from_numpy(input_flipped).cuda()
+                # compute output
+                output_flipped = model(input_flipped)
+                if predict_vis or type(output_flipped) == tuple:
+                    output_flipped, out_vis_flipped = output_flipped
+                    out_vis = (out_vis + out_vis_flipped) * 0.5
+                else:
+                    out_vis_flipped = None
+
+                output_flipped = flip_back(output_flipped.cpu().numpy(),
+                                           val_dataset.flip_pairs)
+                output_flipped = torch.from_numpy(output_flipped.copy()).cuda()
+
+                # feature is not aligned, shift flipped heatmap for higher accuracy
+                if config.TEST.SHIFT_HEATMAP:
+                    output_flipped[:, :, :, 1:] = \
+                        output_flipped.clone()[:, :, :, 0:-1]
+                    # output_flipped[:, :, :, 0] = 0
+
+                output = (output + output_flipped) * 0.5
+
+            if config.MODEL.NB_VIS == 2:
+                target_vis = target_vis > 0
+                target_vis = target_vis.long()
+
+            if False:
+                t = target_vis < 1
+                for r in range(t.shape[0]):
+                    for c in range(t.shape[1]):
+                        if t[r][c]:
+                            m = output[r][c].max()
+                            if m > 0.6:
+                                ti = int(time.time())
+                                hm = output[r][c].mul(255).clamp(0, 255).byte().cpu().numpy()
+                                img = input[r].mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
+                                hmi = cv2.applyColorMap(cv2.resize(hm, (img.shape[1], img.shape[0])), cv2.COLORMAP_JET)
+                                cv2.imwrite(f'temp/{c}_{r}_{m}.png', hmi * 0.7 + 0.3 * img)
+
+            target = target.cuda(non_blocking=True)
+            target_weight = target_weight.cuda(non_blocking=True)
+            target_vis = target_vis.cuda(non_blocking=True)
+
+            loss, loss_detail = criterion([output, out_vis], [target, target_vis], target_weight)
+
+            num_images = input.size(0)
+            # measure accuracy and record loss
+            losses.update(loss.item(), num_images)
+            losses_det.update(loss_detail, num_images)
+
+            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
+                                             target.cpu().numpy())
+
+            acc.update(avg_acc, cnt)
+
+            if predict_vis:
+                accu_vis, _ = accuracy_vis(out_vis.cpu().numpy(), target_vis.cpu().numpy())
+                acc_vis.update(accu_vis, num_images)
+
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            c = meta['center'].numpy()
+            s = meta['scale'].numpy()
+            score = meta['score'].numpy()
+
+            output_np = output.clone().cpu().numpy()
+
+            preds, maxvals = get_final_preds(
+                config, output_np, c, s)
+
+            target_cpu = target.clone().cpu().numpy()
+            target_vis_cpu = target_vis.cpu().numpy()
+            all_maps[idx:idx + num_images, :, :, :] = output_np
+            all_gts_maps[idx:idx + num_images, :, :, :] = target_cpu
+            all_gts[idx:idx + num_images, :, :2] = get_max_preds(target_cpu)[0]
+            all_gts[idx:idx + num_images, :, 2] = target_vis_cpu
+            all_names += [os.path.split(n)[1] for n in meta['image']]
+
+            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
+            all_preds[idx:idx + num_images, :, 2:3] = maxvals
+            # double check this all_boxes parts
+            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
+            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
+            all_boxes[idx:idx + num_images, 4] = np.prod(s * 200, 1)
+            all_boxes[idx:idx + num_images, 5] = score
+            image_path.extend(meta['image'])
+
+            if predict_vis:
+                all_preds_vis[idx:idx + num_images, :, :] = out_vis.cpu().numpy()
+                all_gts_vis[idx:idx + num_images, :] = target_vis_cpu
+
+            if config.DATASET.DATASET == 'posetrack':
+                filenames.extend(meta['filename'])
+                imgnums.extend(meta['imgnum'].numpy())
+
+            idx += num_images
+
+            if i % config.PRINT_FREQ == 0:
+                msg = 'Test: [{0}/{1}]\t' \
+                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
+                      'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
+                      f'Loss_Det {" ".join([f"{l.val:.5f}" for l in losses_det])} ({" ".join([f"{l.avg:.5f}" for l in losses_det])})\t' \
+                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
+                    i, len(val_loader), batch_time=batch_time,
+                    loss=losses, acc=acc)
+                if predict_vis:
+                    msg += f'\tAccu_vis {acc_vis.val:.3f} ({acc_vis.avg:.3f})'
+                logger.info(msg)
+
+                prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
+                save_debug_images(config, input, meta, target, pred * 4, output,
+                                  prefix)
+
+        name_values, perf_indicator = val_dataset.evaluate(
+            config, all_preds, output_dir, all_boxes, image_path,
+            filenames, imgnums)
+
+        with open(os.path.join(output_dir, 'results', 'keypoints_%s_preds.json' % val_dataset.image_set), 'w') as f:
+            json.dump({'annotations': all_preds.tolist(), 'names': all_names}, f)
+
+        with open(os.path.join(output_dir, 'results', 'keypoints_%s_gts.json' % val_dataset.image_set), 'w') as f:
+            json.dump({'annotations': all_gts.tolist(), 'names': all_names}, f)
+
+        if predict_vis:
+            accu_vis, flat_out_vis = accuracy_vis(all_preds_vis, all_gts_vis, f1=True,
+                                                  save=os.path.join(output_dir, 'results', 'result_vis_val.json'))
+            logger.info(f"Total vis accuracy: {accu_vis:.3f}")
+
+            all_vis_filter = flat_out_vis.reshape((-1, 17)) > 0
+            all_preds_nnz = all_preds.copy()
+            for i in range(all_preds_nnz.shape[-1]):
+                all_preds_nnz[:, :, i] *= all_vis_filter
+
+            res_file = os.path.join(
+                output_dir, 'results', 'keypoints_%s_results_nnz.json' % val_dataset.image_set)
+
+            val_dataset.evaluate(
+                config, all_preds_nnz, output_dir, all_boxes, image_path,
+                filenames, imgnums, res_file=res_file)
+
+        # np.save('val_out.npy', all_maps)
+        # np.savez_compressed('eval_out', names=all_names, maps=all_maps)
+        # np.savez_compressed('eval_gts', names=all_names, maps=all_gts_maps)
+
+        _, full_arch_name = get_model_name(config)
+        if isinstance(name_values, list):
+            for name_value in name_values:
+                _print_name_value(name_value, full_arch_name)
+        else:
+            _print_name_value(name_values, full_arch_name)
+
+        if writer_dict:
+            writer = writer_dict['writer']
+            global_steps = writer_dict['valid_global_steps']
+            writer.add_scalar('valid_loss', losses.avg, global_steps)
+            writer.add_scalar('valid_acc', acc.avg, global_steps)
+            if isinstance(name_values, list):
+                for name_value in name_values:
+                    writer.add_scalars('valid', dict(name_value), global_steps)
+            else:
+                writer.add_scalars('valid', dict(name_values), global_steps)
+            writer_dict['valid_global_steps'] = global_steps + 1
+
+    return perf_indicator
+
+
+# markdown format output
+def _print_name_value(name_value, full_arch_name):
+    names = name_value.keys()
+    values = name_value.values()
+    num_values = len(name_value)
+    logger.info(
+        '| Arch ' +
+        ' '.join(['| {}'.format(name) for name in names]) +
+        ' |'
+    )
+    logger.info('|---' * (num_values + 1) + '|')
+    logger.info(
+        '| ' + full_arch_name + ' ' +
+        ' '.join(['| {:.3f}'.format(value) for value in values]) +
+        ' |'
+    )
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count if self.count != 0 else 0
+
+
+class AverageMeters(object):
+    def __init__(self):
+        self.meters = []
+
+    def update(self, vals, n=1):
+        if len(self.meters) < len(vals):
+            self.meters = [AverageMeter() for _ in vals]
+
+        for i, val in enumerate(vals):
+            self.meters[i].update(val.item(), n)
+
+    def reset(self):
+        for met in self.meters:
+            met.reset()
+
+    def __iter__(self):
+        for met in self.meters:
+            yield met
diff --git a/lib/core/inference.py b/lib/core/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfbfb713588cf1216bc7f10f250937c881f0c3c5
--- /dev/null
+++ b/lib/core/inference.py
@@ -0,0 +1,74 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+
+from utils.transforms import transform_preds
+
+
+def get_max_preds(batch_heatmaps):
+    '''
+    get predictions from score maps
+    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
+    '''
+    assert isinstance(batch_heatmaps, np.ndarray), \
+        'batch_heatmaps should be numpy.ndarray'
+    assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
+
+    batch_size = batch_heatmaps.shape[0]
+    num_joints = batch_heatmaps.shape[1]
+    width = batch_heatmaps.shape[3]
+    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
+    idx = np.argmax(heatmaps_reshaped, 2)
+    maxvals = np.amax(heatmaps_reshaped, 2)
+
+    maxvals = maxvals.reshape((batch_size, num_joints, 1))
+    idx = idx.reshape((batch_size, num_joints, 1))
+
+    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
+
+    preds[:, :, 0] = (preds[:, :, 0]) % width
+    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
+
+    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
+    pred_mask = pred_mask.astype(np.float32)
+
+    preds *= pred_mask
+    return preds, maxvals
+
+
+def get_final_preds(config, batch_heatmaps, center, scale):
+    coords, maxvals = get_max_preds(batch_heatmaps)
+
+    heatmap_height = batch_heatmaps.shape[2]
+    heatmap_width = batch_heatmaps.shape[3]
+
+    # post-processing
+    if config.TEST.POST_PROCESS:
+        for n in range(coords.shape[0]):
+            for p in range(coords.shape[1]):
+                hm = batch_heatmaps[n][p]
+                px = int(math.floor(coords[n][p][0] + 0.5))
+                py = int(math.floor(coords[n][p][1] + 0.5))
+                if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
+                    diff = np.array([hm[py][px+1] - hm[py][px-1],
+                                     hm[py+1][px]-hm[py-1][px]])
+                    coords[n][p] += np.sign(diff) * .25
+
+    preds = coords.copy()
+
+    # Transform back
+    for i in range(coords.shape[0]):
+        preds[i] = transform_preds(coords[i], center[i], scale[i],
+                                   [heatmap_width, heatmap_height])
+
+    return preds, maxvals
diff --git a/lib/core/loss.py b/lib/core/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f50b82e85de75a4e6e09998b624ccd89213ef28
--- /dev/null
+++ b/lib/core/loss.py
@@ -0,0 +1,110 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import torch.nn as nn
+from abc import ABC, abstractmethod
+
+
+class JointsLoss(nn.Module, ABC):
+    def __init__(self, use_target_weight, use_vis=False):
+        super(JointsLoss, self).__init__()
+        self.criterion = None
+        self.use_target_weight = use_target_weight
+        self.use_vis = use_vis
+
+    def forward(self, outputs, targets, target_weights):
+        if self.use_vis:
+            return self._forward(*outputs, *targets, target_weight=target_weights)
+        else:
+            return self._forward(outputs[0], targets[0], target_weight=target_weights)
+
+    @abstractmethod
+    def _forward(self, outputs, targets, target_weights):
+        pass
+
+
+class JointsMSELoss(JointsLoss):
+    def __init__(self, use_target_weight):
+        super(JointsMSELoss, self).__init__(use_target_weight, use_vis=False)
+        self.criterion = nn.MSELoss(size_average=True)
+
+    def _forward(self, output, target, target_weight):
+        batch_size = output.size(0)
+        num_joints = output.size(1)
+        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
+        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
+        loss = 0
+
+        for idx in range(num_joints):
+            heatmap_pred = heatmaps_pred[idx].squeeze()
+            heatmap_gt = heatmaps_gt[idx].squeeze()
+            if self.use_target_weight:
+                loss += 0.5 * self.criterion(
+                    heatmap_pred.mul(target_weight[:, idx]),
+                    heatmap_gt.mul(target_weight[:, idx])
+                )
+            else:
+                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
+
+        loss /= num_joints
+        return loss, [loss]
+
+
+class JointsMSELossVis(JointsLoss):
+    def __init__(self, use_target_weight, vis_ratio=.5, vis_weight=None):
+        super(JointsMSELossVis, self).__init__(use_target_weight, use_vis=True)
+        self.criterion = nn.MSELoss(size_average=True)
+        self.vis_criterion = nn.CrossEntropyLoss(weight=vis_weight, size_average=True)
+
+        # Ratio loss_vis / loss_hm
+        self.vis_ratio = vis_ratio
+
+    def update_vis_ratio(self, factor):
+        self.vis_ratio += factor
+        if self.vis_ratio > 1:
+            self.vis_ratio = 1.
+        if self.vis_ratio < 0:
+            self.vis_ratio = 0.
+
+    def _forward(self, output, vis_preds, target, vis_gts, target_weight):
+        batch_size = output.size(0)
+        num_joints = output.size(1)
+        heatmaps_pred = output.reshape((batch_size, num_joints, -1)).split(1, 1)
+        heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
+
+        vis_preds = vis_preds.split(1, 1)
+        vis_gts = vis_gts.split(1, 1)
+
+        loss = 0
+        loss_vis = 0
+
+        for idx in range(num_joints):
+            heatmap_pred = heatmaps_pred[idx].squeeze(1)
+            heatmap_gt = heatmaps_gt[idx].squeeze(1)
+            vis_pred = vis_preds[idx].squeeze(1)
+            vis_gt = vis_gts[idx].squeeze(1)
+
+            if self.use_target_weight:
+                loss += 0.5 * self.criterion(
+                    heatmap_pred.mul(target_weight[:, idx]),
+                    heatmap_gt.mul(target_weight[:, idx])
+                )
+            else:
+                loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
+
+            # 2.5E-3 provides aroud 50/50 ratio
+            l_vis = 0.5 * 2.5E-3 * self.vis_criterion(vis_pred, vis_gt)
+            
+            # print(l_vis)
+            loss_vis += l_vis
+
+        loss /= num_joints
+        loss_vis /= num_joints
+        return (loss * (1 - self.vis_ratio) + loss_vis * self.vis_ratio), [loss, loss_vis]
diff --git a/lib/dataset/JointsDataset.py b/lib/dataset/JointsDataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a39781f195378cb2041ea5b7529ea1620672535
--- /dev/null
+++ b/lib/dataset/JointsDataset.py
@@ -0,0 +1,231 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import logging
+import random
+
+import cv2
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+from utils.transforms import get_affine_transform
+from utils.transforms import affine_transform
+from utils.transforms import fliplr_joints
+
+
+logger = logging.getLogger(__name__)
+
+
+class JointsDataset(Dataset):
+    def __init__(self, cfg, root, image_set, is_train, transform=None):
+        self.num_joints = 0
+        self.pixel_std = 200
+        self.flip_pairs = []
+        self.parent_ids = []
+
+        self.is_train = is_train
+        self.root = root
+        self.image_set = image_set
+
+        self.output_path = cfg.OUTPUT_DIR
+        self.data_format = cfg.DATASET.DATA_FORMAT
+
+        self.scale_factor = cfg.DATASET.SCALE_FACTOR
+        self.rotation_factor = cfg.DATASET.ROT_FACTOR
+        self.flip = cfg.DATASET.FLIP
+
+        self.image_size = cfg.MODEL.IMAGE_SIZE
+        self.target_type = cfg.MODEL.EXTRA.TARGET_TYPE
+        self.heatmap_size = cfg.MODEL.EXTRA.HEATMAP_SIZE
+        self.sigma = cfg.MODEL.EXTRA.SIGMA
+
+        self.transform = transform
+        self.db = []
+
+    def _get_db(self):
+        raise NotImplementedError
+
+    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
+        raise NotImplementedError
+
+    def __len__(self,):
+        return len(self.db)
+
+    def __getitem__(self, idx):
+        db_rec = copy.deepcopy(self.db[idx])
+
+        image_file = db_rec['image']
+        image_file = image_file.replace('\\', '/')
+
+        filename = db_rec['filename'] if 'filename' in db_rec else ''
+        imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
+
+        if self.data_format == 'zip':
+            from utils import zipreader
+            data_numpy = zipreader.imread(
+                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+        else:
+            data_numpy = cv2.imread(
+                image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+
+        if data_numpy is None:
+            logger.error('=> fail to read {}'.format(image_file))
+            raise ValueError('Fail to read {}'.format(image_file))
+
+        joints = db_rec['joints_3d']
+        joints_vis = db_rec['joints_3d_vis']
+
+        c = db_rec['center']
+        s = db_rec['scale']
+        score = db_rec['score'] if 'score' in db_rec else 1
+        r = 0
+
+        if self.is_train:
+            sf = self.scale_factor
+            rf = self.rotation_factor
+            s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
+            r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
+                if random.random() <= 0.6 else 0
+
+            if self.flip and random.random() <= 0.5:
+                data_numpy = data_numpy[:, ::-1, :]
+                joints, joints_vis = fliplr_joints(
+                    joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
+                c[0] = data_numpy.shape[1] - c[0] - 1
+
+        trans = get_affine_transform(c, s, r, self.image_size)
+        input = cv2.warpAffine(
+            data_numpy,
+            trans,
+            (int(self.image_size[0]), int(self.image_size[1])),
+            flags=cv2.INTER_LINEAR)
+
+        if self.transform:
+            input = self.transform(input)
+            
+        # target_vis = np.zeros((self.num_joints, 3), dtype=np.float)
+        target_vis = np.zeros((self.num_joints), dtype=np.long)
+        for i in range(self.num_joints):
+            if joints_vis[i, 0] > 0.0:
+                joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
+            # target_vis[i, int(joints_vis[i, 0])] = 1.0
+            target_vis[i] = int(joints_vis[i, 0])
+
+        # target_vis = torch.from_numpy(target_vis).float()
+        target_vis = torch.from_numpy(target_vis).long()
+
+        target, target_weight = self.generate_target(joints, joints_vis)
+
+        target = torch.from_numpy(target)
+        target_weight = torch.from_numpy(target_weight)
+        
+        meta = {
+            'image': image_file,
+            'filename': filename,
+            'imgnum': imgnum,
+            'joints': joints,
+            'joints_vis': joints_vis,
+            'center': c,
+            'scale': s,
+            'rotation': r,
+            'score': score
+        }
+
+        return input, target, target_weight, target_vis, meta
+
+    def select_data(self, db):
+        db_selected = []
+        for rec in db:
+            num_vis = 0
+            joints_x = 0.0
+            joints_y = 0.0
+            for joint, joint_vis in zip(
+                    rec['joints_3d'], rec['joints_3d_vis']):
+                if joint_vis[0] <= 0:
+                    continue
+                num_vis += 1
+
+                joints_x += joint[0]
+                joints_y += joint[1]
+            if num_vis == 0:
+                continue
+
+            joints_x, joints_y = joints_x / num_vis, joints_y / num_vis
+
+            area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std**2)
+            joints_center = np.array([joints_x, joints_y])
+            bbox_center = np.array(rec['center'])
+            diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2)
+            ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area))
+
+            metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
+            if ks > metric:
+                db_selected.append(rec)
+
+        logger.info('=> num db: {}'.format(len(db)))
+        logger.info('=> num selected db: {}'.format(len(db_selected)))
+        return db_selected
+
+    def generate_target(self, joints, joints_vis):
+        '''
+        :param joints:  [num_joints, 3]
+        :param joints_vis: [num_joints, 3]
+        :return: target, target_weight(1: visible, 0: invisible)
+        '''
+        target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
+        target_weight[:, 0] *= (joints_vis[:, 0] > 0)
+
+        assert self.target_type == 'gaussian', \
+            'Only support gaussian map now!'
+
+        if self.target_type == 'gaussian':
+            target = np.zeros((self.num_joints,
+                               self.heatmap_size[1],
+                               self.heatmap_size[0]),
+                              dtype=np.float32)
+
+            tmp_size = self.sigma * 3
+
+            for joint_id in range(self.num_joints):
+                feat_stride = self.image_size / self.heatmap_size
+                mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
+                mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
+                # Check that any part of the gaussian is in-bounds
+                ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
+                br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
+                if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
+                        or br[0] < 0 or br[1] < 0:
+                    # If not, just return the image as is
+                    target_weight[joint_id] = 0
+                    continue
+
+                # # Generate gaussian
+                size = 2 * tmp_size + 1
+                x = np.arange(0, size, 1, np.float32)
+                y = x[:, np.newaxis]
+                x0 = y0 = size // 2
+                # The gaussian is not normalized, we want the center value to equal 1
+                g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
+
+                # Usable gaussian range
+                g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
+                g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
+                # Image range
+                img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
+                img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
+
+                v = target_weight[joint_id]
+                if v > 0.5:
+                    target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
+                        g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
+
+        return target, target_weight
diff --git a/lib/dataset/__init__.py b/lib/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..144268e0fb778f68f444db61bf2a35fcb3bb06d6
--- /dev/null
+++ b/lib/dataset/__init__.py
@@ -0,0 +1,14 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from .mpii import MPIIDataset as mpii
+from .coco import COCODataset as coco
+from .dripe import DriPEDataset as dripe
+from .demo_loader import DemoLoader as demo_loader
\ No newline at end of file
diff --git a/lib/dataset/coco.py b/lib/dataset/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..37a2d526a8147697bf9cfe4861250ac0812b755a
--- /dev/null
+++ b/lib/dataset/coco.py
@@ -0,0 +1,412 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+import pickle
+from collections import defaultdict
+from collections import OrderedDict
+
+import json_tricks as json
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+
+from dataset.JointsDataset import JointsDataset
+from nms.nms import oks_nms
+
+logger = logging.getLogger(__name__)
+
+
+class COCODataset(JointsDataset):
+    '''
+    "keypoints": {
+        0: "nose",
+        1: "left_eye",
+        2: "right_eye",
+        3: "left_ear",
+        4: "right_ear",
+        5: "left_shoulder",
+        6: "right_shoulder",
+        7: "left_elbow",
+        8: "right_elbow",
+        9: "left_wrist",
+        10: "right_wrist",
+        11: "left_hip",
+        12: "right_hip",
+        13: "left_knee",
+        14: "right_knee",
+        15: "left_ankle",
+        16: "right_ankle"
+    },
+	"skeleton": [
+        [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
+        [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
+    '''
+
+    def __init__(self, cfg, root, image_set, is_train, transform=None):
+        super().__init__(cfg, root, image_set, is_train, transform)
+        self.nms_thre = cfg.TEST.NMS_THRE
+        self.image_thre = cfg.TEST.IMAGE_THRE
+        self.oks_thre = cfg.TEST.OKS_THRE
+        self.in_vis_thre = cfg.TEST.IN_VIS_THRE
+        self.bbox_file = cfg.TEST.COCO_BBOX_FILE
+        self.use_gt_bbox = cfg.TEST.USE_GT_BBOX
+        self.image_width = cfg.MODEL.IMAGE_SIZE[0]
+        self.image_height = cfg.MODEL.IMAGE_SIZE[1]
+        self.aspect_ratio = self.image_width * 1.0 / self.image_height
+        self.pixel_std = 200
+        self.coco = COCO(self._get_ann_file_keypoint())
+
+        # deal with class names
+        cats = [cat['name']
+                for cat in self.coco.loadCats(self.coco.getCatIds())]
+        self.classes = ['__background__'] + cats
+        logger.info('=> classes: {}'.format(self.classes))
+        self.num_classes = len(self.classes)
+        self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
+        self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
+        self._coco_ind_to_class_ind = dict([(self._class_to_coco_ind[cls],
+                                             self._class_to_ind[cls])
+                                            for cls in self.classes[1:]])
+
+        # load image file names
+        self.image_set_index = self._load_image_set_index()
+        self.num_images = len(self.image_set_index)
+        logger.info('=> num_images: {}'.format(self.num_images))
+
+        self.num_joints = 17
+        self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
+                           [9, 10], [11, 12], [13, 14], [15, 16]]
+        self.parent_ids = None
+
+        self.db = self._get_db()
+
+        if is_train and cfg.DATASET.SELECT_DATA:
+            self.db = self.select_data(self.db)
+
+        logger.info('=> load {} samples'.format(len(self.db)))
+
+    def _get_ann_file_keypoint(self):
+        """ self.root / annotations / person_keypoints_train2017.json """
+        prefix = 'person_keypoints' \
+            if 'test' not in self.image_set else 'image_info'
+        return os.path.join(self.root, 'annotations',
+                            prefix + '_' + self.image_set + '.json')
+
+    def _load_image_set_index(self):
+        """ image id: int """
+        image_ids = self.coco.getImgIds()
+        return image_ids
+
+    def _get_db(self):
+        if self.is_train or self.use_gt_bbox:
+            # use ground truth bbox
+            gt_db = self._load_coco_keypoint_annotations()
+        else:
+            # use bbox from detection
+            gt_db = self._load_coco_person_detection_results()
+        return gt_db
+
+    def _load_coco_keypoint_annotations(self):
+        """ ground truth bbox and keypoints """
+        gt_db = []
+        for index in self.image_set_index:
+            gt_db.extend(self._load_coco_keypoint_annotation_kernal(index))
+        return gt_db
+
+    def _load_coco_keypoint_annotation_kernal(self, index):
+        """
+        coco ann: [u'segmentation', u'area', u'iscrowd', u'image_id', u'bbox', u'category_id', u'id']
+        iscrowd:
+            crowd instances are handled by marking their overlaps with all categories to -1
+            and later excluded in training
+        bbox:
+            [x1, y1, w, h]
+        :param index: coco image id
+        :return: db entry
+        """
+        im_ann = self.coco.loadImgs(index)[0]
+        width = im_ann['width']
+        height = im_ann['height']
+
+        annIds = self.coco.getAnnIds(imgIds=index, iscrowd=False)
+        objs = self.coco.loadAnns(annIds)
+
+        # sanitize bboxes
+        valid_objs = []
+        for obj in objs:
+            x, y, w, h = obj['bbox']
+            x1 = np.max((0, x))
+            y1 = np.max((0, y))
+            x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
+            y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
+            if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
+                # obj['clean_bbox'] = [x1, y1, x2, y2]
+                obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
+                valid_objs.append(obj)
+        objs = valid_objs
+
+        rec = []
+        for obj in objs:
+            cls = self._coco_ind_to_class_ind[obj['category_id']]
+            if cls != 1:
+                continue
+
+            # ignore objs without keypoints annotation
+            if max(obj['keypoints']) == 0:
+                continue
+
+            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
+            joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
+            for ipt in range(self.num_joints):
+                joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
+                joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
+                joints_3d[ipt, 2] = 0
+                t_vis = obj['keypoints'][ipt * 3 + 2]
+                # if t_vis > 1:
+                #     t_vis = 1
+                joints_3d_vis[ipt, 0] = t_vis
+                joints_3d_vis[ipt, 1] = t_vis
+                joints_3d_vis[ipt, 2] = 0
+
+            center, scale = self._box2cs(obj['clean_bbox'][:4])
+            rec.append({
+                'image': self.image_path_from_index(index),
+                'center': center,
+                'scale': scale,
+                'joints_3d': joints_3d,
+                'joints_3d_vis': joints_3d_vis,
+                'filename': '',
+                'imgnum': 0,
+            })
+
+        return rec
+
+    def _box2cs(self, box):
+        x, y, w, h = box[:4]
+        return self._xywh2cs(x, y, w, h)
+
+    def _xywh2cs(self, x, y, w, h):
+        center = np.zeros((2), dtype=np.float32)
+        center[0] = x + w * 0.5
+        center[1] = y + h * 0.5
+
+        if w > self.aspect_ratio * h:
+            h = w * 1.0 / self.aspect_ratio
+        elif w < self.aspect_ratio * h:
+            w = h * self.aspect_ratio
+        scale = np.array(
+            [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
+            dtype=np.float32)
+        if center[0] != -1:
+            scale = scale * 1.25
+
+        return center, scale
+
+    def image_path_from_index(self, index):
+        """ example: images / train2017 / 000000119993.jpg """
+        file_name = '%012d.jpg' % index
+        if '2014' in self.image_set:
+            file_name = 'COCO_%s_' % self.image_set + file_name
+
+        prefix = 'test2017' if 'test' in self.image_set else self.image_set
+
+        data_name = prefix + '.zip@' if self.data_format == 'zip' else prefix
+
+        image_path = os.path.join(
+            self.root, 'images', data_name, file_name)
+
+        return image_path
+
+    def _load_coco_person_detection_results(self):
+        all_boxes = None
+        with open(self.bbox_file, 'r') as f:
+            all_boxes = json.load(f)
+
+        if not all_boxes:
+            logger.error('=> Load %s fail!' % self.bbox_file)
+            return None
+
+        logger.info('=> Total boxes: {}'.format(len(all_boxes)))
+
+        kpt_db = []
+        num_boxes = 0
+        for n_img in range(0, len(all_boxes)):
+            det_res = all_boxes[n_img]
+            if det_res['category_id'] != 1:
+                continue
+            img_name = self.image_path_from_index(det_res['image_id'])
+            box = det_res['bbox']
+            score = det_res['score']
+
+            if score < self.image_thre:
+                continue
+
+            num_boxes = num_boxes + 1
+
+            center, scale = self._box2cs(box)
+            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
+            joints_3d_vis = np.ones(
+                (self.num_joints, 3), dtype=np.float)
+            kpt_db.append({
+                'image': img_name,
+                'center': center,
+                'scale': scale,
+                'score': score,
+                'joints_3d': joints_3d,
+                'joints_3d_vis': joints_3d_vis,
+            })
+
+        logger.info('=> Total boxes after filter low score@{}: {}'.format(
+            self.image_thre, num_boxes))
+        return kpt_db
+
+    # need double check this API and classes field
+    def evaluate(self, cfg, preds, output_dir, all_boxes, img_path,
+                 *args, **kwargs):
+        res_folder = os.path.join(output_dir, 'results')
+        if not os.path.exists(res_folder):
+            os.makedirs(res_folder)
+        res_file = os.path.join(
+            res_folder, 'keypoints_%s_results.json' % self.image_set)
+        if 'res_file' in kwargs:
+            res_file = kwargs['res_file']
+
+        # person x (keypoints)
+        _kpts = []
+        for idx, kpt in enumerate(preds):
+            _kpts.append({
+                'keypoints': kpt,
+                'center': all_boxes[idx][0:2],
+                'scale': all_boxes[idx][2:4],
+                'area': all_boxes[idx][4],
+                'score': all_boxes[idx][5],
+                'image': int(img_path[idx][-16:-4])
+            })
+        # image x person x (keypoints)
+        kpts = defaultdict(list)
+        for kpt in _kpts:
+            kpts[kpt['image']].append(kpt)
+
+        # rescoring and oks nms
+        num_joints = self.num_joints
+        in_vis_thre = self.in_vis_thre
+        oks_thre = self.oks_thre
+        oks_nmsed_kpts = []
+
+        for img in kpts.keys():
+            img_kpts = kpts[img]
+            for n_p in img_kpts:
+                box_score = n_p['score']
+                kpt_score = 0
+                valid_num = 0
+                for n_jt in range(0, num_joints):
+                    t_s = n_p['keypoints'][n_jt][2]
+                    if t_s > in_vis_thre:
+                        kpt_score = kpt_score + t_s
+                        valid_num = valid_num + 1
+                if valid_num != 0:
+                    kpt_score = kpt_score / valid_num
+                # rescoring
+                n_p['score'] = kpt_score * box_score
+            keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
+                           oks_thre)
+            if len(keep) == 0:
+                oks_nmsed_kpts.append(img_kpts)
+            else:
+                oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
+
+        self._write_coco_keypoint_results(
+            oks_nmsed_kpts, res_file)
+        if 'test' not in self.image_set:
+            info_str = self._do_python_keypoint_eval(
+                res_file, res_folder)
+            name_value = OrderedDict(info_str)
+            return name_value, name_value['AP']
+        else:
+            return {'Null': 0}, 0
+
+    def _write_coco_keypoint_results(self, keypoints, res_file):
+        data_pack = [{'cat_id': self._class_to_coco_ind[cls],
+                      'cls_ind': cls_ind,
+                      'cls': cls,
+                      'ann_type': 'keypoints',
+                      'keypoints': keypoints
+                      }
+                     for cls_ind, cls in enumerate(self.classes) if not cls == '__background__']
+
+        results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
+        logger.info('=> Writing results json to %s' % res_file)
+        with open(res_file, 'w') as f:
+            json.dump(results, f, sort_keys=True, indent=4)
+        try:
+            json.load(open(res_file))
+        except Exception:
+            content = []
+            with open(res_file, 'r') as f:
+                for line in f:
+                    content.append(line)
+            content[-1] = ']'
+            with open(res_file, 'w') as f:
+                for c in content:
+                    f.write(c)
+
+    def _coco_keypoint_results_one_category_kernel(self, data_pack):
+        cat_id = data_pack['cat_id']
+        keypoints = data_pack['keypoints']
+        cat_results = []
+
+        for img_kpts in keypoints:
+            if len(img_kpts) == 0:
+                continue
+
+            _key_points = np.array([img_kpts[k]['keypoints']
+                                    for k in range(len(img_kpts))])
+            key_points = np.zeros(
+                (_key_points.shape[0], self.num_joints * 3), dtype=np.float)
+
+            for ipt in range(self.num_joints):
+                key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
+                key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
+                key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2]  # keypoints score.
+
+            result = [{'image_id': img_kpts[k]['image'],
+                       'category_id': cat_id,
+                       'keypoints': list(key_points[k]),
+                       'score': img_kpts[k]['score'],
+                       'center': list(img_kpts[k]['center']),
+                       'scale': list(img_kpts[k]['scale'])
+                       } for k in range(len(img_kpts))]
+            cat_results.extend(result)
+
+        return cat_results
+
+    def _do_python_keypoint_eval(self, res_file, res_folder):
+        coco_dt = self.coco.loadRes(res_file)
+        coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
+        coco_eval.params.useSegm = None
+        coco_eval.evaluate()
+        coco_eval.accumulate()
+        coco_eval.summarize()
+        stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
+
+        info_str = []
+        for ind, name in enumerate(stats_names):
+            info_str.append((name, coco_eval.stats[ind]))
+
+        eval_file = os.path.join(
+            res_folder, 'keypoints_%s_results.pkl' % self.image_set)
+
+        with open(eval_file, 'wb') as f:
+            pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
+        logger.info('=> coco eval results saved to %s' % eval_file)
+
+        return info_str
diff --git a/lib/dataset/demo_loader.py b/lib/dataset/demo_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f317fdfab8bfc24b07a963b64d2b82a3cc16e2ca
--- /dev/null
+++ b/lib/dataset/demo_loader.py
@@ -0,0 +1,188 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import copy
+import os
+import pickle
+from collections import defaultdict
+from collections import OrderedDict
+import glob
+import json_tricks as json
+import random
+
+import numpy as np
+import cv2
+import torch
+from torch.utils.data import Dataset
+
+from dataset.JointsDataset import get_affine_transform
+from utils.transforms import get_affine_transform
+from utils.transforms import affine_transform
+
+logger = logging.getLogger(__name__)
+
+
+class DemoLoader(Dataset):
+    '''
+    "keypoints": {
+        0: "nose",
+        1: "left_eye",
+        2: "right_eye",
+        3: "left_ear",
+        4: "right_ear",
+        5: "left_shoulder",
+        6: "right_shoulder",
+        7: "left_elbow",
+        8: "right_elbow",
+        9: "left_wrist",
+        10: "right_wrist",
+        11: "left_hip",
+        12: "right_hip",
+        13: "left_knee",
+        14: "right_knee",
+        15: "left_ankle",
+        16: "right_ankle"
+    },
+	"skeleton": [
+        [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
+        [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
+    '''
+
+    def __init__(self, cfg, root, image_set, is_train, transform=None):
+        self.nms_thre = cfg.TEST.NMS_THRE
+        self.image_thre = cfg.TEST.IMAGE_THRE
+        self.oks_thre = cfg.TEST.OKS_THRE
+        self.in_vis_thre = cfg.TEST.IN_VIS_THRE
+        self.bbox_file = cfg.TEST.COCO_BBOX_FILE
+        self.use_gt_bbox = cfg.TEST.USE_GT_BBOX
+        self.image_width = cfg.MODEL.IMAGE_SIZE[0]
+        self.image_height = cfg.MODEL.IMAGE_SIZE[1]
+        self.aspect_ratio = self.image_width * 1.0 / self.image_height
+        self.pixel_std = 200
+
+        # load image file names
+        self.image_size = cfg.MODEL.IMAGE_SIZE
+        self.image_list = glob.glob('demo/*.png') + glob.glob('demo/*.jpg') + glob.glob('demo/*.jpeg')
+        self.num_images = len(self.image_list)
+        logger.info('=> num_images: {}'.format(self.num_images))
+
+        self.is_train = False
+        self.root = root
+        self.image_set = image_set
+        self.transform = transform
+        self.num_joints = 17
+
+    def __len__(self, ):
+        return len(self.image_list)
+
+    def __getitem__(self, idx):
+        image_file = self.image_list[idx]
+        image_file = image_file.replace('\\', '/')
+
+        data_numpy = cv2.imread(
+            image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
+
+        if data_numpy is None:
+            logger.error('=> fail to read {}'.format(image_file))
+            raise ValueError('Fail to read {}'.format(image_file))
+
+        c, s = self._get_cs(data_numpy.shape)
+        r = 0
+
+        trans = get_affine_transform(c, s, r, self.image_size)
+        input = cv2.warpAffine(
+            data_numpy,
+            trans,
+            (int(self.image_size[0]), int(self.image_size[1])),
+            flags=cv2.INTER_LINEAR)
+
+        if self.transform:
+            input = self.transform(input)
+
+        meta = {
+            'image': image_file,
+            'center': c,
+            'scale': s,
+            'rotation': r,
+        }
+
+        return input, meta
+
+    def _get_cs(self, img_shape):
+        h, w = img_shape[:2]
+
+        center = np.zeros((2), dtype=np.float32)
+        center[0] = w * 0.5
+        center[1] = h * 0.5
+
+        if w > self.aspect_ratio * h:
+            h = w * 1.0 / self.aspect_ratio
+        elif w < self.aspect_ratio * h:
+            w = h * self.aspect_ratio
+        scale = np.array(
+            [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
+            dtype=np.float32)
+        if center[0] != -1:
+            scale = scale * 1.25
+
+        return center, scale
+
+    def _write_coco_keypoint_results(self, keypoints, res_file):
+        data_pack = [
+            {
+                'ann_type': 'keypoints',
+                'keypoints': keypoints
+            }
+        ]
+
+        results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
+        logger.info('=> Writing results json to %s' % res_file)
+        with open(res_file, 'w') as f:
+            json.dump(results, f, sort_keys=True, indent=4)
+        try:
+            json.load(open(res_file))
+        except Exception:
+            content = []
+            with open(res_file, 'r') as f:
+                for line in f:
+                    content.append(line)
+            content[-1] = ']'
+            with open(res_file, 'w') as f:
+                for c in content:
+                    f.write(c)
+
+    def _coco_keypoint_results_one_category_kernel(self, data_pack):
+        keypoints = data_pack['keypoints']
+        cat_results = []
+
+        for img_kpts in keypoints:
+            if len(img_kpts) == 0:
+                continue
+
+            _key_points = np.array([img_kpts[k]['keypoints']
+                                    for k in range(len(img_kpts))])
+            key_points = np.zeros(
+                (_key_points.shape[0], self.num_joints * 3), dtype=np.float)
+
+            for ipt in range(self.num_joints):
+                key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
+                key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
+                key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2]  # keypoints score.
+
+            result = [{'image_id': img_kpts[k]['image'],
+                       'keypoints': list(key_points[k]),
+                       'score': img_kpts[k]['score'],
+                       'center': list(img_kpts[k]['center']),
+                       'scale': list(img_kpts[k]['scale'])
+                       } for k in range(len(img_kpts))]
+            cat_results.extend(result)
+
+        return cat_results
diff --git a/lib/dataset/dripe.py b/lib/dataset/dripe.py
new file mode 100644
index 0000000000000000000000000000000000000000..7705945a6cca9f7f42919a35823ff33970b496a7
--- /dev/null
+++ b/lib/dataset/dripe.py
@@ -0,0 +1,417 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import os
+import pickle
+from collections import defaultdict
+from collections import OrderedDict
+
+import json_tricks as json
+import numpy as np
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+
+from dataset.JointsDataset import JointsDataset
+from nms.nms import oks_nms
+
+
+logger = logging.getLogger(__name__)
+
+
+class DriPEDataset(JointsDataset):
+    '''
+    "keypoints": {
+        0: "nose",
+        1: "left_eye",
+        2: "right_eye",
+        3: "left_ear",
+        4: "right_ear",
+        5: "left_shoulder",
+        6: "right_shoulder",
+        7: "left_elbow",
+        8: "right_elbow",
+        9: "left_wrist",
+        10: "right_wrist",
+        11: "left_hip",
+        12: "right_hip",
+        13: "left_knee",
+        14: "right_knee",
+        15: "left_ankle",
+        16: "right_ankle"
+    },
+	"skeleton": [
+        [16,14],[14,12],[17,15],[15,13],[12,13],[6,12],[7,13], [6,7],[6,8],
+        [7,9],[8,10],[9,11],[2,3],[1,2],[1,3],[2,4],[3,5],[4,6],[5,7]]
+    '''
+    def __init__(self, cfg, root, image_set, is_train, transform=None):
+        super().__init__(cfg, root, image_set, is_train, transform)
+        self.nms_thre = cfg.TEST.NMS_THRE
+        self.image_thre = cfg.TEST.IMAGE_THRE
+        self.oks_thre = cfg.TEST.OKS_THRE
+        self.in_vis_thre = cfg.TEST.IN_VIS_THRE
+        self.bbox_file = cfg.TEST.COCO_BBOX_FILE
+        self.use_gt_bbox = cfg.TEST.USE_GT_BBOX
+        self.image_width = cfg.MODEL.IMAGE_SIZE[0]
+        self.image_height = cfg.MODEL.IMAGE_SIZE[1]
+        self.aspect_ratio = self.image_width * 1.0 / self.image_height
+        self.pixel_std = 200
+        self.coco = COCO(self._get_ann_file_keypoint())
+        self.ids_to_name = dict([(im['id'], im['file_name']) for im in self.coco.imgs.values()])
+        self.name_to_ids = dict([(im['file_name'], im['id']) for im in self.coco.imgs.values()])
+        
+
+        # deal with class names
+        cats = [cat['name']
+                for cat in self.coco.loadCats(self.coco.getCatIds())]
+        self.classes = ['__background__'] + cats
+        logger.info('=> classes: {}'.format(self.classes))
+        self.num_classes = len(self.classes)
+        self._class_to_ind = dict(zip(self.classes, range(self.num_classes)))
+        self._class_to_coco_ind = dict(zip(cats, self.coco.getCatIds()))
+        self._coco_ind_to_class_ind = dict([(self._class_to_coco_ind[cls],
+                                             self._class_to_ind[cls])
+                                            for cls in self.classes[1:]])
+
+        # load image file names
+        self.image_set_index = self._load_image_set_index()
+        self.num_images = len(self.image_set_index)
+        logger.info('=> num_images: {}'.format(self.num_images))
+
+        self.num_joints = 17
+        self.flip_pairs = [[1, 2], [3, 4], [5, 6], [7, 8],
+                           [9, 10], [11, 12], [13, 14], [15, 16]]
+        self.parent_ids = None
+
+        self.db = self._get_db()
+
+        if is_train and cfg.DATASET.SELECT_DATA:
+            self.db = self.select_data(self.db)
+
+        logger.info('=> load {} samples'.format(len(self.db)))
+
+    def _get_ann_file_keypoint(self):
+        """ self.root / annotations / person_keypoints_train2017.json """
+        prefix = 'dripe_coco'
+
+        return os.path.join(self.root, 'annotations', self.image_set + '.json')
+
+    def _load_image_set_index(self):
+        """ image id: int """
+        image_ids = self.coco.getImgIds()
+        return image_ids
+
+    def _get_db(self):
+        if self.is_train or self.use_gt_bbox:
+            # use ground truth bbox
+            gt_db = self._load_coco_keypoint_annotations()
+        else:
+            # use bbox from detection            
+            gt_db = self._load_coco_person_detection_results()
+        return gt_db
+
+    def _load_coco_keypoint_annotations(self):
+        """ ground truth bbox and keypoints """
+        gt_db = []
+        for index in self.image_set_index:
+            gt_db.extend(self._load_coco_keypoint_annotation_kernal(index))
+        return gt_db
+
+    def _load_coco_keypoint_annotation_kernal(self, index):
+        """
+        coco ann: [u'segmentation', u'area', u'iscrowd', u'image_id', u'bbox', u'category_id', u'id']
+        iscrowd:
+            crowd instances are handled by marking their overlaps with all categories to -1
+            and later excluded in training
+        bbox:
+            [x1, y1, w, h]
+        :param index: coco image id
+        :return: db entry
+        """
+        im_ann = self.coco.loadImgs(index)[0]
+        width = im_ann['width']
+        height = im_ann['height']
+
+        annIds = self.coco.getAnnIds(imgIds=index, iscrowd=False)
+        objs = self.coco.loadAnns(annIds)
+
+        # sanitize bboxes
+        valid_objs = []
+        for obj in objs:
+            x, y, w, h = obj['bbox']
+            x1 = np.max((0, x))
+            y1 = np.max((0, y))
+            x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
+            y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
+            if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
+                # obj['clean_bbox'] = [x1, y1, x2, y2]
+                obj['clean_bbox'] = [x1, y1, x2-x1, y2-y1]
+                valid_objs.append(obj)
+        objs = valid_objs
+
+        rec = []
+        for obj in objs:
+            cls = self._coco_ind_to_class_ind[obj['category_id']]
+            if cls != 1:
+                continue
+
+            # ignore objs without keypoints annotation
+            if max(obj['keypoints']) == 0:
+                continue
+
+            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
+            joints_3d_vis = np.zeros((self.num_joints, 3), dtype=np.float)
+            for ipt in range(self.num_joints):
+                joints_3d[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
+                joints_3d[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
+                joints_3d[ipt, 2] = 0
+                t_vis = obj['keypoints'][ipt * 3 + 2]
+                # if t_vis > 1:
+                #     t_vis = 1
+                joints_3d_vis[ipt, 0] = t_vis
+                joints_3d_vis[ipt, 1] = t_vis
+                joints_3d_vis[ipt, 2] = 0
+
+            center, scale = self._box2cs(obj['clean_bbox'][:4])
+            rec.append({
+                'image': self.image_path_from_index(index),
+                'center': center,
+                'scale': scale,
+                'joints_3d': joints_3d,
+                'joints_3d_vis': joints_3d_vis,
+                'filename': '',
+                'imgnum': 0,
+            })
+
+        return rec
+
+    def _box2cs(self, box):
+        x, y, w, h = box[:4]
+        return self._xywh2cs(x, y, w, h)
+
+    def _xywh2cs(self, x, y, w, h):
+        center = np.zeros((2), dtype=np.float32)
+        center[0] = x + w * 0.5
+        center[1] = y + h * 0.5
+
+        if w > self.aspect_ratio * h:
+            h = w * 1.0 / self.aspect_ratio
+        elif w < self.aspect_ratio * h:
+            w = h * self.aspect_ratio
+        scale = np.array(
+            [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
+            dtype=np.float32)
+        if center[0] != -1:
+            scale = scale * 1.25
+
+        return center, scale
+
+    def image_path_from_index(self, index):
+        """ example: images / train2017 / 000000119993.jpg """
+        file_name = self.ids_to_name[index]
+        data_name = self.image_set  
+
+        image_path = os.path.join(
+            # self.root, 'images', file_name)
+            self.root, 'images', file_name)
+
+        return image_path
+
+    def _load_coco_person_detection_results(self):
+        all_boxes = None
+        with open(self.bbox_file, 'r') as f:
+            all_boxes = json.load(f)
+
+        if not all_boxes:
+            logger.error('=> Load %s fail!' % self.bbox_file)
+            return None
+
+        logger.info('=> Total boxes: {}'.format(len(all_boxes)))
+
+        kpt_db = []
+        num_boxes = 0
+        for n_img in range(0, len(all_boxes)):
+            det_res = all_boxes[n_img]
+            if det_res['category_id'] != 1:
+                continue
+            img_name = self.image_path_from_index(det_res['image_id'])
+            box = det_res['bbox']
+            score = det_res['score']
+
+            if score < self.image_thre:
+                continue
+
+            num_boxes = num_boxes + 1
+
+            center, scale = self._box2cs(box)
+            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
+            joints_3d_vis = np.ones(
+                (self.num_joints, 3), dtype=np.float)
+            kpt_db.append({
+                'image': img_name,
+                'center': center,
+                'scale': scale,
+                'score': score,
+                'joints_3d': joints_3d,
+                'joints_3d_vis': joints_3d_vis,
+            })
+
+        logger.info('=> Total boxes after fiter low score@{}: {}'.format(
+            self.image_thre, num_boxes))
+        return kpt_db
+
+    # need double check this API and classes field
+    def evaluate(self, cfg, preds, output_dir, all_boxes, img_path,
+                 *args, **kwargs):
+        res_folder = os.path.join(output_dir, 'results')
+        if not os.path.exists(res_folder):
+            os.makedirs(res_folder)
+        res_file = os.path.join(
+            res_folder, 'keypoints_%s_results.json' % self.image_set)
+        if 'res_file' in kwargs:
+            res_file = kwargs['res_file']
+
+        # person x (keypoints)
+        _kpts = []
+        for idx, kpt in enumerate(preds):
+            _kpts.append({
+                'keypoints': kpt,
+                'center': all_boxes[idx][0:2],
+                'scale': all_boxes[idx][2:4],
+                'area': all_boxes[idx][4],
+                'score': all_boxes[idx][5],
+                'image': int(self.name_to_ids[img_path[idx].split('/')[-1]])
+            })
+        # image x person x (keypoints)
+        kpts = defaultdict(list)
+        
+        for kpt in _kpts:
+            kpts[kpt['image']].append(kpt)
+            
+        print('Welcome to evaluate\nLen pred: {}\tLen _kpts: {}\tLen kpts: {}'.format(len(preds), len(_kpts), len(kpts)))
+
+        # rescoring and oks nms
+        num_joints = self.num_joints
+        in_vis_thre = self.in_vis_thre
+        oks_thre = self.oks_thre
+        oks_nmsed_kpts = []
+        for img in kpts.keys():
+            img_kpts = kpts[img]
+            for n_p in img_kpts:
+                box_score = n_p['score']
+                kpt_score = 0
+                valid_num = 0
+                for n_jt in range(0, num_joints):
+                    t_s = n_p['keypoints'][n_jt][2]
+                    if t_s > in_vis_thre:
+                        kpt_score = kpt_score + t_s
+                        valid_num = valid_num + 1
+                if valid_num != 0:
+                    kpt_score = kpt_score / valid_num
+                
+                # rescoring
+                n_p['score'] = kpt_score * box_score
+            keep = oks_nms([img_kpts[i] for i in range(len(img_kpts))],
+                           oks_thre)
+            if len(keep) == 0:
+                oks_nmsed_kpts.append(img_kpts)
+            else:
+                oks_nmsed_kpts.append([img_kpts[_keep] for _keep in keep])
+        
+        self._write_coco_keypoint_results(
+            oks_nmsed_kpts, res_file)
+
+        '''
+        if 'test' not in self.image_set:
+        
+        else:
+            return {'Null': 0}, 0
+        '''                
+        info_str = self._do_python_keypoint_eval(
+            res_file, res_folder)
+        name_value = OrderedDict(info_str)
+        return name_value, name_value['AP']
+        
+    def _write_coco_keypoint_results(self, keypoints, res_file):
+        data_pack = [{'cat_id': self._class_to_coco_ind[cls],
+                      'cls_ind': cls_ind,
+                      'cls': cls,
+                      'ann_type': 'keypoints',
+                      'keypoints': keypoints
+                      }
+                     for cls_ind, cls in enumerate(self.classes) if not cls == '__background__']
+
+        results = self._coco_keypoint_results_one_category_kernel(data_pack[0])
+        logger.info('=> Writing results json to %s' % res_file)
+        with open(res_file, 'w') as f:
+            json.dump(results, f, sort_keys=True, indent=4)
+        try:
+            json.load(open(res_file))
+        except Exception:
+            content = []
+            with open(res_file, 'r') as f:
+                for line in f:
+                    content.append(line)
+            content[-1] = ']'
+            with open(res_file, 'w') as f:
+                for c in content:
+                    f.write(c)
+
+    def _coco_keypoint_results_one_category_kernel(self, data_pack):
+        cat_id = data_pack['cat_id']
+        keypoints = data_pack['keypoints']
+        cat_results = []
+
+        for img_kpts in keypoints:
+            if len(img_kpts) == 0:
+                continue
+
+            _key_points = np.array([img_kpts[k]['keypoints']
+                                    for k in range(len(img_kpts))])
+            key_points = np.zeros(
+                (_key_points.shape[0], self.num_joints * 3), dtype=np.float)
+
+            for ipt in range(self.num_joints):
+                key_points[:, ipt * 3 + 0] = _key_points[:, ipt, 0]
+                key_points[:, ipt * 3 + 1] = _key_points[:, ipt, 1]
+                key_points[:, ipt * 3 + 2] = _key_points[:, ipt, 2]  # keypoints score.
+
+            result = [{'image_id': img_kpts[k]['image'],
+                       'category_id': cat_id,
+                       'keypoints': list(key_points[k]),
+                       'score': img_kpts[k]['score'],
+                       'center': list(img_kpts[k]['center']),
+                       'scale': list(img_kpts[k]['scale'])
+                       } for k in range(len(img_kpts))]
+            cat_results.extend(result)
+
+        return cat_results
+
+    def _do_python_keypoint_eval(self, res_file, res_folder):
+        coco_dt = self.coco.loadRes(res_file)
+        coco_eval = COCOeval(self.coco, coco_dt, 'keypoints')
+        coco_eval.params.useSegm = None
+        coco_eval.evaluate()
+        coco_eval.accumulate()
+        coco_eval.summarize()
+        stats_names = ['AP', 'Ap .5', 'AP .75', 'AP (M)', 'AP (L)', 'AR', 'AR .5', 'AR .75', 'AR (M)', 'AR (L)']
+
+        info_str = []
+        for ind, name in enumerate(stats_names):
+            info_str.append((name, coco_eval.stats[ind]))
+
+        eval_file = os.path.join(
+            res_folder, 'keypoints_%s_results.pkl' % self.image_set)
+
+        with open(eval_file, 'wb') as f:
+            pickle.dump(coco_eval, f, pickle.HIGHEST_PROTOCOL)
+        logger.info('=> dripe eval results saved to %s' % eval_file)
+
+        return info_str
diff --git a/lib/dataset/mpii.py b/lib/dataset/mpii.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cadb1ac701d58809740e07fa99c5cc538f682e8
--- /dev/null
+++ b/lib/dataset/mpii.py
@@ -0,0 +1,176 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import OrderedDict
+import logging
+import os
+import json_tricks as json
+
+import numpy as np
+from scipy.io import loadmat, savemat
+
+from dataset.JointsDataset import JointsDataset
+
+
+logger = logging.getLogger(__name__)
+
+
+class MPIIDataset(JointsDataset):
+    def __init__(self, cfg, root, image_set, is_train, transform=None):
+        super().__init__(cfg, root, image_set, is_train, transform)
+
+        self.num_joints = 16
+        self.flip_pairs = [[0, 5], [1, 4], [2, 3], [10, 15], [11, 14], [12, 13]]
+        self.parent_ids = [1, 2, 6, 6, 3, 4, 6, 6, 7, 8, 11, 12, 7, 7, 13, 14]
+
+        self.db = self._get_db()
+
+        if is_train and cfg.DATASET.SELECT_DATA:
+            self.db = self.select_data(self.db)
+
+        logger.info('=> load {} samples'.format(len(self.db)))
+
+    def _get_db(self):
+        # create train/val split
+        file_name = os.path.join(self.root,
+                                 'annot',
+                                 self.image_set+'.json')
+        with open(file_name) as anno_file:
+            anno = json.load(anno_file)
+
+        gt_db = []
+        for a in anno:
+            image_name = a['image']
+
+            c = np.array(a['center'], dtype=np.float)
+            s = np.array([a['scale'], a['scale']], dtype=np.float)
+
+            # Adjust center/scale slightly to avoid cropping limbs
+            if c[0] != -1:
+                c[1] = c[1] + 15 * s[1]
+                s = s * 1.25
+
+            # MPII uses matlab format, index is based 1,
+            # we should first convert to 0-based index
+            c = c - 1
+
+            joints_3d = np.zeros((self.num_joints, 3), dtype=np.float)
+            joints_3d_vis = np.zeros((self.num_joints,  3), dtype=np.float)
+            if self.image_set != 'test':
+                joints = np.array(a['joints'])
+                joints[:, 0:2] = joints[:, 0:2] - 1
+                joints_vis = np.array(a['joints_vis'])
+                assert len(joints) == self.num_joints, \
+                    'joint num diff: {} vs {}'.format(len(joints),
+                                                      self.num_joints)
+
+                joints_3d[:, 0:2] = joints[:, 0:2]
+                joints_3d_vis[:, 0] = joints_vis[:]
+                joints_3d_vis[:, 1] = joints_vis[:]
+
+            image_dir = 'images.zip@' if self.data_format == 'zip' else 'images'
+            gt_db.append({
+                'image': os.path.join(self.root, image_dir, image_name),
+                'center': c,
+                'scale': s,
+                'joints_3d': joints_3d,
+                'joints_3d_vis': joints_3d_vis,
+                'filename': '',
+                'imgnum': 0,
+                })
+
+        return gt_db
+
+    def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
+        # convert 0-based index to 1-based index
+        preds = preds[:, :, 0:2] + 1.0
+
+        if output_dir:
+            pred_file = os.path.join(output_dir, 'pred.mat')
+            savemat(pred_file, mdict={'preds': preds})
+
+        if 'test' in cfg.DATASET.TEST_SET:
+            return {'Null': 0.0}, 0.0
+
+        SC_BIAS = 0.6
+        threshold = 0.5
+
+        gt_file = os.path.join(cfg.DATASET.ROOT,
+                               'annot',
+                               'gt_{}.mat'.format(cfg.DATASET.TEST_SET))
+        gt_dict = loadmat(gt_file)
+        dataset_joints = gt_dict['dataset_joints']
+        jnt_missing = gt_dict['jnt_missing']
+        pos_gt_src = gt_dict['pos_gt_src']
+        headboxes_src = gt_dict['headboxes_src']
+
+        pos_pred_src = np.transpose(preds, [1, 2, 0])
+
+        head = np.where(dataset_joints == 'head')[1][0]
+        lsho = np.where(dataset_joints == 'lsho')[1][0]
+        lelb = np.where(dataset_joints == 'lelb')[1][0]
+        lwri = np.where(dataset_joints == 'lwri')[1][0]
+        lhip = np.where(dataset_joints == 'lhip')[1][0]
+        lkne = np.where(dataset_joints == 'lkne')[1][0]
+        lank = np.where(dataset_joints == 'lank')[1][0]
+
+        rsho = np.where(dataset_joints == 'rsho')[1][0]
+        relb = np.where(dataset_joints == 'relb')[1][0]
+        rwri = np.where(dataset_joints == 'rwri')[1][0]
+        rkne = np.where(dataset_joints == 'rkne')[1][0]
+        rank = np.where(dataset_joints == 'rank')[1][0]
+        rhip = np.where(dataset_joints == 'rhip')[1][0]
+
+        jnt_visible = 1 - jnt_missing
+        uv_error = pos_pred_src - pos_gt_src
+        uv_err = np.linalg.norm(uv_error, axis=1)
+        headsizes = headboxes_src[1, :, :] - headboxes_src[0, :, :]
+        headsizes = np.linalg.norm(headsizes, axis=0)
+        headsizes *= SC_BIAS
+        scale = np.multiply(headsizes, np.ones((len(uv_err), 1)))
+        scaled_uv_err = np.divide(uv_err, scale)
+        scaled_uv_err = np.multiply(scaled_uv_err, jnt_visible)
+        jnt_count = np.sum(jnt_visible, axis=1)
+        less_than_threshold = np.multiply((scaled_uv_err <= threshold),
+                                          jnt_visible)
+        PCKh = np.divide(100.*np.sum(less_than_threshold, axis=1), jnt_count)
+
+        # save
+        rng = np.arange(0, 0.5+0.01, 0.01)
+        pckAll = np.zeros((len(rng), 16))
+
+        for r in range(len(rng)):
+            threshold = rng[r]
+            less_than_threshold = np.multiply(scaled_uv_err <= threshold,
+                                              jnt_visible)
+            pckAll[r, :] = np.divide(100.*np.sum(less_than_threshold, axis=1),
+                                     jnt_count)
+
+        PCKh = np.ma.array(PCKh, mask=False)
+        PCKh.mask[6:8] = True
+
+        jnt_count = np.ma.array(jnt_count, mask=False)
+        jnt_count.mask[6:8] = True
+        jnt_ratio = jnt_count / np.sum(jnt_count).astype(np.float64)
+
+        name_value = [
+            ('Head', PCKh[head]),
+            ('Shoulder', 0.5 * (PCKh[lsho] + PCKh[rsho])),
+            ('Elbow', 0.5 * (PCKh[lelb] + PCKh[relb])),
+            ('Wrist', 0.5 * (PCKh[lwri] + PCKh[rwri])),
+            ('Hip', 0.5 * (PCKh[lhip] + PCKh[rhip])),
+            ('Knee', 0.5 * (PCKh[lkne] + PCKh[rkne])),
+            ('Ankle', 0.5 * (PCKh[lank] + PCKh[rank])),
+            ('Mean', np.sum(PCKh * jnt_ratio)),
+            ('Mean@0.1', np.sum(pckAll[11, :] * jnt_ratio))
+        ]
+        name_value = OrderedDict(name_value)
+
+        return name_value, name_value['Mean']
diff --git a/lib/models/__init__.py b/lib/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b695778113b14e5555e65ff215d2545fe0472fa6
--- /dev/null
+++ b/lib/models/__init__.py
@@ -0,0 +1,12 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import models.pose_resnet
+import models.pose_vis
\ No newline at end of file
diff --git a/lib/models/modules.py b/lib/models/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c77da9d792983c7d837ee8de4bda69d8f3451cac
--- /dev/null
+++ b/lib/models/modules.py
@@ -0,0 +1,196 @@
+import os
+import logging
+
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+from .resnet import Bottleneck, BN_MOMENTUM
+
+logger = logging.getLogger(__name__)
+
+
+class DeconvStage(nn.Module):
+    def __init__(self, inplanes, cfg):
+        super(DeconvStage, self).__init__()
+
+        self.inplanes = inplanes
+        extra = cfg.MODEL.EXTRA
+        self.deconv_with_bias = extra.DECONV_WITH_BIAS
+        self.nb_joints = cfg.MODEL.NUM_JOINTS
+
+        self.deconv_layers = self._make_deconv_layer(
+            extra.NUM_DECONV_LAYERS,
+            extra.NUM_DECONV_FILTERS,
+            extra.NUM_DECONV_KERNELS,
+        )
+
+        self.final_layer = nn.Conv2d(
+            in_channels=extra.NUM_DECONV_FILTERS[-1],
+            out_channels=self.nb_joints,
+            kernel_size=extra.FINAL_CONV_KERNEL,
+            stride=1,
+            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
+        )
+
+    def _get_deconv_cfg(self, deconv_kernel, index):
+        if deconv_kernel == 4:
+            padding = 1
+            output_padding = 0
+        elif deconv_kernel == 3:
+            padding = 1
+            output_padding = 1
+        elif deconv_kernel == 2:
+            padding = 0
+            output_padding = 0
+
+        return deconv_kernel, padding, output_padding
+
+    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
+        assert num_layers == len(num_filters), \
+            'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+        assert num_layers == len(num_kernels), \
+            'ERROR: num_deconv_layers is different len(num_deconv_filters)'
+
+        layers = []
+        for i in range(num_layers):
+            kernel, padding, output_padding = \
+                self._get_deconv_cfg(num_kernels[i], i)
+
+            planes = num_filters[i]
+            layers.append(
+                nn.ConvTranspose2d(
+                    in_channels=self.inplanes,
+                    out_channels=planes,
+                    kernel_size=kernel,
+                    stride=2,
+                    padding=padding,
+                    output_padding=output_padding,
+                    bias=self.deconv_with_bias))
+            layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
+            layers.append(nn.ReLU(inplace=True))
+            self.inplanes = planes
+
+        return nn.Sequential(*layers)
+
+    def forward(self, feats):
+        x = self.deconv_layers(feats)
+        x = self.final_layer(x)
+
+        return x
+
+    def freeze(self, freeze=True):
+        for layer in [
+            self.deconv_layers,
+            self.final_layer,
+        ]:
+            for p in layer.parameters():
+                p.requires_grad = not freeze
+
+
+class DeconvStageVis(DeconvStage):
+    def __init__(self, inplanes, block, cfg):
+        super(DeconvStageVis, self).__init__(inplanes, cfg)
+
+        self.nb_vis = cfg.MODEL.NB_VIS
+        extra = cfg.MODEL.EXTRA
+
+        # New visibility deductor
+        self.fc = self._make_fc(512 * block.expansion,
+                                extra.HEATMAP_SIZE,
+                                extra.NUM_DECONV_LAYERS,
+                                fc_sizes=cfg.MODEL.EXTRA.NUM_LINEAR_LAYERS
+                                )
+
+    def _make_fc(self, input_channel=512 * 4, hm_size=[64, 48], deconv_ratio=3, fc_sizes=[4096, 2048, 1024]):
+        out_conv = 512
+        max_pool = 2
+        layers = [
+            Bottleneck(input_channel, int(out_conv / Bottleneck.expansion),
+                       downsample=nn.Conv2d(input_channel, out_conv, kernel_size=1, stride=1, padding=0, bias=False)),
+            nn.BatchNorm2d(out_conv, momentum=BN_MOMENTUM), nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=max_pool, stride=max_pool, padding=0), nn.Flatten(), ]
+
+        # input layer
+        layers += [
+            nn.Linear(int(out_conv * hm_size[0] * hm_size[1] / (max_pool * 2 ** deconv_ratio) ** 2), fc_sizes[0]),
+            nn.ReLU()]
+
+        # hidden layers
+        for i in range(len(fc_sizes) - 1):
+            layers.append(nn.Linear(fc_sizes[i], fc_sizes[i + 1]))
+            layers.append(nn.ReLU())
+
+        # output layers
+        layers.append(nn.Linear(fc_sizes[-1], self.nb_vis * self.nb_joints))
+        layers.append(nn.ReLU())
+
+        return nn.Sequential(*layers)
+
+    def forward(self, feats):
+        x = self.deconv_layers(feats)
+        x = self.final_layer(x)
+
+        vis_feats = self.fc(feats)
+        vis_preds = vis_feats.reshape((vis_feats.shape[0], self.nb_joints, self.nb_vis))
+
+        return x, vis_preds
+
+
+class PoseNet(nn.Module):
+    def __init__(self):
+        super(PoseNet, self).__init__()
+
+        self.final_stage = None
+
+    def forward(self, x):
+        raise NotImplementedError()
+
+    def init_weights(self, pretrained_pth=''):
+        if os.path.isfile(pretrained_pth):
+            logger.info('=> init deconv weights from normal distribution')
+            for name, m in self.final_stage.deconv_layers.named_modules():
+                if isinstance(m, nn.ConvTranspose2d):
+                    logger.info('=> init {}.weight as normal(0, 0.001)'.format(name))
+                    logger.info('=> init {}.bias as 0'.format(name))
+                    nn.init.normal_(m.weight, std=0.001)
+                    if self.final_stage.deconv_with_bias:
+                        nn.init.constant_(m.bias, 0)
+                elif isinstance(m, nn.BatchNorm2d):
+                    logger.info('=> init {}.weight as 1'.format(name))
+                    logger.info('=> init {}.bias as 0'.format(name))
+                    nn.init.constant_(m.weight, 1)
+                    nn.init.constant_(m.bias, 0)
+            logger.info('=> init final conv weights from normal distribution')
+            for m in self.final_stage.final_layer.modules():
+                if isinstance(m, nn.Conv2d):
+                    # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                    logger.info('=> init {}.weight as normal(0, 0.001)'.format(name))
+                    logger.info('=> init {}.bias as 0'.format(name))
+                    nn.init.normal_(m.weight, std=0.001)
+                    nn.init.constant_(m.bias, 0)
+
+            # pretrained_state_dict = torch.load(pretrained)
+            logger.info('=> loading pretrained model {}'.format(pretrained_pth))
+            # self.load_state_dict(pretrained_state_dict, strict=False)
+            checkpoint = torch.load(pretrained_pth)
+            if isinstance(checkpoint, OrderedDict):
+                state_dict = checkpoint
+            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+                state_dict_old = checkpoint['state_dict']
+                state_dict = OrderedDict()
+                # delete 'module.' because it is saved from DataParallel module
+                for key in state_dict_old.keys():
+                    if key.startswith('module.'):
+                        # state_dict[key[7:]] = state_dict[key]
+                        # state_dict.pop(key)
+                        state_dict[key[7:]] = state_dict_old[key]
+                    else:
+                        state_dict[key] = state_dict_old[key]
+            else:
+                raise RuntimeError(
+                    'No state_dict found in checkpoint file {}'.format(pretrained_pth))
+            self.load_state_dict(state_dict, strict=False)
+        else:
+            logger.error(f'=> imagenet pretrained model dose not exist : {pretrained_pth}')
+            logger.error('=> please download it first')
+            raise ValueError('imagenet pretrained model does not exist')
diff --git a/lib/models/pose_resnet.py b/lib/models/pose_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..206aed967265bb9286f837617f724a6350a11f83
--- /dev/null
+++ b/lib/models/pose_resnet.py
@@ -0,0 +1,80 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+
+from .modules import DeconvStage, PoseNet
+from .resnet import ResNet, BasicBlock, Bottleneck
+from utils.utils import convert_state_dict
+
+logger = logging.getLogger(__name__)
+
+
+class PoseResNet(PoseNet):
+    def __init__(self, block, layers, cfg, **kwargs):
+        super(PoseResNet, self).__init__()
+
+        self.resnet = ResNet(block, layers, cfg, **kwargs)
+        self.final_stage = DeconvStage(self.resnet.inplanes, cfg)
+
+    def forward(self, x):
+        x = self.resnet(x)
+        x = self.final_stage(x)
+        return x
+
+    def freeze_encoder(self, freeze=True):
+        self.resnet.freeze(freeze=freeze)
+        logger.info("Encoder frozen")
+
+    def freeze_deconv(self, freeze=True):
+        self.final_stage.freeze(freeze=freeze)
+        logger.info("Deconv frozen")
+
+    def load_state_dict(self, state_dict, strict=True):
+        state_dict = OrderedDict({k.replace('module.', ''): v
+                                  for k, v in state_dict.items()})
+        try:
+            nn.Module.load_state_dict(self, state_dict=state_dict, strict=True)
+        except RuntimeError as err:
+            str_err = str(err)
+            if 'Missing key(s)' not in str_err and 'Unexpected key(s)' not in str_err:
+                raise err
+            if strict:
+                new_dict = convert_state_dict(state_dict, ['resnet', 'final_stage'], err)
+            else:
+                print('\n'.join(str(err).split('\n')[1:]))
+                new_dict = state_dict
+            nn.Module.load_state_dict(self, state_dict=new_dict, strict=strict)
+
+
+resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
+               34: (BasicBlock, [3, 4, 6, 3]),
+               50: (Bottleneck, [3, 4, 6, 3]),
+               101: (Bottleneck, [3, 4, 23, 3]),
+               152: (Bottleneck, [3, 8, 36, 3])}
+
+
+def get_pose_net(cfg, is_train, **kwargs):
+    num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
+    style = cfg.MODEL.STYLE
+
+    block_class, layers = resnet_spec[num_layers]
+
+    model = PoseResNet(block_class, layers, cfg, **kwargs)
+
+    if is_train and cfg.MODEL.INIT_WEIGHTS:
+        model.init_weights(cfg.MODEL.PRETRAINED)
+
+    return model
diff --git a/lib/models/pose_vis.py b/lib/models/pose_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0aaa5a4d05ee946244adc0cbf767cc0d1aec28b
--- /dev/null
+++ b/lib/models/pose_vis.py
@@ -0,0 +1,39 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from .modules import DeconvStageVis
+from .resnet import ResNet
+from .pose_resnet import PoseResNet, resnet_spec
+
+
+class PoseResNetVis(PoseResNet):
+    def __init__(self, block, layers, cfg, **kwargs):
+        super(PoseResNetVis, self).__init__(block, layers, cfg, **kwargs)
+        self.nb_vis = 3
+
+        self.resnet = ResNet(block, layers, cfg, **kwargs)
+        self.final_stage = DeconvStageVis(self.resnet.inplanes, block, cfg)
+
+    def forward(self, x):
+        x = self.resnet(x)
+        x, vis_preds = self.final_stage(x)
+
+        return x, vis_preds
+
+
+def get_pose_net(cfg, is_train, **kwargs):
+    num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
+    style = cfg.MODEL.STYLE
+
+    block_class, layers = resnet_spec[num_layers]
+
+    if style == 'caffe':
+        raise NotImplementedError('Caffe not handled')
+
+    model = PoseResNetVis(block_class, layers, cfg, **kwargs)
+
+    if is_train and cfg.MODEL.INIT_WEIGHTS:
+        model.init_weights(cfg.MODEL.PRETRAINED)
+
+    return model
diff --git a/lib/models/resnet.py b/lib/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ded256ba45c60707900a04d50e316142d43abe
--- /dev/null
+++ b/lib/models/resnet.py
@@ -0,0 +1,145 @@
+import torch.nn as nn
+BN_MOMENTUM = 0.1
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+                               padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
+                               bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion,
+                                  momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+
+class ResNet(nn.Module):
+    def __init__(self, block, layers, cfg, **kwargs):
+        super(ResNet, self).__init__()
+
+        self.inplanes = 64
+        extra = cfg.MODEL.EXTRA
+
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                nn.Conv2d(self.inplanes, planes * block.expansion,
+                          kernel_size=1, stride=stride, bias=False),
+                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        return x
+
+    def freeze(self, freeze=True):
+        for layer in [
+            self.conv1,
+            self.bn1,
+            self.relu,
+            self.maxpool,
+            self.layer1,
+            self.layer2,
+            self.layer3,
+            self.layer4,
+        ]:
+            for p in layer.parameters():
+                p.requires_grad = not freeze
+
diff --git a/lib/nms/__init__.py b/lib/nms/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/nms/cpu_nms.pyx b/lib/nms/cpu_nms.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..59388fe78d16620671ce1d929388e9553e30e510
--- /dev/null
+++ b/lib/nms/cpu_nms.pyx
@@ -0,0 +1,67 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
+# ------------------------------------------------------------------------------
+
+import numpy as np
+cimport numpy as np
+
+cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
+    return a if a >= b else b
+
+cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
+    return a if a <= b else b
+
+def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
+    cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
+    cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
+    cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
+    cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
+    cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
+
+    cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1].astype('i')
+
+    cdef int ndets = dets.shape[0]
+    cdef np.ndarray[np.int_t, ndim=1] suppressed = \
+            np.zeros((ndets), dtype=np.int)
+
+    # nominal indices
+    cdef int _i, _j
+    # sorted indices
+    cdef int i, j
+    # temp variables for box i's (the box currently under consideration)
+    cdef np.float32_t ix1, iy1, ix2, iy2, iarea
+    # variables for computing overlap with box j (lower scoring box)
+    cdef np.float32_t xx1, yy1, xx2, yy2
+    cdef np.float32_t w, h
+    cdef np.float32_t inter, ovr
+
+    keep = []
+    for _i in range(ndets):
+        i = order[_i]
+        if suppressed[i] == 1:
+            continue
+        keep.append(i)
+        ix1 = x1[i]
+        iy1 = y1[i]
+        ix2 = x2[i]
+        iy2 = y2[i]
+        iarea = areas[i]
+        for _j in range(_i + 1, ndets):
+            j = order[_j]
+            if suppressed[j] == 1:
+                continue
+            xx1 = max(ix1, x1[j])
+            yy1 = max(iy1, y1[j])
+            xx2 = min(ix2, x2[j])
+            yy2 = min(iy2, y2[j])
+            w = max(0.0, xx2 - xx1 + 1)
+            h = max(0.0, yy2 - yy1 + 1)
+            inter = w * h
+            ovr = inter / (iarea + areas[j] - inter)
+            if ovr >= thresh:
+                suppressed[j] = 1
+
+    return keep
diff --git a/lib/nms/gpu_nms.hpp b/lib/nms/gpu_nms.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..68b6d42cd88b59496b22a9e77919abe529b09014
--- /dev/null
+++ b/lib/nms/gpu_nms.hpp
@@ -0,0 +1,2 @@
+void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
+          int boxes_dim, float nms_overlap_thresh, int device_id);
diff --git a/lib/nms/gpu_nms.pyx b/lib/nms/gpu_nms.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..b6ff660cf70e16680f3772d2b6caf05f1dd26e0c
--- /dev/null
+++ b/lib/nms/gpu_nms.pyx
@@ -0,0 +1,30 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
+# ------------------------------------------------------------------------------
+
+import numpy as np
+cimport numpy as np
+
+assert sizeof(int) == sizeof(np.int32_t)
+
+cdef extern from "gpu_nms.hpp":
+    void _nms(np.int32_t*, int*, np.float32_t*, int, int, float, int)
+
+def gpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh,
+            np.int32_t device_id=0):
+    cdef int boxes_num = dets.shape[0]
+    cdef int boxes_dim = dets.shape[1]
+    cdef int num_out
+    cdef np.ndarray[np.int32_t, ndim=1] \
+        keep = np.zeros(boxes_num, dtype=np.int32)
+    cdef np.ndarray[np.float32_t, ndim=1] \
+        scores = dets[:, 4]
+    cdef np.ndarray[np.int32_t, ndim=1] \
+        order = scores.argsort()[::-1].astype(np.int32)
+    cdef np.ndarray[np.float32_t, ndim=2] \
+        sorted_dets = dets[order, :]
+    _nms(&keep[0], &num_out, &sorted_dets[0, 0], boxes_num, boxes_dim, thresh, device_id)
+    keep = keep[:num_out]
+    return list(order[keep])
diff --git a/lib/nms/nms.py b/lib/nms/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..e02850d9535609cecb854d59a4050c8eb696083e
--- /dev/null
+++ b/lib/nms/nms.py
@@ -0,0 +1,123 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from .cpu_nms import cpu_nms
+from .gpu_nms import gpu_nms
+
+
+def py_nms_wrapper(thresh):
+    def _nms(dets):
+        return nms(dets, thresh)
+    return _nms
+
+
+def cpu_nms_wrapper(thresh):
+    def _nms(dets):
+        return cpu_nms(dets, thresh)
+    return _nms
+
+
+def gpu_nms_wrapper(thresh, device_id):
+    def _nms(dets):
+        return gpu_nms(dets, thresh, device_id)
+    return _nms
+
+
+def nms(dets, thresh):
+    """
+    greedily select boxes with high confidence and overlap with current maximum <= thresh
+    rule out overlap >= thresh
+    :param dets: [[x1, y1, x2, y2 score]]
+    :param thresh: retain overlap < thresh
+    :return: indexes to keep
+    """
+    if dets.shape[0] == 0:
+        return []
+
+    x1 = dets[:, 0]
+    y1 = dets[:, 1]
+    x2 = dets[:, 2]
+    y2 = dets[:, 3]
+    scores = dets[:, 4]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+def oks_iou(g, d, a_g, a_d, sigmas=None, in_vis_thre=None):
+    if not isinstance(sigmas, np.ndarray):
+        sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
+    vars = (sigmas * 2) ** 2
+    xg = g[0::3]
+    yg = g[1::3]
+    vg = g[2::3]
+    ious = np.zeros((d.shape[0]))
+    for n_d in range(0, d.shape[0]):
+        xd = d[n_d, 0::3]
+        yd = d[n_d, 1::3]
+        vd = d[n_d, 2::3]
+        dx = xd - xg
+        dy = yd - yg
+        e = (dx ** 2 + dy ** 2) / vars / ((a_g + a_d[n_d]) / 2 + np.spacing(1)) / 2
+        if in_vis_thre is not None:
+            ind = list(vg > in_vis_thre) and list(vd > in_vis_thre)
+            e = e[ind]
+        ious[n_d] = np.sum(np.exp(-e)) / e.shape[0] if e.shape[0] != 0 else 0.0
+    return ious
+
+def oks_nms(kpts_db, thresh, sigmas=None, in_vis_thre=None):
+    """
+    greedily select boxes with high confidence and overlap with current maximum <= thresh
+    rule out overlap >= thresh, overlap = oks
+    :param kpts_db
+    :param thresh: retain overlap < thresh
+    :return: indexes to keep
+    """
+    if len(kpts_db) == 0:
+        return []
+
+    scores = np.array([kpts_db[i]['score'] for i in range(len(kpts_db))])
+    kpts = np.array([kpts_db[i]['keypoints'].flatten() for i in range(len(kpts_db))])
+    areas = np.array([kpts_db[i]['area'] for i in range(len(kpts_db))])
+
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+
+        oks_ovr = oks_iou(kpts[i], kpts[order[1:]], areas[i], areas[order[1:]], sigmas, in_vis_thre)
+
+        inds = np.where(oks_ovr <= thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
diff --git a/lib/nms/nms_kernel.cu b/lib/nms/nms_kernel.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f6176c6de2274cbf9f3527a3b0220cd3f23f3ae6
--- /dev/null
+++ b/lib/nms/nms_kernel.cu
@@ -0,0 +1,143 @@
+// ------------------------------------------------------------------
+// Copyright (c) Microsoft
+// Licensed under The MIT License
+// Modified from MATLAB Faster R-CNN (https://github.com/shaoqingren/faster_rcnn)
+// ------------------------------------------------------------------
+
+#include "gpu_nms.hpp"
+#include <vector>
+#include <iostream>
+
+#define CUDA_CHECK(condition) \
+  /* Code block avoids redefinition of cudaError_t error */ \
+  do { \
+    cudaError_t error = condition; \
+    if (error != cudaSuccess) { \
+      std::cout << cudaGetErrorString(error) << std::endl; \
+    } \
+  } while (0)
+
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+int const threadsPerBlock = sizeof(unsigned long long) * 8;
+
+__device__ inline float devIoU(float const * const a, float const * const b) {
+  float left = max(a[0], b[0]), right = min(a[2], b[2]);
+  float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
+  float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
+  float interS = width * height;
+  float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
+  float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
+  return interS / (Sa + Sb - interS);
+}
+
+__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
+                           const float *dev_boxes, unsigned long long *dev_mask) {
+  const int row_start = blockIdx.y;
+  const int col_start = blockIdx.x;
+
+  // if (row_start > col_start) return;
+
+  const int row_size =
+        min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
+  const int col_size =
+        min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
+
+  __shared__ float block_boxes[threadsPerBlock * 5];
+  if (threadIdx.x < col_size) {
+    block_boxes[threadIdx.x * 5 + 0] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
+    block_boxes[threadIdx.x * 5 + 1] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
+    block_boxes[threadIdx.x * 5 + 2] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
+    block_boxes[threadIdx.x * 5 + 3] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
+    block_boxes[threadIdx.x * 5 + 4] =
+        dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
+  }
+  __syncthreads();
+
+  if (threadIdx.x < row_size) {
+    const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
+    const float *cur_box = dev_boxes + cur_box_idx * 5;
+    int i = 0;
+    unsigned long long t = 0;
+    int start = 0;
+    if (row_start == col_start) {
+      start = threadIdx.x + 1;
+    }
+    for (i = start; i < col_size; i++) {
+      if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
+        t |= 1ULL << i;
+      }
+    }
+    const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
+    dev_mask[cur_box_idx * col_blocks + col_start] = t;
+  }
+}
+
+void _set_device(int device_id) {
+  int current_device;
+  CUDA_CHECK(cudaGetDevice(&current_device));
+  if (current_device == device_id) {
+    return;
+  }
+  // The call to cudaSetDevice must come before any calls to Get, which
+  // may perform initialization using the GPU.
+  CUDA_CHECK(cudaSetDevice(device_id));
+}
+
+void _nms(int* keep_out, int* num_out, const float* boxes_host, int boxes_num,
+          int boxes_dim, float nms_overlap_thresh, int device_id) {
+  _set_device(device_id);
+
+  float* boxes_dev = NULL;
+  unsigned long long* mask_dev = NULL;
+
+  const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
+
+  CUDA_CHECK(cudaMalloc(&boxes_dev,
+                        boxes_num * boxes_dim * sizeof(float)));
+  CUDA_CHECK(cudaMemcpy(boxes_dev,
+                        boxes_host,
+                        boxes_num * boxes_dim * sizeof(float),
+                        cudaMemcpyHostToDevice));
+
+  CUDA_CHECK(cudaMalloc(&mask_dev,
+                        boxes_num * col_blocks * sizeof(unsigned long long)));
+
+  dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
+              DIVUP(boxes_num, threadsPerBlock));
+  dim3 threads(threadsPerBlock);
+  nms_kernel<<<blocks, threads>>>(boxes_num,
+                                  nms_overlap_thresh,
+                                  boxes_dev,
+                                  mask_dev);
+
+  std::vector<unsigned long long> mask_host(boxes_num * col_blocks);
+  CUDA_CHECK(cudaMemcpy(&mask_host[0],
+                        mask_dev,
+                        sizeof(unsigned long long) * boxes_num * col_blocks,
+                        cudaMemcpyDeviceToHost));
+
+  std::vector<unsigned long long> remv(col_blocks);
+  memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
+
+  int num_to_keep = 0;
+  for (int i = 0; i < boxes_num; i++) {
+    int nblock = i / threadsPerBlock;
+    int inblock = i % threadsPerBlock;
+
+    if (!(remv[nblock] & (1ULL << inblock))) {
+      keep_out[num_to_keep++] = i;
+      unsigned long long *p = &mask_host[0] + i * col_blocks;
+      for (int j = nblock; j < col_blocks; j++) {
+        remv[j] |= p[j];
+      }
+    }
+  }
+  *num_out = num_to_keep;
+
+  CUDA_CHECK(cudaFree(boxes_dev));
+  CUDA_CHECK(cudaFree(mask_dev));
+}
diff --git a/lib/nms/setup.py b/lib/nms/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..faada13e9330598c7861491ce6a639aee2450014
--- /dev/null
+++ b/lib/nms/setup.py
@@ -0,0 +1,140 @@
+# --------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Modified from py-faster-rcnn (https://github.com/rbgirshick/py-faster-rcnn)
+# --------------------------------------------------------
+
+import os
+from os.path import join as pjoin
+from setuptools import setup
+from distutils.extension import Extension
+from Cython.Distutils import build_ext
+import numpy as np
+
+
+def find_in_path(name, path):
+    "Find a file in a search path"
+    # Adapted fom
+    # http://code.activestate.com/recipes/52224-find-a-file-given-a-search-path/
+    for dir in path.split(os.pathsep):
+        binpath = pjoin(dir, name)
+        if os.path.exists(binpath):
+            return os.path.abspath(binpath)
+    return None
+
+
+def locate_cuda():
+    """Locate the CUDA environment on the system
+    Returns a dict with keys 'home', 'nvcc', 'include', and 'lib64'
+    and values giving the absolute path to each directory.
+    Starts by looking for the CUDAHOME env variable. If not found, everything
+    is based on finding 'nvcc' in the PATH.
+    """
+
+    # first check if the CUDAHOME env variable is in use
+    if 'CUDAHOME' in os.environ:
+        home = os.environ['CUDAHOME']
+        nvcc = pjoin(home, 'bin', 'nvcc')
+    else:
+        # otherwise, search the PATH for NVCC
+        default_path = pjoin(os.sep, 'usr', 'local', 'cuda', 'bin')
+        nvcc = find_in_path('nvcc', os.environ['PATH'] + os.pathsep + default_path)
+        if nvcc is None:
+            raise EnvironmentError('The nvcc binary could not be '
+                'located in your $PATH. Either add it to your path, or set $CUDAHOME')
+        home = os.path.dirname(os.path.dirname(nvcc))
+
+    cudaconfig = {'home':home, 'nvcc':nvcc,
+                  'include': pjoin(home, 'include'),
+                  'lib64': pjoin(home, 'lib64')}
+    for k, v in cudaconfig.items():
+        if not os.path.exists(v):
+            raise EnvironmentError('The CUDA %s path could not be located in %s' % (k, v))
+
+    return cudaconfig
+CUDA = locate_cuda()
+
+
+# Obtain the numpy include directory.  This logic works across numpy versions.
+try:
+    numpy_include = np.get_include()
+except AttributeError:
+    numpy_include = np.get_numpy_include()
+
+
+def customize_compiler_for_nvcc(self):
+    """inject deep into distutils to customize how the dispatch
+    to gcc/nvcc works.
+    If you subclass UnixCCompiler, it's not trivial to get your subclass
+    injected in, and still have the right customizations (i.e.
+    distutils.sysconfig.customize_compiler) run on it. So instead of going
+    the OO route, I have this. Note, it's kindof like a wierd functional
+    subclassing going on."""
+
+    # tell the compiler it can processes .cu
+    self.src_extensions.append('.cu')
+
+    # save references to the default compiler_so and _comple methods
+    default_compiler_so = self.compiler_so
+    super = self._compile
+
+    # now redefine the _compile method. This gets executed for each
+    # object but distutils doesn't have the ability to change compilers
+    # based on source extension: we add it.
+    def _compile(obj, src, ext, cc_args, extra_postargs, pp_opts):
+        if os.path.splitext(src)[1] == '.cu':
+            # use the cuda for .cu files
+            self.set_executable('compiler_so', CUDA['nvcc'])
+            # use only a subset of the extra_postargs, which are 1-1 translated
+            # from the extra_compile_args in the Extension class
+            postargs = extra_postargs['nvcc']
+        else:
+            postargs = extra_postargs['gcc']
+
+        super(obj, src, ext, cc_args, postargs, pp_opts)
+        # reset the default compiler_so, which we might have changed for cuda
+        self.compiler_so = default_compiler_so
+
+    # inject our redefined _compile method into the class
+    self._compile = _compile
+
+
+# run the customize_compiler
+class custom_build_ext(build_ext):
+    def build_extensions(self):
+        customize_compiler_for_nvcc(self.compiler)
+        build_ext.build_extensions(self)
+
+
+ext_modules = [
+    Extension(
+        "cpu_nms",
+        ["cpu_nms.pyx"],
+        extra_compile_args={'gcc': ["-Wno-cpp", "-Wno-unused-function"]},
+        include_dirs = [numpy_include]
+    ),
+    Extension('gpu_nms',
+        ['nms_kernel.cu', 'gpu_nms.pyx'],
+        library_dirs=[CUDA['lib64']],
+        libraries=['cudart'],
+        language='c++',
+        runtime_library_dirs=[CUDA['lib64']],
+        # this syntax is specific to this build system
+        # we're only going to use certain compiler args with nvcc and not with
+        # gcc the implementation of this trick is in customize_compiler() below
+        extra_compile_args={'gcc': ["-Wno-unused-function"],
+                            'nvcc': ['-arch=sm_35',
+                                     '--ptxas-options=-v',
+                                     '-c',
+                                     '--compiler-options',
+                                     "'-fPIC'"]},
+        include_dirs = [numpy_include, CUDA['include']]
+    ),
+]
+
+setup(
+    name='nms',
+    ext_modules=ext_modules,
+    # inject our custom trigger
+    cmdclass={'build_ext': custom_build_ext},
+)
diff --git a/lib/utils/__init__.py b/lib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/lib/utils/debug.py b/lib/utils/debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae6a69a58606ca2e469ee55877ab7fe42b64d24c
--- /dev/null
+++ b/lib/utils/debug.py
@@ -0,0 +1,84 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import gc, torch
+
+
+class GradPlots:
+    def __init__(self, key=None):
+        self.layers = []
+        self.ave_grads = {}
+        self.epochs = 0
+        self.key = key
+
+    def save_grads(self, named_parameters):
+        ave_grads = []
+        layers = []
+        for n, p in named_parameters:
+            if self.key in n:
+                print(n)
+            if (p.requires_grad) and ("bias" not in n):
+                layers.append(n)
+                ave_grads.append(p.grad.abs().mean())
+
+        np_grads = np.asarray(ave_grads)
+        if self.key is None:
+            # Keep 10 greatest
+            idx = np_grads.argpartition(-10)[-10:]
+        else:
+            idx = [i for i, k in enumerate(layers) if self.key in k]
+        ave_grads = np_grads[idx].tolist()
+        layers = np.asarray(layers)[idx].tolist()
+
+        for l, layer in enumerate(layers):
+            if layer in self.layers:
+                self.ave_grads[layer].append(ave_grads[l])
+            else:
+                self.ave_grads[layer] = [ave_grads[l]]
+                self.layers.append(layer)
+
+        self.epochs += 1
+
+        for k in self.ave_grads:
+            while len(self.ave_grads[k]) < self.epochs:
+                self.ave_grads[k] = [0] + self.ave_grads[k]
+
+    def plot_graph(self):
+        for e in range(self.epochs):
+            ave_grad = [self.ave_grads[l][e] for l in self.layers]
+
+            plt.plot(ave_grad, alpha=0.3, label=f'{e}')
+
+        plt.hlines(0, 0, len(self.layers) + 1, linewidth=1, color="k")
+        plt.xticks(range(0, len(self.layers), 1), self.layers, rotation="vertical")
+        plt.xlim(xmin=0, xmax=len(self.layers))
+        plt.xlabel("Layers")
+        plt.ylabel("average gradient")
+        plt.title("Gradient flow")
+        plt.grid(True)
+        plt.legend()
+        plt.savefig('debug_grad.png', bbox_inches='tight')
+
+
+def print_tensors():
+    tensors = {}
+    for obj in gc.get_objects():
+        try:
+            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
+                lbl_type = type(obj)
+                if lbl_type == torch.nn.parameter.Parameter:
+                    lbl_type = 'parameter'
+                elif lbl_type == torch.Tensor:
+                    lbl_type = 'tensor'
+
+                mem_size = obj.element_size() * obj.nelement()
+                lbls_mem = ['', 'k', 'M', 'G', 'T', 'P']
+                i = 0
+                while mem_size >= 1024:
+                    mem_size /= 1024
+                    i += 1
+                tensors[
+                    obj.element_size() * obj.nelement()] = f'{lbl_type}: {tuple(obj.size())}, {mem_size:.2}{lbls_mem[i]}b'
+        except:
+            pass
+    for k in sorted(tensors, reverse=True):
+        print(tensors[k])
diff --git a/lib/utils/tabs.py b/lib/utils/tabs.py
new file mode 100644
index 0000000000000000000000000000000000000000..7019e72f97892cfb41e4c15232f4bada0c5121bf
--- /dev/null
+++ b/lib/utils/tabs.py
@@ -0,0 +1,103 @@
+import numpy as np
+
+
+def text_from_data(data, lbls_r=None, lbls_c=None):
+    text = ''
+    if not data:
+        return text
+
+    if not type(data[0]) in (list, dict, np.ndarray):
+        data = [data]
+
+    if lbls_c:
+        text = '\t' if lbls_r else ''
+        text += '\t'.join(lbls_c)
+
+    for r, row in enumerate(data):
+        if lbls_c or r:
+            text += '\n'
+        if lbls_r:
+            text += lbls_r[r] + '\t'
+
+        text += '\t'.join([f'{x:0.2f}' if type(x) in [float, np.float_] else f'{x}' for x in row])
+
+    return text
+
+
+class Tabs:
+    def __init__(self, data, lbls_r=None, lbls_c=None, logger=None):
+        if not data:
+            return
+        self.text = data if type(data) == str else text_from_data(data, lbls_r=lbls_r, lbls_c=lbls_c)
+
+        self.rows = [l.split('\t') for l in self.text.split('\n')]
+        if self.text[-1] == '\n':
+            self.rows.pop()
+
+        self.num_l = len(self.rows)
+        self.num_c = len(self.rows[0]) if self.rows else 0
+
+        self.cols = [[l[i] if len(l) > i else '' for l in self.rows] for i in range(self.num_c)]
+
+        self.len_c = [max([len(x) for x in c]) for c in self.cols]
+
+        self.logger = logger
+        self.disp_text = self.disp()
+
+    def disp(self):
+        disp_text = ''
+        for row in self.rows:
+            for c, w in enumerate(row):
+                disp_text += '\t' if c > 0 else ''
+                disp_text += w + ' ' * (self.len_c[c] - len(w))
+            disp_text += '\n'
+
+        if self.logger:
+            self.logger.info('\n' + disp_text)
+        else:
+            print(disp_text)
+        return disp_text
+
+    def to_latex(self):
+        header = """
+        \\begin{table}[!htb]
+        \\centering
+        \\small
+        \\renewcommand{\\tabcolsep}{2pt}
+        \\begin{tabular}{"""
+
+        header2 = ''.join(['c' + (' ' if c else '|') for c in range(self.num_c)])
+        header3 = """}
+        \\hline
+        """
+
+        header = header + header2 + header3
+
+        feat = """\\hline        
+        \\multicolumn{9}{c}{}
+        \\end{tabular}
+        \\caption{Comparisons of AP scores on the COCO 2017 val set with AP OKS.}
+        \\label{tab:cocotest}
+        \\end{table}"""
+
+        mid = ''
+        for r, row in enumerate(self.rows):
+            for c, w in enumerate(row):
+                mid += ' & ' if c > 0 else ''
+                mid += w
+            mid += '\\\\\n'
+            if not r:
+                mid += "\\hline\n"
+
+        tab = header + mid + feat
+        print(tab.replace('        ', ''))
+
+
+if __name__ == '__main__':
+    lbl_c = ['AP', 'AP$^{50}$', 'AP$^{75}$', 'AP$^L$', 'AR', 'AR$^{50}$', 'AR$^{75}$', 'AR$^L$']
+    lbl_r = [r'SBl~\cite{Xiao_2018_ECCV}', r'MSPN~\cite{MSPN_2019}', 'RSN~\cite{RSN_2020}']
+    data = [[.72, .92, .80, .77, .76, .93, .82, .80],
+            [.77, .94, .85, .82, .80, .95, .87, .85],
+            [.76, .94, .84, .81, .79, .94, .85, .84]]
+    t = Tabs(data, lbls_r=lbl_r, lbls_c=lbl_c)
+    t.to_latex()
diff --git a/lib/utils/transforms.py b/lib/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..173d439cbf6ba7dadf92ee131e6ac4a9b21d9905
--- /dev/null
+++ b/lib/utils/transforms.py
@@ -0,0 +1,123 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import cv2
+
+
+def flip_back(output_flipped, matched_parts):
+    '''
+    ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
+    '''
+    assert output_flipped.ndim == 4,\
+        'output_flipped should be [batch_size, num_joints, height, width]'
+
+    output_flipped = output_flipped[:, :, :, ::-1]
+
+    for pair in matched_parts:
+        tmp = output_flipped[:, pair[0], :, :].copy()
+        output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
+        output_flipped[:, pair[1], :, :] = tmp
+
+    return output_flipped
+
+
+def fliplr_joints(joints, joints_vis, width, matched_parts):
+    """
+    flip coords
+    """
+    # Flip horizontal
+    joints[:, 0] = width - joints[:, 0] - 1
+
+    # Change left-right parts
+    for pair in matched_parts:
+        joints[pair[0], :], joints[pair[1], :] = \
+            joints[pair[1], :], joints[pair[0], :].copy()
+        joints_vis[pair[0], :], joints_vis[pair[1], :] = \
+            joints_vis[pair[1], :], joints_vis[pair[0], :].copy()
+
+    return joints*(joints_vis > 0), joints_vis
+
+
+def transform_preds(coords, center, scale, output_size):
+    target_coords = np.zeros(coords.shape)
+    trans = get_affine_transform(center, scale, 0, output_size, inv=1)
+    for p in range(coords.shape[0]):
+        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
+    return target_coords
+
+
+def get_affine_transform(center,
+                         scale,
+                         rot,
+                         output_size,
+                         shift=np.array([0, 0], dtype=np.float32),
+                         inv=0):
+    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
+        print(scale)
+        scale = np.array([scale, scale])
+
+    scale_tmp = scale * 200.0
+    src_w = scale_tmp[0]
+    dst_w = output_size[0]
+    dst_h = output_size[1]
+
+    rot_rad = np.pi * rot / 180
+    src_dir = get_dir([0, src_w * -0.5], rot_rad)
+    dst_dir = np.array([0, dst_w * -0.5], np.float32)
+
+    src = np.zeros((3, 2), dtype=np.float32)
+    dst = np.zeros((3, 2), dtype=np.float32)
+    src[0, :] = center + scale_tmp * shift
+    src[1, :] = center + src_dir + scale_tmp * shift
+    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+
+    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
+    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
+
+    if inv:
+        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+    else:
+        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+    return trans
+
+
+def affine_transform(pt, t):
+    new_pt = np.array([pt[0], pt[1], 1.]).T
+    new_pt = np.dot(t, new_pt)
+    return new_pt[:2]
+
+
+def get_3rd_point(a, b):
+    direct = a - b
+    return b + np.array([-direct[1], direct[0]], dtype=np.float32)
+
+
+def get_dir(src_point, rot_rad):
+    sn, cs = np.sin(rot_rad), np.cos(rot_rad)
+
+    src_result = [0, 0]
+    src_result[0] = src_point[0] * cs - src_point[1] * sn
+    src_result[1] = src_point[0] * sn + src_point[1] * cs
+
+    return src_result
+
+
+def crop(img, center, scale, output_size, rot=0):
+    trans = get_affine_transform(center, scale, rot, output_size)
+
+    dst_img = cv2.warpAffine(img,
+                             trans,
+                             (int(output_size[0]), int(output_size[1])),
+                             flags=cv2.INTER_LINEAR)
+
+    return dst_img
diff --git a/lib/utils/utils.py b/lib/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f7620642875ad9070534cf3f7998034ba1e85b
--- /dev/null
+++ b/lib/utils/utils.py
@@ -0,0 +1,129 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import logging
+import sys
+import time
+from pathlib import Path
+from collections import OrderedDict
+
+import torch
+import torch.optim as optim
+
+from core.config import get_model_name
+
+
+def create_logger(cfg, cfg_name, phase='train'):
+    root_output_dir = Path(cfg.OUTPUT_DIR)
+    # set up logger
+    if not root_output_dir.exists():
+        print('=> creating {}'.format(root_output_dir))
+        root_output_dir.mkdir()
+
+    dataset = cfg.DATASET.DATASET + '_' + cfg.DATASET.HYBRID_JOINTS_TYPE \
+        if cfg.DATASET.HYBRID_JOINTS_TYPE else cfg.DATASET.DATASET
+    dataset = dataset.replace(':', '_')
+    model, _ = get_model_name(cfg)
+    cfg_name = os.path.basename(cfg_name).split('.')[0]
+
+    final_output_dir = root_output_dir / dataset / model / cfg_name
+
+    print('=> creating {}'.format(final_output_dir))
+    final_output_dir.mkdir(parents=True, exist_ok=True)
+
+    time_str = time.strftime('%Y-%m-%d-%H-%M')
+    log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
+    final_log_file = final_output_dir / log_file
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(filename=str(final_log_file),
+                        format=head)
+    logger = logging.getLogger()
+    logger.setLevel(logging.INFO)
+    console = logging.StreamHandler(sys.stdout)
+    logging.getLogger('').addHandler(console)
+
+    tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
+                          (cfg_name + '_' + time_str)
+    print('=> creating {}'.format(tensorboard_log_dir))
+    tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
+
+    return logger, str(final_output_dir), str(tensorboard_log_dir)
+
+
+def get_optimizer(cfg, model):
+    optimizer = None
+    if cfg.TRAIN.OPTIMIZER == 'sgd':
+        optimizer = optim.SGD(
+            model.parameters(),
+            lr=cfg.TRAIN.LR,
+            momentum=cfg.TRAIN.MOMENTUM,
+            weight_decay=cfg.TRAIN.WD,
+            nesterov=cfg.TRAIN.NESTEROV
+        )
+    elif cfg.TRAIN.OPTIMIZER == 'adam':
+        optimizer = optim.Adam(
+            model.parameters(),
+            lr=cfg.TRAIN.LR
+        )
+
+    return optimizer
+
+
+def save_checkpoint(states, is_best, output_dir,
+                    filename='checkpoint.pth.tar'):
+    torch.save(states, os.path.join(output_dir, filename))
+    if is_best and 'state_dict' in states:
+        torch.save(states['state_dict'],
+                   os.path.join(output_dir, 'model_best.pth.tar'))
+
+
+def convert_state_dict(state_dict, keys, error):
+    msg = str(error)
+    for m in ['Error(s) in loading state_dict for', 'Missing key(s)', 'Unexpected key(s)']:
+        if m not in msg:
+            raise error
+
+    lines = msg.split('\n')
+    missing = []
+    unexpected = []
+    for i, line in enumerate(lines[1:]):
+        storage = (missing, unexpected)[i]
+        for word in line.split():
+            if '"' not in word:
+                continue
+            storage.append(word[1:-2])
+
+    if len(missing) > len(unexpected):
+        raise error
+
+    new_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k not in unexpected:
+            new_dict[k] = v
+            continue
+        elif '.num_batches_tracked' in k:
+            continue
+        for i, m in enumerate(missing):
+            split_m = m.split('.')
+            short_m = '.'.join(split_m[1:])
+            if short_m == k and split_m[0] in keys:
+                new_dict[m] = v
+
+                if '.running_var' in m:
+                    nbt = m.replace('.running_var', '.num_batches_tracked')
+                    new_dict[nbt] = state_dict[k.replace('.running_var', '.num_batches_tracked')]
+
+                missing.pop(i)
+                break
+        else:
+            raise error
+
+    return OrderedDict(new_dict)
diff --git a/lib/utils/vis.py b/lib/utils/vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..adc0947a460328eb25047aa65e24c1febfdbeab7
--- /dev/null
+++ b/lib/utils/vis.py
@@ -0,0 +1,141 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import numpy as np
+import torchvision
+import cv2
+
+from core.inference import get_max_preds
+
+
+def save_batch_image_with_joints(batch_image, batch_joints, batch_joints_vis,
+                                 file_name, nrow=8, padding=2):
+    '''
+    batch_image: [batch_size, channel, height, width]
+    batch_joints: [batch_size, num_joints, 3],
+    batch_joints_vis: [batch_size, num_joints, 1],
+    }
+    '''
+    grid = torchvision.utils.make_grid(batch_image, nrow, padding, True)
+    ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
+    ndarr = ndarr.copy()
+
+    nmaps = batch_image.size(0)
+    xmaps = min(nrow, nmaps)
+    ymaps = int(math.ceil(float(nmaps) / xmaps))
+    height = int(batch_image.size(2) + padding)
+    width = int(batch_image.size(3) + padding)
+    k = 0
+    for y in range(ymaps):
+        for x in range(xmaps):
+            if k >= nmaps:
+                break
+            joints = batch_joints[k]
+            joints_vis = batch_joints_vis[k]
+
+            for joint, joint_vis in zip(joints, joints_vis):
+                joint[0] = x * width + padding + joint[0]
+                joint[1] = y * height + padding + joint[1]
+                if joint_vis[0]:
+                    cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, [255, 0, 0], 2)
+            k = k + 1
+    cv2.imwrite(file_name, ndarr)
+
+
+def save_batch_heatmaps(batch_image, batch_heatmaps, file_name,
+                        normalize=True):
+    '''
+    batch_image: [batch_size, channel, height, width]
+    batch_heatmaps: ['batch_size, num_joints, height, width]
+    file_name: saved file name
+    '''
+    if normalize:
+        batch_image = batch_image.clone()
+        min = float(batch_image.min())
+        max = float(batch_image.max())
+
+        batch_image.add_(-min).div_(max - min + 1e-5)
+
+    batch_size = batch_heatmaps.size(0)
+    num_joints = batch_heatmaps.size(1)
+    heatmap_height = batch_heatmaps.size(2)
+    heatmap_width = batch_heatmaps.size(3)
+
+    grid_image = np.zeros((batch_size*heatmap_height,
+                           (num_joints+1)*heatmap_width,
+                           3),
+                          dtype=np.uint8)
+
+    preds, maxvals = get_max_preds(batch_heatmaps.detach().cpu().numpy())
+
+    for i in range(batch_size):
+        image = batch_image[i].mul(255)\
+                              .clamp(0, 255)\
+                              .byte()\
+                              .permute(1, 2, 0)\
+                              .cpu().numpy()
+        heatmaps = batch_heatmaps[i].mul(255)\
+                                    .clamp(0, 255)\
+                                    .byte()\
+                                    .cpu().numpy()
+
+        resized_image = cv2.resize(image,
+                                   (int(heatmap_width), int(heatmap_height)))
+
+        height_begin = heatmap_height * i
+        height_end = heatmap_height * (i + 1)
+        for j in range(num_joints):
+            cv2.circle(resized_image,
+                       (int(preds[i][j][0]), int(preds[i][j][1])),
+                       1, [0, 0, 255], 1)
+            heatmap = heatmaps[j, :, :]
+            colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
+            masked_image = colored_heatmap*0.7 + resized_image*0.3
+            cv2.circle(masked_image,
+                       (int(preds[i][j][0]), int(preds[i][j][1])),
+                       1, [0, 0, 255], 1)
+
+            width_begin = heatmap_width * (j+1)
+            width_end = heatmap_width * (j+2)
+            grid_image[height_begin:height_end, width_begin:width_end, :] = \
+                masked_image
+            # grid_image[height_begin:height_end, width_begin:width_end, :] = \
+            #     colored_heatmap*0.7 + resized_image*0.3
+
+        grid_image[height_begin:height_end, 0:heatmap_width, :] = resized_image
+
+    cv2.imwrite(file_name, grid_image)
+
+
+def save_debug_images(config, input, meta, target, joints_pred, output,
+                      prefix):
+    if not config.DEBUG.DEBUG:
+        return
+
+    if config.DEBUG.SAVE_BATCH_IMAGES_GT:
+        save_batch_image_with_joints(
+            input, meta['joints'], meta['joints_vis'],
+            '{}_gt.jpg'.format(prefix)
+        )
+    if config.DEBUG.SAVE_BATCH_IMAGES_PRED:
+        save_batch_image_with_joints(
+            input, joints_pred, meta['joints_vis'],
+            '{}_pred.jpg'.format(prefix)
+        )
+    if config.DEBUG.SAVE_HEATMAPS_GT:
+        save_batch_heatmaps(
+            input, target, '{}_hm_gt.jpg'.format(prefix)
+        )
+    if config.DEBUG.SAVE_HEATMAPS_PRED:
+        save_batch_heatmaps(
+            input, output, '{}_hm_pred.jpg'.format(prefix)
+        )
diff --git a/lib/utils/zipreader.py b/lib/utils/zipreader.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c447a31f20a8c29af0b212b4d0fe1a95632a1cc
--- /dev/null
+++ b/lib/utils/zipreader.py
@@ -0,0 +1,70 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import zipfile
+import xml.etree.ElementTree as ET
+
+import cv2
+import numpy as np
+
+_im_zfile = []
+_xml_path_zip = []
+_xml_zfile = []
+
+
+def imread(filename, flags=cv2.IMREAD_COLOR):
+    global _im_zfile
+    path = filename
+    pos_at = path.index('@')
+    if pos_at == -1:
+        print("character '@' is not found from the given path '%s'"%(path))
+        assert 0
+    path_zip = path[0: pos_at]
+    path_img = path[pos_at + 2:]
+    if not os.path.isfile(path_zip):
+        print("zip file '%s' is not found"%(path_zip))
+        assert 0
+    for i in range(len(_im_zfile)):
+        if _im_zfile[i]['path'] == path_zip:
+            data = _im_zfile[i]['zipfile'].read(path_img)
+            return cv2.imdecode(np.frombuffer(data, np.uint8), flags)
+
+    _im_zfile.append({
+        'path': path_zip,
+        'zipfile': zipfile.ZipFile(path_zip, 'r')
+    })
+    data = _im_zfile[-1]['zipfile'].read(path_img)
+
+    return cv2.imdecode(np.frombuffer(data, np.uint8), flags)
+
+
+def xmlread(filename):
+    global _xml_path_zip
+    global _xml_zfile
+    path = filename
+    pos_at = path.index('@')
+    if pos_at == -1:
+        print("character '@' is not found from the given path '%s'"%(path))
+        assert 0
+    path_zip = path[0: pos_at]
+    path_xml = path[pos_at + 2:]
+    if not os.path.isfile(path_zip):
+        print("zip file '%s' is not found"%(path_zip))
+        assert 0
+    for i in range(len(_xml_path_zip)):
+        if _xml_path_zip[i] == path_zip:
+            data = _xml_zfile[i].open(path_xml)
+            return ET.fromstring(data.read())
+    _xml_path_zip.append(path_zip)
+    print("read new xml file '%s'"%(path_zip))
+    _xml_zfile.append(zipfile.ZipFile(path_zip, 'r'))
+    data = _xml_zfile[-1].open(path_xml)
+    return ET.fromstring(data.read())
diff --git a/pose_estimation/_init_paths.py b/pose_estimation/_init_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3c81c33f183e04f6afe043a99793b67804fc17b
--- /dev/null
+++ b/pose_estimation/_init_paths.py
@@ -0,0 +1,23 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path as osp
+import sys
+
+
+def add_path(path):
+    if path not in sys.path:
+        sys.path.insert(0, path)
+
+
+this_dir = osp.dirname(__file__)
+
+lib_path = osp.join(this_dir, '..', 'lib')
+add_path(lib_path)
diff --git a/pose_estimation/train.py b/pose_estimation/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..899897d1da6be9d43e39793da09455ddcd5b5ca6
--- /dev/null
+++ b/pose_estimation/train.py
@@ -0,0 +1,272 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import pprint
+import shutil
+from collections import OrderedDict
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+from tensorboardX import SummaryWriter
+
+import _init_paths
+from core.config import config
+from core.config import update_config
+from core.config import update_dir
+from core.config import get_model_name
+from core.loss import JointsMSELoss, JointsMSELossVis
+from core.function import train
+from core.function import validate
+from utils.utils import get_optimizer
+from utils.utils import save_checkpoint
+from utils.utils import create_logger
+
+import dataset
+import models
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train keypoints network')
+    # general
+    parser.add_argument('--cfg',
+                        help='experiment configure file name',
+                        required=True,
+                        type=str)
+
+    args, rest = parser.parse_known_args()
+    # update config
+    update_config(args.cfg)
+
+    # training
+    parser.add_argument('--frequent',
+                        help='frequency of logging',
+                        default=config.PRINT_FREQ,
+                        type=int)
+    parser.add_argument('--gpus',
+                        help='gpus',
+                        type=str)
+    parser.add_argument('--workers',
+                        help='num of dataloader workers',
+                        type=int)
+    parser.add_argument('--resume',
+                        action='store_true',
+                        help='Resume')
+
+    parser.add_argument('--model-file',
+                        help='model state file',
+                        type=str)
+
+    parser.add_argument('--debug-memory', action='store_true')
+
+    args = parser.parse_args()
+
+    return args
+
+
+def reset_config(config, args):
+    if args.gpus:
+        config.GPUS = args.gpus
+    if args.workers:
+        config.WORKERS = args.workers
+
+    if args.resume:
+        config.MODEL.INIT_WEIGHTS = False
+        if args.model_file:
+            config.TRAIN.CHECKPOINT = args.model_file
+
+    if args.debug_memory:
+        config.DEBUG.DEBUG_MEMORY = True
+
+
+def main():
+    args = parse_args()
+    reset_config(config, args)
+
+    logger, final_output_dir, tb_log_dir = create_logger(
+        config, args.cfg, 'train')
+
+    logger.info(pprint.pformat(args))
+    logger.info(pprint.pformat(config))
+
+    # cudnn related setting
+    cudnn.benchmark = config.CUDNN.BENCHMARK
+    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
+    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
+
+    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
+        config, is_train=True
+    )
+
+    if args.resume:
+        logger.info('=> loading model from {}'.format(config.TRAIN.CHECKPOINT))
+        meta_info = torch.load(config.TRAIN.CHECKPOINT)
+
+        # resume previous training
+        state_dict = OrderedDict({k.replace('module.', ''): v
+                                  for k, v in meta_info['state_dict'].items()})
+        config.TRAIN.BEGIN_EPOCH = meta_info['epoch']
+        model.load_state_dict(state_dict)
+
+    else:
+        '''
+        logger.info('=> initialize from coco, so adjusting the last layer')
+        model.final_layer = torch.nn.Conv2d(256, config.MODEL.NUM_JOINTS, (1, 1), (1, 1))
+        logger.info('=> adjust done.')
+        '''
+        pass
+
+    if config.TRAIN.FREEZE:
+        model.freeze_encoder()
+        model.freeze_deconv()
+
+    # copy model file
+    this_dir = os.path.dirname(__file__)
+    shutil.copy2(
+        os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
+        final_output_dir)
+
+    writer_dict = {
+        'writer': SummaryWriter(log_dir=tb_log_dir),
+        'train_global_steps': 0,
+        'valid_global_steps': 0,
+    }
+
+    dump_input = torch.rand((config.TRAIN.BATCH_SIZE,
+                             3,
+                             config.MODEL.IMAGE_SIZE[1],
+                             config.MODEL.IMAGE_SIZE[0]))
+    writer_dict['writer'].add_graph(model, (dump_input,), verbose=False)
+
+    gpus = [int(i) for i in config.GPUS.split(',')]
+    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
+
+    # define loss function (criterion) and optimizer
+    if config.MODEL.PREDICT_VIS:
+        class_weights = None
+        if config.LOSS.USE_CLASS_WEIGHT:
+            vis_weight = [1.61, 7.83, 1.]
+            if config.MODEL.NB_VIS == 2:
+                vis_weight = [(1 / vis_weight[1] + 1 / vis_weight[2]) * vis_weight[0], 1.]
+            class_weights = torch.FloatTensor(vis_weight).cuda()
+        criterion = JointsMSELossVis(
+            use_target_weight=config.LOSS.USE_TARGET_WEIGHT,
+            vis_ratio=config.LOSS.VIS_RATIO,
+            vis_weight=class_weights
+        ).cuda()
+        for st in config.LOSS.VIS_STEP:
+            if config.TRAIN.BEGIN_EPOCH >= st:
+                criterion.update_vis_ratio(config.LOSS.VIS_FACTOR)
+            else:
+                break
+    else:
+        criterion = JointsMSELoss(
+            use_target_weight=config.LOSS.USE_TARGET_WEIGHT
+        ).cuda()
+
+    optimizer = get_optimizer(config, model)
+
+    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+        optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR
+    )
+
+    # Data loading code
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
+    train_dataset = eval('dataset.' + config.DATASET.DATASET)(
+        config,
+        config.DATASET.ROOT,
+        config.DATASET.TRAIN_SET,
+        True,
+        transforms.Compose([
+            transforms.ToTensor(),
+            normalize,
+        ])
+    )
+    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
+        config,
+        config.DATASET.ROOT,
+        config.DATASET.TEST_SET,
+        False,
+        transforms.Compose([
+            transforms.ToTensor(),
+            normalize,
+        ])
+    )
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset,
+        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
+        shuffle=config.TRAIN.SHUFFLE,
+        num_workers=config.WORKERS,
+        pin_memory=True
+    )
+    valid_loader = torch.utils.data.DataLoader(
+        valid_dataset,
+        batch_size=config.TEST.BATCH_SIZE * len(gpus),
+        shuffle=False,
+        num_workers=config.WORKERS,
+        pin_memory=True
+    )
+
+    best_perf = 0.0
+    best_model = False
+    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
+
+        # Step Loss
+        if config.MODEL.PREDICT_VIS and epoch in config.LOSS.VIS_STEP:
+            criterion.update_vis_ratio(config.LOSS.VIS_FACTOR)
+
+        # train for one epoch
+        train(config, train_loader, model, criterion, optimizer, epoch,
+              final_output_dir, tb_log_dir, writer_dict)
+
+        lr_scheduler.step()
+
+        # evaluate on validation set
+        perf_indicator = validate(config, valid_loader, valid_dataset, model,
+                                  criterion, final_output_dir, tb_log_dir,
+                                  writer_dict)
+
+        if perf_indicator > best_perf:
+            best_perf = perf_indicator
+            best_model = True
+        else:
+            best_model = False
+
+        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
+        save_checkpoint({
+            'epoch': epoch + 1,
+            'model': get_model_name(config),
+            'state_dict': model.state_dict(),
+            'perf': perf_indicator,
+            'optimizer': optimizer.state_dict(),
+        }, best_model, final_output_dir)
+
+        if config.TRAIN.SAVE_CHECKPOINT and not (epoch + 1) % config.TRAIN.SAVE_CHECKPOINT:
+            shutil.copy2(os.path.join(final_output_dir, 'checkpoint.pth.tar'),
+                         os.path.join(final_output_dir, f'checkpoint_{epoch + 1}.pth.tar'))
+
+    final_model_state_file = os.path.join(final_output_dir,
+                                          'final_state.pth.tar')
+    logger.info('saving final model state to {}'.format(
+        final_model_state_file))
+    torch.save(model.module.state_dict(), final_model_state_file)
+    writer_dict['writer'].close()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/pose_estimation/valid.py b/pose_estimation/valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..35856a31b7e40a6c3c8acbde0e5c015044e6d88f
--- /dev/null
+++ b/pose_estimation/valid.py
@@ -0,0 +1,189 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# ------------------------------------------------------------------------------
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import pprint
+from collections import OrderedDict
+
+import torch
+import torch.nn.parallel
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torch.utils.data.distributed
+import torchvision.transforms as transforms
+
+import _init_paths
+from core.config import config
+from core.config import update_config
+from core.config import update_dir
+from core.loss import JointsMSELoss, JointsMSELossVis
+from core.function import validate
+from utils.utils import create_logger
+
+import dataset
+import models
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Train keypoints network')
+    # general
+    parser.add_argument('--cfg',
+                        help='experiment configure file name',
+                        required=True,
+                        type=str)
+
+    args, rest = parser.parse_known_args()
+    # update config
+    update_config(args.cfg)
+
+    # training
+    parser.add_argument('--frequent',
+                        help='frequency of logging',
+                        default=config.PRINT_FREQ,
+                        type=int)
+    parser.add_argument('--gpus',
+                        help='gpus',
+                        type=str)
+    parser.add_argument('--workers',
+                        help='num of dataloader workers',
+                        type=int)
+    parser.add_argument('--model-file',
+                        help='model state file',
+                        type=str)
+    parser.add_argument('--use-detect-bbox',
+                        help='use detect bbox',
+                        action='store_true')
+    parser.add_argument('--flip-test',
+                        help='use flip test',
+                        action='store_true')
+    parser.add_argument('--post-process',
+                        help='use post process',
+                        action='store_true')
+    parser.add_argument('--shift-heatmap',
+                        help='shift heatmap',
+                        action='store_true')
+    parser.add_argument('--coco-bbox-file',
+                        help='coco detection bbox file',
+                        type=str)
+
+    args = parser.parse_args()
+
+    return args
+
+
+def reset_config(config, args):
+    if args.gpus:
+        config.GPUS = args.gpus
+    if args.workers:
+        config.WORKERS = args.workers
+    if args.use_detect_bbox:
+        config.TEST.USE_GT_BBOX = not args.use_detect_bbox
+    if args.flip_test:
+        config.TEST.FLIP_TEST = args.flip_test
+    if args.post_process:
+        config.TEST.POST_PROCESS = args.post_process
+    if args.shift_heatmap:
+        config.TEST.SHIFT_HEATMAP = args.shift_heatmap
+    if args.model_file:
+        config.TEST.MODEL_FILE = args.model_file
+    if args.coco_bbox_file:
+        config.TEST.COCO_BBOX_FILE = args.coco_bbox_file
+
+
+def main():
+    args = parse_args()
+    reset_config(config, args)
+
+    logger, final_output_dir, tb_log_dir = create_logger(
+        config, args.cfg, 'valid')
+
+    logger.info(pprint.pformat(args))
+    logger.info(pprint.pformat(config))
+
+    # cudnn related setting
+    cudnn.benchmark = config.CUDNN.BENCHMARK
+    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
+    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
+
+    model = eval('models.' + config.MODEL.NAME + '.get_pose_net')(
+        config, is_train=False
+    )
+
+    if config.TEST.MODEL_FILE:
+        logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
+
+        state_dict = torch.load(config.TEST.MODEL_FILE)
+        if "state_dict" in state_dict:
+            state_dict = OrderedDict([(k.replace('module.', ''), v) for k, v in state_dict['state_dict'].items()])
+        model.load_state_dict(state_dict)
+    else:
+        model_state_file = os.path.join(final_output_dir,
+                                        'final_state.pth.tar')
+        logger.info('=> loading model from {}'.format(model_state_file))
+        model.load_state_dict(torch.load(model_state_file))
+
+    gpus = [int(i) for i in config.GPUS.split(',')]
+    model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
+
+    # define loss function (criterion) and optimizer
+    if config.MODEL.PREDICT_VIS:
+        class_weights = None
+        if config.LOSS.USE_CLASS_WEIGHT:
+            vis_weight = [1.61, 7.83, 1.]
+            if config.MODEL.NB_VIS == 2:
+                vis_weight = [(1 / vis_weight[1] + 1 / vis_weight[2]) * vis_weight[0], 1.]
+            class_weights = torch.FloatTensor(vis_weight).cuda()
+        criterion = JointsMSELossVis(
+            use_target_weight=config.LOSS.USE_TARGET_WEIGHT,
+            vis_ratio=config.LOSS.VIS_RATIO,
+            vis_weight=class_weights
+        ).cuda()
+        for st in config.LOSS.VIS_STEP:
+            if config.TRAIN.BEGIN_EPOCH >= st:
+                criterion.update_vis_ratio(config.LOSS.VIS_FACTOR)
+            else:
+                break
+    else:
+        criterion = JointsMSELoss(
+            use_target_weight=config.LOSS.USE_TARGET_WEIGHT
+        ).cuda()
+
+    # Data loading code
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
+
+    valid_dataset = eval('dataset.' + config.DATASET.DATASET)(
+        config,
+        config.DATASET.ROOT,
+        config.DATASET.TEST_SET,
+        False,
+        transforms.Compose([
+            transforms.ToTensor(),
+            normalize,
+        ])
+    )
+    valid_loader = torch.utils.data.DataLoader(
+        valid_dataset,
+        batch_size=config.TEST.BATCH_SIZE * len(gpus),
+        shuffle=False,
+        num_workers=config.WORKERS,
+        pin_memory=True
+    )
+
+    # evaluate on validation set
+    validate(config, valid_loader, valid_dataset, model, criterion,
+             final_output_dir, tb_log_dir)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e9306dee28fbd14d3194979dcbf45339b9e0179
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+EasyDict
+opencv-python==3.4.13.47
+Cython
+scipy
+numpy
+pandas
+pyyaml
+json_tricks
+scikit-image
+scikit-learn
+tensorboard
+tensorboardX>=1.2
+torchvision