diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9ce682d9d475592c2e1b42d88eb1a9db5df0368f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+.idea
+*__pycache__*
+synthetic_ssd/metrics/*
+data/tless/*
+data/weights/maskrcnn*
+deps/*
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..211e3c68dd5374837233666a6ea9488f99bf31ff
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "deps/Object-Detection-Metrics"]
+	path = Object-Detection-Metrics
+	url = https://github.com/rafaelpadilla/Object-Detection-Metrics.git
diff --git a/README.md b/README.md
index e1bacd878e730ec7dbdbcc61d10a210ed4eb43dd..a8c8c983eaf73b29b6d2b18f802072c6a9de4747 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,117 @@
-# Synthetic-SSD
+<h1 align="center">
+TRAINING AN EMBEDDED OBJECT DETECTOR FOR INDUSTRIAL SETTINGS WITHOUT REAL IMAGES
+</h1>
+<div align="center">
+<h3>
+<a href="https://liris.cnrs.fr/page-membre/julia-cohen">Julia Cohen</a>,
+<a href="https://liris.cnrs.fr/page-membre/carlos-crispim-junior">Carlos Crispim-Junior</a>,
+<a>Jean-Marc Chiappa</a>,
+<a href="https://liris.cnrs.fr/page-membre/laure-tougne">Laure Tougne</a>
+<br>
+<br>
+IEEE ICIP: International Conference on Image Processing, 2021
+</h3>
+</div>
+
+
+# Table of content
+- [Overview](#overview)
+- [Installation](#installation)
+- [Dataset](#dataset)
+- [Evaluation](#evaluation)
+- [Training](#training)
+- [Citation](#citation)
+- [Acknowledgements](#acknowledgements)
+
+# Overview
+This repository contains the code for MobileNet-V3 SSD, MobileNet-V3 Small SSD, and MobileNet-V2 SSD 
+as described in the paper. The weights of the models trained on the T-LESS dataset are available 
+for download, as well as the compared Mask R-CNN model.
+
+![Results Images](data/images/results.png)
+
+# Installation
+## Project
+Clone the current repository:
+```
+git clone --recurse-submodules https://gitlab.liris.cnrs.fr/jcohen/synthetic-ssd.git
+```
+The "Object Detection Metrics" project should be downloaded in the `deps` folder. 
+After download, move the `lib` subfolder and `_init_paths.py` file into `synthetic_ssd\metrics`.
+
+## Environment
+Create and activate a conda environment. 
+```
+conda create --file requirements.yml
+```
+This project has been tested on Windows with Python 3.8 and PyTorch 1.9, but previous versions of
+Python 3 and PyTorch should also work.
+
+# Dataset
+The T-LESS dataset is available on the [BOP website](https://bop.felk.cvut.cz/datasets/). 
+We use the "PBR-BlenderProc4BOP training images" and "All test images" subsets, saved in the `data` folder. 
+The expected folder structure is:  
+```
+- data
+| - tless
+  | - test_primesense
+  | - train_pbr
+```
+Otherwise, you can change the `TLESS_BASE_PATH` variable in
+[synthetic_ssd\config.py](synthetic_ssd/config.py)
+
+# Evaluation
+For the Mask R-CNN model, download the trained weights file from the 
+[CosyPose project](https://github.com/ylabbe/cosypose#downloading-bop-datasets) into the `data/weights` 
+folder, or elsewhere and change the `MASK_RCNN_PATH` variable in [synthetic_ssd\config.py](synthetic_ssd/config.py). The
+model evaluated in the paper corresponds to the one with `model_id` detector-bop-tless-pbr--873074.
+
+To reproduce the performance reported in the paper, use the evaluation script:
+```
+python -m synthetic_ssd.scripts.run_eval \
+--model [mobilenet_v2_ssd, mobilenet_v3_ssd, mobilenet_v3_small_ssd, mask_rcnn] \
+--tf [aug1, aug2]
+```
+Default values are set to **mobilenet_v2_ssd** and **aug2**.
+
+Model   | mAP (%) | Parameters (M)
+:--- | :----: | :----:
+Mask R-CNN | 32.8 | 44.0
+V3small-SSD (aug1) | 18.6 | 2.6
+V3-SSD (aug1) | 36.3 | 4.9 
+V2-SSD (aug1) | 38.3 | 3.5
+V3small-SSD (aug2) | 23.5 | 2.6
+V3-SSD (aug2) | 46.1 | 4.9
+**V2-SSD (aug2)** | **47.7** | 3.5
+
+
+# Training
+Coming soon...
+
+# Citation
+If you use this code in your research, please cite the paper:
+
+```
+@inproceedings{cohen2021training,
+  title={Training An Embedded Object Detector For Industrial Settings Without Real Images},
+  author={Cohen, Julia and Crispim-Junior, Carlos and Chiappa, Jean-Marc and Tougne, Laure},
+  booktitle={2021 IEEE International Conference on Image Processing (ICIP)},
+  pages={714--718},
+  year={2021},
+  organization={IEEE}
+}
+```
+
+# Acknowledgements
+This repository is based on the following works:
+- MobileNet-V3 with ImageNet pretraining: [https://github.com/d-li14/mobilenetv3.pytorch](https://github.com/d-li14/mobilenetv3.pytorch)
+- SSD: [https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection](https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection)
+- MobileNet-V2 - SSD: [https://github.com/qfgaohao/pytorch-ssd](https://github.com/qfgaohao/pytorch-ssd)
+- MobileNet-V3 - SSD: [https://github.com/tongyuhome/MobileNetV3-SSD](https://github.com/tongyuhome/MobileNetV3-SSD)
+
+This work is supported by grant CIFRE n.2018/0872 from ANRT.  
+<div align="center">
+<img src="data/images/LogoLIRIS.jpg" alt="LIRIS logo" height="100" width="100"/>
+<img src="data/images/LogoDEMS.png" alt="DEMS logo" height="100" width="100"/>
+</div>
 
-Code associated with the paper "Training an Embedded Object Detector for Industrial Settings Without Real Images" (IEEE ICIP2021).
\ No newline at end of file
diff --git a/data/images/LogoDEMS.png b/data/images/LogoDEMS.png
new file mode 100644
index 0000000000000000000000000000000000000000..117dd16151f369fd201fea867524453e43b578be
Binary files /dev/null and b/data/images/LogoDEMS.png differ
diff --git a/data/images/LogoLIRIS.jpg b/data/images/LogoLIRIS.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..74e45b48397f3f0d1e68ffbfdbc5a5987639219c
Binary files /dev/null and b/data/images/LogoLIRIS.jpg differ
diff --git a/data/images/results.png b/data/images/results.png
new file mode 100644
index 0000000000000000000000000000000000000000..0f08f8e6716cb0a78fc209cfefefc6c238e89076
Binary files /dev/null and b/data/images/results.png differ
diff --git a/data/weights/aug1/tless_icip21_V2ssd.pth b/data/weights/aug1/tless_icip21_V2ssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..daf37a54302569ca75b6362b8290bb9c72846a99
Binary files /dev/null and b/data/weights/aug1/tless_icip21_V2ssd.pth differ
diff --git a/data/weights/aug1/tless_icip21_V3smallssd.pth b/data/weights/aug1/tless_icip21_V3smallssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..5e42d4d5a383c31ec142709e54ad9e3c37b027a5
Binary files /dev/null and b/data/weights/aug1/tless_icip21_V3smallssd.pth differ
diff --git a/data/weights/aug1/tless_icip21_V3ssd.pth b/data/weights/aug1/tless_icip21_V3ssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..82f7daa1fc0ca6ab381573599cbd3daee98fe815
Binary files /dev/null and b/data/weights/aug1/tless_icip21_V3ssd.pth differ
diff --git a/data/weights/aug2/tless_icip21_V2ssd.pth b/data/weights/aug2/tless_icip21_V2ssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3976479908100a013faf558bfde5bf5ad177abc6
Binary files /dev/null and b/data/weights/aug2/tless_icip21_V2ssd.pth differ
diff --git a/data/weights/aug2/tless_icip21_V3smallssd.pth b/data/weights/aug2/tless_icip21_V3smallssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..47cfe3edd5e31d4477615eed0e018a699ef54c8e
Binary files /dev/null and b/data/weights/aug2/tless_icip21_V3smallssd.pth differ
diff --git a/data/weights/aug2/tless_icip21_V3ssd.pth b/data/weights/aug2/tless_icip21_V3ssd.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f64a4b4dc0817594f0054210e30f1a9c03365790
Binary files /dev/null and b/data/weights/aug2/tless_icip21_V3ssd.pth differ
diff --git a/deps b/deps
new file mode 160000
index 0000000000000000000000000000000000000000..d612d3f543a565ee4229daee908333777f733532
--- /dev/null
+++ b/deps
@@ -0,0 +1 @@
+Subproject commit d612d3f543a565ee4229daee908333777f733532
diff --git a/requirements.yml b/requirements.yml
new file mode 100644
index 0000000000000000000000000000000000000000..65fa6ec9a23ea36622756808aa5e2e396755768f
--- /dev/null
+++ b/requirements.yml
@@ -0,0 +1,16 @@
+name: synthetic_ssd
+channels:
+  - conda-forge
+  - pytorch
+dependencies:
+  - python=3.8.10
+  - pytorch
+  - torchvision
+  - cudatoolkit
+  - yaml
+  - tqdm
+  - numpy
+  - opencv
+  - matplotlib
+  - imgaug
+  - albumentations
\ No newline at end of file
diff --git a/synthetic_ssd/__init__.py b/synthetic_ssd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/synthetic_ssd/config.py b/synthetic_ssd/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd73b4473b37da99eb1041c9d766673998c15026
--- /dev/null
+++ b/synthetic_ssd/config.py
@@ -0,0 +1,15 @@
+import synthetic_ssd
+from pathlib import Path
+
+PROJECT_ROOT = Path(synthetic_ssd.__file__).parent.parent
+PROJECT_DIR = PROJECT_ROOT
+DATA_DIR = PROJECT_DIR / 'data'
+WEIGHTS_DIR = DATA_DIR / 'weights'
+MASK_RCNN_PATH = WEIGHTS_DIR / "maskrcnn_detector-bop-tless-pbr--873074.pth"
+
+TLESS_BASE_PATH = DATA_DIR / 'tless'
+TLESS_TRAIN_PATH = TLESS_BASE_PATH / 'train_pbr'
+TLESS_TEST_PATH = TLESS_BASE_PATH / 'test_primesense'
+
+DEPS_DIR = PROJECT_DIR / 'deps'
+
diff --git a/synthetic_ssd/datasets/__init__.py b/synthetic_ssd/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/synthetic_ssd/datasets/augmentations.py b/synthetic_ssd/datasets/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..55bb5edeed7046b076519dab0611bab9880fed7c
--- /dev/null
+++ b/synthetic_ssd/datasets/augmentations.py
@@ -0,0 +1,72 @@
+import albumentations as A
+from albumentations.pytorch import ToTensorV2
+
+
+def train_tf(out_size=224):
+    """
+    Proposed extensive augmentation pipeline.
+    :param out_size:
+    :return:
+    """
+    tf = A.Compose([
+        A.OneOf([
+            A.ColorJitter(p=0.5, brightness=(0.5, 1.4), contrast=0.5, saturation=0.9, hue=0.5),
+            A.ColorJitter(p=0.5, brightness=0, contrast=0, saturation=0, hue=0.5),
+            A.ColorJitter(p=0.5, brightness=0, contrast=0, saturation=0.9, hue=0),
+            A.ColorJitter(p=0.5, brightness=0, contrast=0.5, saturation=0, hue=0),
+            A.ColorJitter(p=0.5, brightness=(0.5, 1.4), contrast=0, saturation=0, hue=0),
+        ], p=0.8),
+
+        A.CLAHE(p=0.5),
+
+        A.RGBShift(p=0.5),
+
+        A.OneOf([
+            A.Blur(p=0.5, blur_limit=(3, 7)),
+            A.GaussianBlur(p=0.5, blur_limit=(3, 7), sigma_limit=0),
+            A.MedianBlur(p=0.5, blur_limit=(3, 7)),
+            A.MotionBlur(p=0.5, blur_limit=(5, 15)),
+        ], p=0.5),
+
+        A.GaussNoise(p=0.8),
+        A.MultiplicativeNoise(p=0.2, multiplier=(0.7, 1.3)),
+        A.ISONoise(p=0.2),
+
+        A.VerticalFlip(p=0.5),
+
+        A.RandomCrop(p=0.6, height=400, width=400),
+
+        A.Resize(p=1.0, height=out_size, width=out_size),
+        A.Normalize(p=1.0),
+        ToTensorV2(p=1.0)
+    ],
+        bbox_params=A.BboxParams(format="albumentations",
+                                 label_fields=['labels'],
+                                 min_area=250,
+                                 )
+    )
+    return tf
+
+
+def test_tf_ssd(out_size=224):
+    tf = A.Compose(
+        [
+            A.Resize(p=1.0, height=out_size, width=out_size),
+            A.Normalize(p=1.0),
+            ToTensorV2(p=1.0)
+        ],
+        bbox_params=A.BboxParams(format="albumentations",
+                                 label_fields=['labels'])
+    )
+    return tf
+
+
+def test_tf_rcnn():
+    tf = A.Compose(
+        [
+            ToTensorV2(p=1.0)
+        ],
+        bbox_params=A.BboxParams(format="albumentations",
+                                 label_fields=['labels'])
+    )
+    return tf
diff --git a/synthetic_ssd/datasets/detection_dataset.py b/synthetic_ssd/datasets/detection_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..402ebb817a3c1b5834ec07d4b8b28fa410d5de24
--- /dev/null
+++ b/synthetic_ssd/datasets/detection_dataset.py
@@ -0,0 +1,120 @@
+import os
+from pathlib import Path
+
+import json
+import cv2
+import torch
+from torch.utils.data import Dataset
+import albumentations as A
+
+from synthetic_ssd.config import TLESS_TRAIN_PATH, TLESS_TEST_PATH
+
+
+class TLESSDetection(Dataset):
+    def __init__(self, tf, target_tf, mode):
+        self.tf = tf
+        self.target_tf = target_tf
+
+        if mode == "train":
+            path = Path(TLESS_TRAIN_PATH)
+        else:
+            path = Path(TLESS_TEST_PATH)
+        assert path.exists(), f"Path {path} does not exist."
+
+        self.rgb_paths = list()
+        self.gt = dict()
+        self.gt['boxes'] = dict()
+        self.gt['labels'] = dict()
+        for cur_path in path.iterdir():
+            if not cur_path.is_dir():
+                continue
+            images = [str(f) for f in (cur_path / 'rgb').iterdir()
+                      if f.name.endswith(".jpg") or f.name.endswith(".png")]
+            self.rgb_paths.extend(images)
+
+            annot_gt = json.loads((cur_path / 'scene_gt.json').read_text())
+            annot_gt_info = json.loads((cur_path / 'scene_gt_info.json').read_text())
+            self.gt['boxes'][cur_path.name.zfill(6)] = dict()
+            self.gt['labels'][cur_path.name.zfill(6)] = dict()
+            for image_id in annot_gt.keys():
+                boxes = list()
+                labels = list()
+                for id_data, bbox_data in zip(annot_gt[image_id], annot_gt_info[image_id]):
+                    x, y, w, h = bbox_data["bbox_visib"]  # left, top, width, height
+                    if x == -1 and y == -1 and w == -1 and h == -1:
+                        continue
+                    if bbox_data["px_count_visib"] < 100 or w < 10 or h < 10:
+                        continue
+                    boxes.append([x, y, x + w, y + h])
+                    labels.append(id_data["obj_id"])
+                self.gt['boxes'][cur_path.name.zfill(6)][image_id.zfill(6)] = boxes
+                self.gt['labels'][cur_path.name.zfill(6)][image_id.zfill(6)] = labels
+        del annot_gt, annot_gt_info
+
+        self.length = len(self.rgb_paths)
+        self.classes = list(range(1, 31))
+
+    def __len__(self):
+        return self.length
+
+    def __repr__(self):
+        return f"Dataset TLESS of RGB images for object detection: "
+
+    def __getitem__(self, idx):
+        color_img = cv2.imread(self.rgb_paths[idx])[:, :, :3]  # uint8, BGR
+        color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)
+
+        full_path = self.rgb_paths[idx]
+        list_path = full_path.split(os.sep)
+        folder_id = list_path[-3]
+        image_id = list_path[-1].split(".")[0]
+        target = dict()
+        target["boxes"] = self.gt["boxes"][folder_id][image_id]
+        target["labels"] = self.gt["labels"][folder_id][image_id]
+        target['image_name'] = full_path
+        target['image_size'] = color_img.shape[:2]
+
+        # Pre-process and augment data
+        target["boxes"] = normalize_bboxes(target["boxes"], target["image_size"][0], target["image_size"][1])
+        len_new_boxes = 0
+        # Make sure the image still has boxes after augmentations
+        while not len_new_boxes:
+            transformed = self.tf(image=color_img, bboxes=target["boxes"], labels=target["labels"])
+            len_new_boxes = len(transformed["bboxes"])
+        color_img = transformed["image"]
+        target["boxes"] = torch.as_tensor(transformed["bboxes"], dtype=torch.float32)
+        target["labels"] = torch.as_tensor(transformed["labels"], dtype=torch.int64)
+
+        target['raw_boxes'], target["raw_labels"] = target["boxes"], target["labels"]
+        if self.target_tf:
+            target["boxes"], target["labels"] = self.target_tf(target["boxes"], target["labels"])
+
+        if color_img.dtype == torch.uint8:
+            color_img = color_img.type(torch.FloatTensor)
+
+        return color_img, target
+
+
+def normalize_bboxes(bboxes, im_height, im_width):
+    """
+    Normalize bounding boxes.
+    :param bboxes: a sequence of bounding boxes with format x1y1x2y2
+    :param im_height: image height
+    :param im_width: image width
+    :return: the normalized bounding boxes, with format "albumentations"
+    """
+    return A.normalize_bboxes(bboxes, rows=im_height, cols=im_width)
+
+
+def collate_fn(batch):
+    """
+    Collate target data stored in a dictionary.
+    :param batch:
+    :return:
+    """
+    images = [b[0] for b in batch]
+    images = torch.stack(images, dim=0)
+
+    targets = [b[1] for b in batch]
+
+    return images, targets
diff --git a/synthetic_ssd/models/__init__.py b/synthetic_ssd/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92c225878515ea2e7926a9afe9b4741a8a4e958e
--- /dev/null
+++ b/synthetic_ssd/models/__init__.py
@@ -0,0 +1,49 @@
+from torchvision.models.detection import maskrcnn_resnet50_fpn
+
+from .mobilenet_v3 import MobileNetV3
+from .mobilenet_v2_ssd_lite import mobilenet_v2_ssd_lite
+from .mobilenet_v3_ssd_lite import mobilenet_v3_ssd_lite
+from . import mobilenetv2_ssd_config, mobilenetv3_ssd_config
+from .box_utils import generate_ssd_priors
+
+from torchvision.models import mobilenet_v2
+
+
+def create_model(opt):
+    num_classes_with_bkgd = opt.num_classes + 1
+    if not opt.train:
+        if opt.model == 'mask_rcnn':
+            # Load classes info
+            model = maskrcnn_resnet50_fpn(pretrained=False,
+                                          pretrained_backbone=False,
+                                          num_classes=num_classes_with_bkgd,
+                                          box_score_thresh=0.1)
+            opt.priors = None
+        elif opt.model == 'mobilenet_v2_ssd':
+            backbone = mobilenet_v2(pretrained=True)
+            priors = generate_ssd_priors(mobilenetv2_ssd_config.specs, opt.input_size)
+            model = mobilenet_v2_ssd_lite(num_classes=num_classes_with_bkgd,
+                                          base_net=backbone,
+                                          is_test=False,
+                                          config=mobilenetv2_ssd_config,
+                                          priors=priors,
+                                          device=opt.device
+                                          )
+            opt.priors = priors
+        elif 'mobilenet_v3' in opt.model:
+            if 'small' in opt.model:
+                backbone = MobileNetV3(mode='small')
+            else:
+                backbone = MobileNetV3(mode='large')
+            priors = generate_ssd_priors(mobilenetv3_ssd_config.specs, opt.input_size)
+            model = mobilenet_v3_ssd_lite(num_classes=num_classes_with_bkgd,
+                                          base_net=backbone,
+                                          is_test=False,
+                                          config=mobilenetv3_ssd_config,
+                                          priors=priors,
+                                          device=opt.device
+                                          )
+            opt.priors = priors
+        else:
+            raise ValueError(f"Model {opt.model} is not available.")
+    return model
diff --git a/synthetic_ssd/models/box_utils.py b/synthetic_ssd/models/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87cf3f0440f723089d09a742e00793b01fa90bea
--- /dev/null
+++ b/synthetic_ssd/models/box_utils.py
@@ -0,0 +1,301 @@
+"""
+From https://github.com/qfgaohao/pytorch-ssd/blob/7174f33aa2a1540f90d827d48dea681ec1a2856c/vision/utils/box_utils.py
+"""
+
+import collections
+import torch
+import itertools
+from typing import List
+import math
+
+SSDBoxSizes = collections.namedtuple('SSDBoxSizes', ['min', 'max'])
+
+SSDSpec = collections.namedtuple('SSDSpec', ['feature_map_size', 'shrinkage', 'box_sizes', 'aspect_ratios'])
+
+
+def generate_ssd_priors(specs: List[SSDSpec], image_size, clamp=True) -> torch.Tensor:
+    """Generate SSD Prior Boxes.
+
+    It returns the center, height and width of the priors. The values are relative to the image size
+    Args:
+        specs: SSDSpecs about the shapes of sizes of prior boxes. i.e.
+            specs = [
+                SSDSpec(38, 8, SSDBoxSizes(30, 60), [2]),
+                SSDSpec(19, 16, SSDBoxSizes(60, 111), [2, 3]),
+                SSDSpec(10, 32, SSDBoxSizes(111, 162), [2, 3]),
+                SSDSpec(5, 64, SSDBoxSizes(162, 213), [2, 3]),
+                SSDSpec(3, 100, SSDBoxSizes(213, 264), [2]),
+                SSDSpec(1, 300, SSDBoxSizes(264, 315), [2])
+            ]
+        image_size: image size.
+        clamp: if true, clamp the values to make fall between [0.0, 1.0]
+    Returns:
+        priors (num_priors, 4): The prior boxes represented as [[center_x, center_y, w, h]]. All the values
+            are relative to the image size.
+    """
+    priors = []
+    for spec in specs:
+        scale = image_size / spec.shrinkage
+        for j, i in itertools.product(range(spec.feature_map_size), repeat=2):
+            x_center = (i + 0.5) / scale
+            y_center = (j + 0.5) / scale
+
+            # small sized square box
+            size = spec.box_sizes.min
+            h = w = size / image_size
+            priors.append([
+                x_center,
+                y_center,
+                w,
+                h
+            ])
+
+            # big sized square box
+            size = math.sqrt(spec.box_sizes.max * spec.box_sizes.min)
+            h = w = size / image_size
+            priors.append([
+                x_center,
+                y_center,
+                w,
+                h
+            ])
+
+            # change h/w ratio of the small sized box
+            size = spec.box_sizes.min
+            h = w = size / image_size
+            for ratio in spec.aspect_ratios:
+                ratio = math.sqrt(ratio)
+                priors.append([
+                    x_center,
+                    y_center,
+                    w * ratio,
+                    h / ratio
+                ])
+                priors.append([
+                    x_center,
+                    y_center,
+                    w / ratio,
+                    h * ratio
+                ])
+
+    priors = torch.tensor(priors)
+    if clamp:
+        torch.clamp(priors, 0.0, 1.0, out=priors)
+    return priors
+
+
+def convert_locations_to_boxes(locations, priors, center_variance,
+                               size_variance):
+    """Convert regressional location results of SSD into boxes in the form of (center_x, center_y, h, w).
+
+    The conversion:
+        $$predicted\_center * center_variance = \frac {real\_center - prior\_center} {prior\_hw}$$
+        $$exp(predicted\_hw * size_variance) = \frac {real\_hw} {prior\_hw}$$
+    We do it in the inverse direction here.
+    Args:
+        locations (batch_size, num_priors, 4): the regression output of SSD. It will contain the outputs as well.
+        priors (num_priors, 4) or (batch_size/1, num_priors, 4): prior boxes.
+        center_variance: a float used to change the scale of center.
+        size_variance: a float used to change of scale of size.
+    Returns:
+        boxes:  priors: [[center_x, center_y, h, w]]. All the values
+            are relative to the image size.
+    """
+    # priors can have one dimension less.
+    if priors.dim() + 1 == locations.dim():
+        priors = priors.unsqueeze(0)
+    return torch.cat([
+        locations[..., :2] * center_variance * priors[..., 2:] + priors[..., :2],
+        torch.exp(locations[..., 2:] * size_variance) * priors[..., 2:]
+    ], dim=locations.dim() - 1)
+
+
+def convert_boxes_to_locations(center_form_boxes, center_form_priors, center_variance, size_variance):
+    # priors can have one dimension less
+    if center_form_priors.dim() + 1 == center_form_boxes.dim():
+        center_form_priors = center_form_priors.unsqueeze(0)
+    return torch.cat([
+        (center_form_boxes[..., :2] - center_form_priors[..., :2]) / center_form_priors[..., 2:] / center_variance,
+        torch.log(center_form_boxes[..., 2:] / center_form_priors[..., 2:]) / size_variance
+    ], dim=center_form_boxes.dim() - 1)
+
+
+def area_of(left_top, right_bottom) -> torch.Tensor:
+    """Compute the areas of rectangles given two corners.
+
+    Args:
+        left_top (N, 2): left top corner.
+        right_bottom (N, 2): right bottom corner.
+
+    Returns:
+        area (N): return the area.
+    """
+    hw = torch.clamp(right_bottom - left_top, min=0.0)
+    return hw[..., 0] * hw[..., 1]
+
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+    """Return intersection-over-union (Jaccard index) of boxes.
+
+    Args:
+        boxes0 (N, 4): ground truth boxes.
+        boxes1 (N or 1, 4): predicted boxes.
+        eps: a small number to avoid 0 as denominator.
+    Returns:
+        iou (N): IoU values.
+    """
+    overlap_left_top = torch.max(boxes0[..., :2], boxes1[..., :2])
+    overlap_right_bottom = torch.min(boxes0[..., 2:], boxes1[..., 2:])
+
+    overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+    area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+    area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+    return overlap_area / (area0 + area1 - overlap_area + eps)
+
+
+def assign_priors(gt_boxes, gt_labels, corner_form_priors,
+                  iou_threshold):
+    """Assign ground truth boxes and targets to priors.
+
+    Args:
+        gt_boxes (num_targets, 4): ground truth boxes.
+        gt_labels (num_targets): labels of targets.
+        priors (num_priors, 4): corner form priors
+    Returns:
+        boxes (num_priors, 4): real values for priors.
+        labels (num_priros): labels for priors.
+    """
+    # size: num_priors x num_targets
+    ious = iou_of(gt_boxes.unsqueeze(0), corner_form_priors.unsqueeze(1))
+    # size: num_priors
+    best_target_per_prior, best_target_per_prior_index = ious.max(1)
+    # size: num_targets
+    best_prior_per_target, best_prior_per_target_index = ious.max(0)
+
+    for target_index, prior_index in enumerate(best_prior_per_target_index):
+        best_target_per_prior_index[prior_index] = target_index
+    # 2.0 is used to make sure every target has a prior assigned
+    best_target_per_prior.index_fill_(0, best_prior_per_target_index, 2)
+    # size: num_priors
+    labels = gt_labels[best_target_per_prior_index]
+    labels[best_target_per_prior < iou_threshold] = 0  # the backgournd id
+    boxes = gt_boxes[best_target_per_prior_index]
+    return boxes, labels
+
+
+def hard_negative_mining(loss, labels, neg_pos_ratio):
+    """
+    It used to suppress the presence of a large number of negative prediction.
+    It works on image level not batch level.
+    For any example/image, it keeps all the positive predictions and
+     cut the number of negative predictions to make sure the ratio
+     between the negative examples and positive examples is no more
+     the given ratio for an image.
+
+    Args:
+        loss (N, num_priors): the loss for each example.
+        labels (N, num_priors): the labels.
+        neg_pos_ratio:  the ratio between the negative examples and positive examples.
+    Return:
+        indexes of priors to use.
+    """
+    pos_mask = labels > 0
+    num_pos = pos_mask.long().sum(dim=1, keepdim=True)
+    num_neg = num_pos * neg_pos_ratio
+
+    loss[pos_mask] = -math.inf
+    _, indexes = loss.sort(dim=1, descending=True)
+    _, orders = indexes.sort(dim=1)
+    neg_mask = orders < num_neg
+    return pos_mask | neg_mask
+
+
+def center_form_to_corner_form(locations):
+    return torch.cat([locations[..., :2] - locations[..., 2:]/2,
+                      locations[..., :2] + locations[..., 2:]/2], locations.dim() - 1)
+
+
+def corner_form_to_center_form(boxes):
+    return torch.cat([
+        (boxes[..., :2] + boxes[..., 2:]) / 2,
+        boxes[..., 2:] - boxes[..., :2]
+    ], boxes.dim() - 1)
+
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+    """
+
+    Args:
+        box_scores (N, 5): boxes in corner-form and probabilities.
+        iou_threshold: intersection over union threshold.
+        top_k: keep top_k results. If k <= 0, keep all the results.
+        candidate_size: only consider the candidates with the highest scores.
+    Returns:
+         picked: a list of indexes of the kept boxes
+    """
+    scores = box_scores[:, -1]
+    boxes = box_scores[:, :-1]
+    picked = []
+    _, indexes = scores.sort(descending=True)
+    indexes = indexes[:candidate_size]
+    while len(indexes) > 0:
+        current = indexes[0]
+        picked.append(current.item())
+        if 0 < top_k == len(picked) or len(indexes) == 1:
+            break
+        current_box = boxes[current, :]
+        indexes = indexes[1:]
+        rest_boxes = boxes[indexes, :]
+        iou = iou_of(
+            rest_boxes,
+            current_box.unsqueeze(0),
+        )
+        indexes = indexes[iou <= iou_threshold]
+
+    return box_scores[picked, :]
+
+
+def nms(box_scores, nms_method=None, score_threshold=None, iou_threshold=None,
+        sigma=0.5, top_k=-1, candidate_size=200):
+    if nms_method == "soft":
+        return soft_nms(box_scores, score_threshold, sigma, top_k)
+    else:
+        return hard_nms(box_scores, iou_threshold, top_k, candidate_size=candidate_size)
+
+
+def soft_nms(box_scores, score_threshold, sigma=0.5, top_k=-1):
+    """Soft NMS implementation.
+
+    References:
+        https://arxiv.org/abs/1704.04503
+        https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/cython_nms.pyx
+
+    Args:
+        box_scores (N, 5): boxes in corner-form and probabilities.
+        score_threshold: boxes with scores less than value are not considered.
+        sigma: the parameter in score re-computation.
+            scores[i] = scores[i] * exp(-(iou_i)^2 / simga)
+        top_k: keep top_k results. If k <= 0, keep all the results.
+    Returns:
+         picked_box_scores (K, 5): results of NMS.
+    """
+    picked_box_scores = []
+    while box_scores.size(0) > 0:
+        max_score_index = torch.argmax(box_scores[:, 4])
+        cur_box_prob = torch.tensor(box_scores[max_score_index, :])
+        picked_box_scores.append(cur_box_prob)
+        if len(picked_box_scores) == top_k > 0 or box_scores.size(0) == 1:
+            break
+        cur_box = cur_box_prob[:-1]
+        box_scores[max_score_index, :] = box_scores[-1, :]
+        box_scores = box_scores[:-1, :]
+        ious = iou_of(cur_box.unsqueeze(0), box_scores[:, :-1])
+        box_scores[:, -1] = box_scores[:, -1] * torch.exp(-(ious * ious) / sigma)
+        box_scores = box_scores[box_scores[:, -1] > score_threshold, :]
+    if len(picked_box_scores) > 0:
+        return torch.stack(picked_box_scores)
+    else:
+        return torch.tensor([])
+
+
+
diff --git a/synthetic_ssd/models/mobilenet_v2_ssd_lite.py b/synthetic_ssd/models/mobilenet_v2_ssd_lite.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e3e314b16326bc3c17564ea5d3bdea97d839a1
--- /dev/null
+++ b/synthetic_ssd/models/mobilenet_v2_ssd_lite.py
@@ -0,0 +1,51 @@
+"""
+Adapted from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/model.py
+and https://github.com/qfgaohao/pytorch-ssd/blob/master/vision/ssd/mobilenet_v2_ssd_lite.py
+
+"""
+from torch.nn import ModuleList, Conv2d
+
+from .ssd import GraphPath, InvertedResidual, SeparableConv2d, SSD
+
+
+def mobilenet_v2_ssd_lite(num_classes, base_net, width_mult=1.0,
+                            is_test=False, config=None, priors=None, device=None):
+
+    source_layer_indexes = [
+        GraphPath(14, 'conv', 1),
+        19
+    ]
+
+    extras = ModuleList([
+        InvertedResidual(1280, 512, stride=2, expand_ratio=0.2),
+        InvertedResidual(512, 256, stride=2, expand_ratio=0.25),
+        InvertedResidual(256, 256, stride=2, expand_ratio=0.5),
+        InvertedResidual(256, 64, stride=2, expand_ratio=0.25)
+    ])
+
+    regression_headers = ModuleList([
+        SeparableConv2d(in_channels=round(576 * width_mult), out_channels=6 * 4,
+                        kernel_size=3, padding=1, onnx_compatible=False),
+        SeparableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3,
+                        padding=1, onnx_compatible=False),
+        SeparableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3,
+                        padding=1, onnx_compatible=False),
+        SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3,
+                        padding=1, onnx_compatible=False),
+        SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3,
+                        padding=1, onnx_compatible=False),
+        Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
+    ])
+
+    classification_headers = ModuleList([
+        SeparableConv2d(in_channels=round(576 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1),
+        SeparableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1),
+        SeparableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
+        SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+        SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+        Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
+    ])
+
+    return SSD(num_classes, base_net, source_layer_indexes, extras,
+               classification_headers, regression_headers, is_test=is_test,
+               config=config, priors=priors, device=device)
diff --git a/synthetic_ssd/models/mobilenet_v3.py b/synthetic_ssd/models/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71449b8d8ffe384923f188a835e3abe8316b34e
--- /dev/null
+++ b/synthetic_ssd/models/mobilenet_v3.py
@@ -0,0 +1,231 @@
+"""
+Adapted from: https://github.com/d-li14/mobilenetv3.pytorch
+
+Creates a MobileNetV3 Model as defined in:
+Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019).
+Searching for MobileNetV3
+arXiv preprint arXiv:1905.02244. (v.1)
+"""
+
+import torch.nn as nn
+import math
+
+
+def _make_divisible(v, divisor, min_value=None):
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    :param v:
+    :param divisor:
+    :param min_value:
+    :return:
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class h_sigmoid(nn.Module):
+    def __init__(self, inplace=True):
+        super(h_sigmoid, self).__init__()
+        self.relu = nn.ReLU6(inplace=inplace)
+
+    def forward(self, x):
+        return self.relu(x + 3) / 6
+
+
+class h_swish(nn.Module):
+    def __init__(self, inplace=True):
+        super(h_swish, self).__init__()
+        self.sigmoid = h_sigmoid(inplace=inplace)
+
+    def forward(self, x):
+        return x * self.sigmoid(x)
+
+
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=4):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel),
+            h_sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y
+
+
+def conv_3x3_bn(inp, oup, stride):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+        nn.BatchNorm2d(oup),
+        h_swish()
+    )
+
+
+def conv_1x1_bn(inp, oup):
+    return nn.Sequential(
+        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+        nn.BatchNorm2d(oup),
+        h_swish()
+    )
+
+
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
+        super(InvertedResidual, self).__init__()
+        assert stride in [1, 2]
+
+        self.identity = stride == 1 and inp == oup
+
+        if inp == hidden_dim:
+            self.conv = nn.Sequential(
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                h_swish() if use_hs else nn.ReLU(inplace=True),
+                # Squeeze-and-Excite
+                SELayer(hidden_dim) if use_se else nn.Sequential(),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+        else:
+            self.conv = nn.Sequential(
+                # pw
+                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                h_swish() if use_hs else nn.ReLU(inplace=True),
+                # dw
+                nn.Conv2d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False),
+                nn.BatchNorm2d(hidden_dim),
+                # Squeeze-and-Excite
+                SELayer(hidden_dim) if use_se else nn.Sequential(),
+                h_swish() if use_hs else nn.ReLU(inplace=True),
+                # pw-linear
+                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+                nn.BatchNorm2d(oup),
+            )
+
+    def forward(self, x):
+        if self.identity:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileNetV3(nn.Module):
+    def __init__(self, mode, num_channels=3, num_classes=1000, width_mult=1.):
+        super(MobileNetV3, self).__init__()
+        # setting of inverted residual blocks
+        assert mode in ['large', 'small']
+        self.mode = mode
+        self.width_mult = width_mult
+        if mode == 'large':
+            cfgs = [
+                # k, t, c, SE, NL, s
+                [3, 16, 16, 0, 0, 1],
+                [3, 64, 24, 0, 0, 2],
+                [3, 72, 24, 0, 0, 1],
+                [5, 72, 40, 1, 0, 2],
+                [5, 120, 40, 1, 0, 1],
+                [5, 120, 40, 1, 0, 1],
+                [3, 240, 80, 0, 1, 2],
+                [3, 200, 80, 0, 1, 1],
+                [3, 184, 80, 0, 1, 1],
+                [3, 184, 80, 0, 1, 1],
+                [3, 480, 112, 1, 1, 1],
+                [3, 672, 112, 1, 1, 1],
+                [5, 672, 160, 1, 1, 2],
+                [5, 672, 160, 1, 1, 1],  # from ArXiv paper v1. In subsequent versions: t=960 instead of 672
+                [5, 960, 160, 1, 1, 1]
+            ]
+        else:
+            cfgs = [
+                # k, t, c, SE, NL, s
+                [3, 16, 16, 1, 0, 2],
+                [3, 72, 24, 0, 0, 2],
+                [3, 88, 24, 0, 0, 1],
+                [5, 96, 40, 1, 1, 2],
+                [5, 240, 40, 1, 1, 1],
+                [5, 240, 40, 1, 1, 1],
+                [5, 120, 48, 1, 1, 1],
+                [5, 144, 48, 1, 1, 1],
+                [5, 288, 96, 1, 1, 2],
+                [5, 576, 96, 1, 1, 1],
+                [5, 576, 96, 1, 1, 1],
+            ]
+
+        self.cfgs = cfgs
+
+        # building first layer
+        input_channel = _make_divisible(16 * width_mult, 8)
+        layers = [conv_3x3_bn(num_channels, input_channel, 2)]
+        # building inverted residual blocks
+        block = InvertedResidual
+        for k, exp_size, c, use_se, use_hs, s in self.cfgs:
+            output_channel = _make_divisible(c * width_mult, 8)
+            layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
+            input_channel = output_channel
+        self.features = nn.Sequential(*layers)
+        # building last several layers
+        self.conv = nn.Sequential(
+            conv_1x1_bn(input_channel, _make_divisible(exp_size * width_mult, 8)),
+            SELayer(_make_divisible(exp_size * width_mult, 8)) if mode == 'small' else nn.Sequential()
+        )
+        self.avgpool = nn.Sequential(
+            nn.AdaptiveAvgPool2d((1, 1)),
+            h_swish()
+        )
+        output_channel = _make_divisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280
+        self.classifier = nn.Sequential(
+            nn.Linear(_make_divisible(exp_size * width_mult, 8), output_channel),
+            nn.BatchNorm1d(output_channel) if mode == 'small' else nn.Sequential(),
+            h_swish(),
+            nn.Linear(output_channel, num_classes),
+            nn.BatchNorm1d(num_classes) if mode == 'small' else nn.Sequential(),
+            h_swish() if mode == 'small' else nn.Sequential()
+        )
+
+        self._initialize_weights()
+
+    def forward(self, x):
+        x = self.features(x)
+        x = self.conv(x)
+        x = self.avgpool(x)
+        x = x.view(x.size(0), -1)
+        x = self.classifier(x)
+        return x
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+                m.weight.data.normal_(0, math.sqrt(2. / n))
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            elif isinstance(m, nn.Linear):
+                n = m.weight.size(1)
+                m.weight.data.normal_(0, 0.01)
+                m.bias.data.zero_()
+
+    def get_mode(self):
+        return self.mode
+
+    def get_width_mult(self):
+        return self.width_mult
diff --git a/synthetic_ssd/models/mobilenet_v3_ssd_lite.py b/synthetic_ssd/models/mobilenet_v3_ssd_lite.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc6c64d6dd17c86625911e8f2d5c43206bab546a
--- /dev/null
+++ b/synthetic_ssd/models/mobilenet_v3_ssd_lite.py
@@ -0,0 +1,103 @@
+'''
+Adapted from https://github.com/tongyuhome/MobileNetV3-SSD/blob/master/mobilenet_v3_ssd_lite.py
+'''
+import torch
+from torch.nn import ModuleList, Conv2d
+
+from .mobilenet_v3 import MobileNetV3
+from .ssd import GraphPath, InvertedResidual, SeparableConv2d, SSD
+
+
+def mobilenet_v3_ssd_lite(num_classes, base_net, width_mult=1.0,
+                            is_test=False, config=None, dropout=0.8, priors=None,
+                            device=None):
+    """
+    Same parameters as MobileNetV3 class init function
+    num_classes counts the background class
+    :return: Model
+    """
+    assert isinstance(base_net, MobileNetV3)
+    if base_net.get_mode() == 'small':
+        source_layer_indexes = [
+            GraphPath(9, 'conv', 3),
+            13
+        ]
+        # Reorganize layers
+        modules = [x for x in base_net.features]
+        modules.append(base_net.conv[:])  # Conv2d + BN + HS + SEmodules.append(base_net.avgpool[0])  # avgpool without HS
+        sequence = torch.nn.Sequential(*modules)
+        base_net.features = sequence
+        base_net.conv = torch.nn.Sequential()
+
+        extras = ModuleList([
+            InvertedResidual(576, 512, stride=2, expand_ratio=0.2),
+            InvertedResidual(512, 256, stride=2, expand_ratio=0.25),
+            InvertedResidual(256, 256, stride=2, expand_ratio=0.5),
+            InvertedResidual(256, 64, stride=2, expand_ratio=0.25)
+        ])
+
+        regression_headers = ModuleList([
+            SeparableConv2d(in_channels=round(288 * width_mult), out_channels=6 * 4,
+                            kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=576, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
+        ])
+
+        classification_headers = ModuleList([
+            SeparableConv2d(in_channels=round(288 * width_mult), out_channels=6 * num_classes, kernel_size=3,
+                            padding=1),
+            SeparableConv2d(in_channels=576, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
+        ])
+
+    elif base_net.get_mode() == 'large':
+        source_layer_indexes = [
+            GraphPath(13, 'conv', 3),
+            17
+        ]
+        # Reorganize layers
+        modules = [x for x in base_net.features]
+        modules.append(base_net.conv[0])  # Conv2d + BN + HS
+        sequence = torch.nn.Sequential(*modules)
+        base_net.features = sequence
+        base_net.conv = torch.nn.Sequential()
+
+        extras = ModuleList([
+            InvertedResidual(960, 512, stride=2, expand_ratio=0.2),
+            InvertedResidual(512, 256, stride=2, expand_ratio=0.25),
+            InvertedResidual(256, 256, stride=2, expand_ratio=0.5),
+            InvertedResidual(256, 64, stride=2, expand_ratio=0.25)
+        ])
+
+        regression_headers = ModuleList([
+            SeparableConv2d(in_channels=round(672 * width_mult), out_channels=6 * 4,
+                            kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=960, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            SeparableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
+            Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
+        ])
+
+        classification_headers = ModuleList([
+            SeparableConv2d(in_channels=round(672 * width_mult), out_channels=6 * num_classes, kernel_size=3,
+                            padding=1),
+            SeparableConv2d(in_channels=960, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            SeparableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
+            Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
+        ])
+    else:
+        print("Mode {mode} is not available, use 'small' or 'large'.".format(mode=base_net.get_mode()))
+        exit(1)
+
+    return SSD(num_classes, base_net, source_layer_indexes, extras,
+               classification_headers, regression_headers, is_test=is_test,
+               config=config, priors=priors, device=device)
diff --git a/synthetic_ssd/models/mobilenetv2_ssd_config.py b/synthetic_ssd/models/mobilenetv2_ssd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..50a66998e0da735a5357eb734012a6ab8a371719
--- /dev/null
+++ b/synthetic_ssd/models/mobilenetv2_ssd_config.py
@@ -0,0 +1,23 @@
+"""
+Modified from https://github.com/qfgaohao/pytorch-ssd/blob/master/vision/ssd/config/mobilenetv1_ssd_config.py
+"""
+
+from .box_utils import SSDSpec, SSDBoxSizes
+
+
+iou_threshold = 0.5
+center_variance = 0.1
+size_variance = 0.2
+neg_pos_ratio = 3
+
+
+specs = [
+    SSDSpec(14, 16, SSDBoxSizes(60, 105), [2, 3]),
+    SSDSpec(7, 32, SSDBoxSizes(105, 150), [2, 3]),
+    SSDSpec(4, 64, SSDBoxSizes(150, 195), [2, 3]),
+    SSDSpec(2, 100, SSDBoxSizes(195, 240), [2, 3]),
+    SSDSpec(1, 150, SSDBoxSizes(240, 285), [2, 3]),
+    SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3])
+]
+
+
diff --git a/synthetic_ssd/models/mobilenetv3_ssd_config.py b/synthetic_ssd/models/mobilenetv3_ssd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..24028fb0c0e1ade33d28dd8fad13e991c7534b9f
--- /dev/null
+++ b/synthetic_ssd/models/mobilenetv3_ssd_config.py
@@ -0,0 +1,16 @@
+from .box_utils import SSDSpec, SSDBoxSizes
+
+iou_threshold = 0.5
+center_variance = 0.1
+size_variance = 0.2
+neg_pos_ratio = 3
+
+specs = [
+    SSDSpec(14, 16, SSDBoxSizes(60, 105), [2, 3]),
+    SSDSpec(7, 32, SSDBoxSizes(105, 150), [2, 3]),
+    SSDSpec(4, 64, SSDBoxSizes(150, 195), [2, 3]),
+    SSDSpec(2, 100, SSDBoxSizes(195, 240), [2, 3]),
+    SSDSpec(1, 150, SSDBoxSizes(240, 285), [2, 3]),
+    SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3])
+]
+
diff --git a/synthetic_ssd/models/ssd.py b/synthetic_ssd/models/ssd.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e6bebdd9cf225238659034dc21b7d9c5c00841
--- /dev/null
+++ b/synthetic_ssd/models/ssd.py
@@ -0,0 +1,241 @@
+"""
+Adapted for MobileNet from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection/blob/master/model.py
+and https://github.com/qfgaohao/pytorch-ssd/blob/master/vision/ssd/mobilenet_v2_ssd_lite.py
+and https://github.com/tongyuhome/MobileNetV3-SSD/blob/master/mobilenet_v3_ssd_lite.py
+"""
+from collections import namedtuple
+
+import torch
+from torch import nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from . import box_utils
+
+GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1'])
+
+
+def SeparableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, onnx_compatible=False):
+    """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
+    """
+    ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
+    return nn.Sequential(
+        nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
+                  groups=in_channels, stride=stride, padding=padding),
+        nn.BatchNorm2d(in_channels),
+        ReLU(),
+        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
+    )
+
+
+# From torchvision.models.mobilenet.py
+class ConvBNReLU(nn.Sequential):
+    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+        padding = (kernel_size - 1) // 2
+        super(ConvBNReLU, self).__init__(
+            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.ReLU6(inplace=True)
+        )
+
+
+# From torchvision.models.mobilenet.py
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, oup, stride, expand_ratio):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = int(round(inp * expand_ratio))
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        layers = []
+        if expand_ratio != 1:
+            # pw
+            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+        layers.extend([
+            # dw
+            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+            # pw-linear
+            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+            nn.BatchNorm2d(oup),
+        ])
+        self.conv = nn.Sequential(*layers)
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class SSD(nn.Module):
+    def __init__(self, num_classes, base_net, source_layer_indexes, extras,
+                 classification_headers, regression_headers, device=None,
+                 is_test=False, config=None, priors=None):
+        super(SSD, self).__init__()
+        self.num_classes = num_classes
+
+        self.base_net = base_net.features
+        self.source_layer_indexes = source_layer_indexes
+        self.extras = extras
+        self.classification_headers = classification_headers
+        self.regression_headers = regression_headers
+
+        self.is_test = is_test
+        self.config_center_variance = config.center_variance
+        self.config_size_variance = config.size_variance
+
+        self.source_layer_add_ons = nn.ModuleList(
+            [t[1] for t in source_layer_indexes
+             if isinstance(t, tuple) and not isinstance(t, GraphPath)])
+
+        if device is not None:
+            self.device = device
+        else:
+            self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
+
+        if priors is not None:
+            self.priors = priors.to(self.device)
+
+    def forward(self, x):
+        """
+        :param x: torch.Tensor
+        :return: Tuple[torch.Tensor, torch.Tensor]
+                In train mode, return the confidences and locations
+                In test mode, return the confidences and the boxes (top-left
+                corner, bottom-right corner)
+        """
+        confidences = []
+        locations = []
+        start_layer_index = 0
+        header_index = 0
+
+        for end_layer_index in self.source_layer_indexes:
+            if isinstance(end_layer_index, GraphPath):
+                path = end_layer_index
+                end_layer_index = end_layer_index.s0
+                added_layer = None
+            elif isinstance(end_layer_index, tuple):
+                added_layer = end_layer_index[1]
+                end_layer_index = end_layer_index[0]
+                path = None
+            else:
+                added_layer = None
+                path = None
+
+            for layer in self.base_net[start_layer_index: end_layer_index]:
+                x = layer(x)
+            if added_layer:
+                y = added_layer(x)
+            else:
+                y = x
+            if path:
+                sub = getattr(self.base_net[end_layer_index], path.name)
+                if path.s1 < 0:
+                    for layer in sub:
+                        y = layer(y)
+                else:
+                    for layer in sub[:path.s1]:
+                        x = layer(x)
+                    y = x
+                    for layer in sub[path.s1:]:
+                        x = layer(x)
+                    end_layer_index += 1
+            start_layer_index = end_layer_index
+            confidence, location = self.compute_header(header_index, y)
+            header_index += 1
+            confidences.append(confidence)
+            locations.append(location)
+
+        for layer in self.extras:
+            x = layer(x)
+            confidence, location = self.compute_header(header_index, x)
+            header_index += 1
+            confidences.append(confidence)
+            locations.append(location)
+
+        confidences = torch.cat(confidences, 1)
+        locations = torch.cat(locations, 1)
+
+        if self.is_test:
+            confidences = F.softmax(confidences, dim=2)
+            boxes = box_utils.convert_locations_to_boxes(locations,
+                                                         self.priors, self.config_center_variance,
+                                                         self.config_size_variance)
+            boxes = box_utils.center_form_to_corner_form(boxes)
+            return confidences, boxes
+        else:
+            return confidences, locations
+
+    def compute_header(self, i, x):
+        confidence = self.classification_headers[i](x)
+        confidence = confidence.permute(0, 2, 3, 1).contiguous()
+        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
+
+        location = self.regression_headers[i](x)
+        location = location.permute(0, 2, 3, 1).contiguous()
+        location = location.view(location.size(0), -1, 4)
+
+        return confidence, location
+
+    def init_from_base_net(self, model):
+        self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=True)
+        self.source_layer_add_ons.apply(_xavier_init_)
+        self.extras.apply(_xavier_init_)
+        self.classification_headers.apply(_xavier_init_)
+        self.regression_headers.apply(_xavier_init_)
+
+    def init_xavier_except_base_net(self):
+        self.source_layer_add_ons.apply(_xavier_init_)
+        self.extras.apply(_xavier_init_)
+        self.classification_headers.apply(_xavier_init_)
+        self.regression_headers.apply(_xavier_init_)
+
+    def init_from_pretrained_ssd(self, model):
+        state_dict = torch.load(model, map_location=lambda storage, loc: storage)
+        state_dict = {k: v for k, v in state_dict.items() if
+                      not (k.startswith("classification_headers") or k.startswith("regression_headers"))}
+        model_dict = self.state_dict()
+        model_dict.update(state_dict)
+        self.load_state_dict(model_dict)
+        self.classification_headers.apply(_xavier_init_)
+        self.regression_headers.apply(_xavier_init_)
+
+    def init_xavier(self):
+        self.base_net.apply(_xavier_init_)
+        self.source_layer_add_ons.apply(_xavier_init_)
+        self.extras.apply(_xavier_init_)
+        self.classification_headers.apply(_xavier_init_)
+        self.regression_headers.apply(_xavier_init_)
+
+    def load(self, model):
+        self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
+
+    def save(self, model_path):
+        torch.save(self.state_dict(), model_path)
+
+
+def _xavier_init_(m: nn.Module):
+    if isinstance(m, nn.Conv2d):
+        nn.init.xavier_uniform_(m.weight)
+
+
+class MatchPrior(object):
+    def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold):
+        self.center_form_priors = center_form_priors
+        self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors)
+        self.center_variance = center_variance
+        self.size_variance = size_variance
+        self.iou_threshold = iou_threshold
+
+    def __call__(self, gt_boxes, gt_labels):
+        if type(gt_boxes) is np.ndarray:
+            gt_boxes = torch.from_numpy(gt_boxes)
+        if type(gt_labels) is np.ndarray:
+            gt_labels = torch.from_numpy(gt_labels)
+        boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels,
+                                                self.corner_form_priors, self.iou_threshold)
+        boxes = box_utils.corner_form_to_center_form(boxes)
+        locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance)
+        return locations, labels
diff --git a/synthetic_ssd/scripts/__init__.py b/synthetic_ssd/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/synthetic_ssd/scripts/run_eval.py b/synthetic_ssd/scripts/run_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c372daac1966be551760bd5ea0021403a9965162
--- /dev/null
+++ b/synthetic_ssd/scripts/run_eval.py
@@ -0,0 +1,201 @@
+import argparse
+
+import tqdm
+import torch
+from torch.utils.data import DataLoader
+
+from synthetic_ssd.utils.configuration import initialize
+from synthetic_ssd.utils.evaluate import custom_nms
+from synthetic_ssd.utils.load_weights import load_weights
+from synthetic_ssd.models import create_model
+from synthetic_ssd.datasets.detection_dataset import TLESSDetection, collate_fn
+from synthetic_ssd.datasets.augmentations import test_tf_ssd, test_tf_rcnn
+
+from synthetic_ssd.metrics import _init_paths
+from BoundingBox import BoundingBox
+from BoundingBoxes import BoundingBoxes
+from Evaluator import Evaluator
+import utils as metrics_utils
+
+
+def main():
+    args = argparse.ArgumentParser()
+    args.add_argument('--model', type=str, default='mobilenet_v2_ssd',
+                      choices=['mobilenet_v2_ssd', 'mobilenet_v3_ssd', 'mobilenet_v3_small_ssd', 'mask_rcnn'],
+                      help="Object detection model (default: 'mobilenet_v2_ssd').")
+    args.add_argument('--tf', type=str, default='aug2', choices=['aug1', 'aug2'],
+                      help="Augmentation method to applied for training (default:'aug2').")
+    args.add_argument('--random_seed', type=int, default=3,
+                      help="Random seed (default: 3).")
+    opt = args.parse_args()
+    opt.train = False
+    opt.num_classes = 30
+    opt.input_size = 224
+    run_evaluation(opt)
+
+
+@torch.no_grad()
+def run_evaluation(opt):
+    initialize(opt)
+
+    # Load dataset
+    if "mobilenet" in opt.model:
+        tf = test_tf_ssd(opt.input_size)
+    else:
+        tf = test_tf_rcnn()
+    dataset = TLESSDetection(tf=tf, target_tf=None, mode='test')
+
+    # Load model
+    model = create_model(opt)
+    model = load_weights(model, model_name=opt.model, tf=opt.tf)
+
+    model.to(opt.device)
+    model.eval()
+    if hasattr(model, 'is_test'):
+        model.is_test = True
+
+    # Eval loop
+    evaluator = Evaluator()
+    if opt.model == "mask_rcnn":
+        bboxes = test_mask_rcnn(opt, dataset, model)
+    else:
+        dataloader = DataLoader(dataset,
+                                batch_size=1,
+                                shuffle=False,
+                                num_workers=2,
+                                collate_fn=collate_fn)
+
+        bboxes = BoundingBoxes()
+
+        for data in tqdm.tqdm(dataloader):
+            images, targets = data
+            raw_boxes = [t['raw_boxes'] for t in targets]
+            raw_labels = [t['raw_labels'] for t in targets]
+            image_names = [t['image_name'] for t in targets]
+            image_sizes = [t['image_size'] for t in targets]
+
+            for im in range(len(images)):
+                im_boxes = raw_boxes[im]  # tensor of size (n_boxes, 4)
+                im_labels = raw_labels[im]
+                im_name = image_names[im]
+                im_size = image_sizes[im]
+                for i in range(im_boxes.size(0)):
+                    x1, y1, x2, y2 = im_boxes[i]
+                    x = (x1.item()+x2.item())/2.
+                    y = (y1.item()+y2.item())/2.
+                    w = x2.item() - x1.item()
+                    h = y2.item() - y1.item()
+                    gtBox = BoundingBox(imageName=im_name,
+                                        classId=str(im_labels[i].item()),
+                                        x=x, y=y,
+                                        w=w, h=h,
+                                        typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                        imgSize=(im_size[1], im_size[0]),
+                                        bbType=metrics_utils.BBType.GroundTruth,
+                                        format=metrics_utils.BBFormat.XYWH)
+                    bboxes.addBoundingBox(gtBox)
+
+            confidences, locations = model(images.to(opt.device))
+            out_boxes, out_labels, out_confidences, out_names = custom_nms(locations, confidences,
+                                                                           images_names=image_names,
+                                                                           iou_threshold=0.5,
+                                                                           score_threshold=0.1)
+
+            for i in range(len(out_boxes)):
+                im_size = image_sizes[image_names.index(out_names[i])]  # (h, w)
+                x1, y1, x2, y2 = out_boxes[i].cpu()
+                x = (x1.item()+x2.item())/2.
+                y = (y1.item()+y2.item())/2.
+                w = x2.item() - x1.item()
+                h = y2.item() - y1.item()
+                newBox = BoundingBox(imageName=out_names[i],
+                                     classId=str(out_labels[i]),
+                                     x=x, y=y, w=w, h=h,
+                                     typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                     imgSize=(im_size[1], im_size[0]),
+                                     bbType=metrics_utils.BBType.Detected,
+                                     classConfidence=out_confidences[i],
+                                     format=metrics_utils.BBFormat.XYWH)
+                bboxes.addBoundingBox(newBox)
+
+    metrics = evaluator.GetPascalVOCMetrics(bboxes)
+    AP = []
+    for cls_metrics in metrics:
+        print("Class {cls}: {AP}% AP, {pos} gt positives, "
+              "{TP} true positives, {FP} false positives".
+              format(cls=cls_metrics['class'],
+                     AP=cls_metrics['AP'] * 100,
+                     pos=cls_metrics['total positives'],
+                     TP=cls_metrics['total TP'],
+                     FP=cls_metrics['total FP']
+                     ))
+        AP.append(cls_metrics['AP'])
+    mean_ap = sum(AP) / len(AP)
+    print(f"Average precision: {mean_ap * 100}%")
+
+
+def test_mask_rcnn(opt, dataset, model):
+    bboxes = BoundingBoxes()
+    for img, target in tqdm.tqdm(dataset):
+        img = img.div(255.)
+        _, height, width = img.shape
+
+        img = [img.to(opt.device)]
+        im_boxes = target["raw_boxes"]
+        im_labels = target["raw_labels"]
+        im_name = target["image_name"]
+
+        for i in range(im_boxes.size(0)):
+            x1, y1, x2, y2 = im_boxes[i]
+            x = (x1.item()+x2.item())/2.
+            y = (y1.item()+y2.item())/2.
+            w = x2.item() - x1.item()
+            h = y2.item() - y1.item()
+            gtBox = BoundingBox(imageName=im_name,
+                                classId=str(im_labels[i].item()),
+                                x=x, y=y, w=w, h=h,
+                                typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                imgSize=(width, height),
+                                bbType=metrics_utils.BBType.GroundTruth,
+                                format=metrics_utils.BBFormat.XYWH)
+            bboxes.addBoundingBox(gtBox)
+
+        with torch.no_grad():
+            output = model(img)[0]
+
+        out_boxes = output["boxes"].cpu()
+        out_labels = output["labels"].cpu()
+        out_confidences = output["scores"].cpu()
+
+        out_boxes[:, 0] = out_boxes[:, 0].div(width)
+        out_boxes[:, 2] = out_boxes[:, 2].div(width)
+        out_boxes[:, 1] = out_boxes[:, 1].div(height)
+        out_boxes[:, 3] = out_boxes[:, 3].div(height)
+
+        score_threshold = 0.1
+        for i in range(out_boxes.shape[0]):
+            if out_confidences[i] < score_threshold:
+                continue
+
+            x1, y1, x2, y2 = out_boxes[i].cpu()
+            x = (x1.item()+x2.item())/2.
+            y = (y1.item()+y2.item())/2.
+            w = x2.item() - x1.item()
+            h = y2.item() - y1.item()
+            newBox = BoundingBox(imageName=im_name,
+                                 classId=str(out_labels[i].item()),
+                                 x=x, y=y, w=w, h=h,
+                                 typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                 imgSize=(width, height),
+                                 bbType=metrics_utils.BBType.Detected,
+                                 classConfidence=out_confidences[i],
+                                 format=metrics_utils.BBFormat.XYWH)
+            if newBox is not None:
+                bboxes.addBoundingBox(newBox)
+
+    return bboxes
+
+
+if __name__ == "__main__":
+    main()
+
diff --git a/synthetic_ssd/scripts/test_dataset.py b/synthetic_ssd/scripts/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3de7fc5eb99837f916d6e8144d0288ad07d5ad09
--- /dev/null
+++ b/synthetic_ssd/scripts/test_dataset.py
@@ -0,0 +1,18 @@
+from tqdm import tqdm
+
+from torch.utils.data import DataLoader
+
+from synthetic_ssd.datasets.detection_dataset import TLESSDetection, collate_fn
+from synthetic_ssd.datasets.augmentations import train_tf, test_tf_ssd, test_tf_rcnn
+
+
+if __name__ == "__main__":
+    for mode in ('train', 'test'):
+        for tf in (train_tf(), test_tf_ssd(), test_tf_rcnn()):
+            dataset = TLESSDetection(tf=tf, target_tf=None, mode=mode)
+            loader = DataLoader(dataset, shuffle=(mode == 'train'),
+                                batch_size=2, num_workers=2,
+                                drop_last=False, collate_fn=collate_fn)
+
+        for data in tqdm(loader):
+            pass
diff --git a/synthetic_ssd/utils/__init__.py b/synthetic_ssd/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eb2c0cf92f01f7a58d7597f558a744a8216fe5c
--- /dev/null
+++ b/synthetic_ssd/utils/__init__.py
@@ -0,0 +1,11 @@
+import argparse
+import os
+
+
+def parse_arguments():
+    args = argparse.ArgumentParser()
+    return args.parse_args()
+
+
+if __name__ == "__main__":
+    opt = parse_arguments()
diff --git a/synthetic_ssd/utils/configuration.py b/synthetic_ssd/utils/configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..572dc84402617d7355ce1c11c36a043827ea82b9
--- /dev/null
+++ b/synthetic_ssd/utils/configuration.py
@@ -0,0 +1,9 @@
+import torch
+
+
+def initialize(opt):
+    torch.manual_seed(opt.random_seed)
+    torch.backends.cudnn.deterministic = True
+    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
+    opt.device = device
+    opt.k_rgb = 3
diff --git a/synthetic_ssd/utils/evaluate.py b/synthetic_ssd/utils/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..34974d7a6ee70e65df52090375c8555222c5b4c2
--- /dev/null
+++ b/synthetic_ssd/utils/evaluate.py
@@ -0,0 +1,42 @@
+from torchvision.ops import nms
+
+
+def custom_nms(boxes, confidences, images_names, score_threshold=0.1, iou_threshold=0.5):
+    """
+    :param boxes: tensor of size(batch_size, N_boxes, 4).
+    :param confidences: tensor of size (batch_size, N_boxes, Nclasses+1), with background.
+    :param images_names: list of length batch_size.
+    :param score_threshold: remove boxes with confidence score < score_threshold.
+    :param iou_threshold: remove boxes with IoU with other box of higher confidence > iou_threshold.
+    :return: boxes, labels, confidences, images names as lists.
+    """
+    out_boxes = []
+    out_labels = []
+    out_confidences = []
+    out_names = []
+
+    # Remove the background class
+    confidences = confidences[:, :, 1:]
+
+    # For each image in the batch:
+    for i in range(boxes.size(0)):
+        im_name = images_names[i]
+        # For each class
+        for cls in range(confidences.size(-1)):
+            # /!\ Class label starts at 1 in the ground truth, 0 here /!\
+            conf_mask = confidences[i, :, cls] > score_threshold
+            boxes_im = boxes[i, conf_mask, :]
+            scores_im = confidences[i, conf_mask, cls]
+
+            if scores_im.size(0) > 0:
+                keep = nms(boxes_im, scores_im, iou_threshold=iou_threshold)
+            else:
+                keep = []
+
+            for ind in keep:
+                out_boxes.append(boxes_im[ind.item()])
+                out_labels.append(cls+1)
+                out_confidences.append(scores_im[ind.item()])
+                out_names.append(im_name)
+
+    return out_boxes, out_labels, out_confidences, out_names
\ No newline at end of file
diff --git a/synthetic_ssd/utils/load_weights.py b/synthetic_ssd/utils/load_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e8d0c512f1848d0eaceaa51fb4ab4e5ed5f59b8
--- /dev/null
+++ b/synthetic_ssd/utils/load_weights.py
@@ -0,0 +1,43 @@
+import torch
+
+from synthetic_ssd.config import WEIGHTS_DIR, MASK_RCNN_PATH
+
+
+def load_weights(model, model_name, tf):
+    if model_name == "mask_rcnn":
+        model = load_mask_rcnn_weights(model)
+    elif model_name == "mobilenet_v2_ssd":
+        model = load_mobilenet_v2_ssd_weights(model, tf)
+    elif model_name == "mobilenet_v3_ssd":
+        model = load_mobilenet_v3_ssd_weights(model, tf)
+    elif model_name == "mobilenet_v3_small_ssd":
+        model = load_mobilenet_v3_small_ssd_weights(model, tf)
+    return model
+
+
+def load_mask_rcnn_weights(model):
+    weights = MASK_RCNN_PATH
+    state_dict = torch.load(weights)
+    model.load_state_dict(state_dict['state_dict'])
+    return model
+
+
+def load_mobilenet_v2_ssd_weights(model, tf):
+    weights = WEIGHTS_DIR / tf / "tless_icip21_V2ssd.pth"
+    state_dict = torch.load(weights)
+    model.load_state_dict(state_dict)
+    return model
+
+
+def load_mobilenet_v3_ssd_weights(model, tf):
+    weights = WEIGHTS_DIR / tf / "tless_icip21_V3ssd.pth"
+    state_dict = torch.load(weights)
+    model.load_state_dict(state_dict, strict=False)
+    return model
+
+
+def load_mobilenet_v3_small_ssd_weights(model, tf):
+    weights = WEIGHTS_DIR / tf / "tless_icip21_V3smallssd.pth"
+    state_dict = torch.load(weights)
+    model.load_state_dict(state_dict, strict=False)
+    return model