diff --git a/README.md b/README.md
index 5fc3fff3dda977c89631cb50eb3cff44b806ba4e..cf8c9e3f465f56269aee2cbda44264008a9f4447 100644
--- a/README.md
+++ b/README.md
@@ -91,7 +91,30 @@ V3-SSD (aug2) | 46.1 | 4.9
 
 
 # Training
-Coming soon...
+To train one of the SSD models, use the trainign script:
+```
+python -m synthetic_ssd.scripts.train_ssd \
+  -h, --help            show this help message and exit
+  --model {mobilenet_v2_ssd,mobilenet_v3_ssd,mobilenet_v3_small_ssd}
+                        Object detection model (default: 'mobilenet_v2_ssd').
+  --random_seed RANDOM_SEED
+                        Random seed (default: 3).
+  --valid_id VALID_ID   Folder ID to use for validation (default: 49).
+  --num_workers NUM_WORKERS
+  --no_print
+  --eval_freq EVAL_FREQ
+                        Frequency to evaluate the model (default: 5, for every 5 epoch).
+  --eval_on_valid       Evaluation is performed on the validation set (synthetic images).
+  --eval_on_test        Evaluation is performed on the test set (real images).
+  --batch_size BATCH_SIZE
+                        Batch size (default: 16).
+  --epochs EPOCHS       Number of training epochs (default: 75).
+  --lr LR               Learning rate (default: 0.05).
+  --momentum MOMENTUM   Momentum (default: 0.9).
+  --weight_decay WEIGHT_DECAY
+                        Weight decay (default: 0.000012).
+```
+Training uses the proposed augmentation method (aug2).
 
 # Citation
 If you use this code in your research, please cite the paper:
diff --git a/data/weights/imagenet-mobilenetv3-large.pth b/data/weights/imagenet-mobilenetv3-large.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c941861b17e19545b1920a23ffede01b31854edc
Binary files /dev/null and b/data/weights/imagenet-mobilenetv3-large.pth differ
diff --git a/data/weights/imagenet-mobilenetv3-small.pth b/data/weights/imagenet-mobilenetv3-small.pth
new file mode 100644
index 0000000000000000000000000000000000000000..672c0057c5b8d7e00b2afee5ce3991f44af6bf50
Binary files /dev/null and b/data/weights/imagenet-mobilenetv3-small.pth differ
diff --git a/synthetic_ssd/config.py b/synthetic_ssd/config.py
index cd73b4473b37da99eb1041c9d766673998c15026..29a48fac57274b1cc5a68f5b9663922a3d68dfa5 100644
--- a/synthetic_ssd/config.py
+++ b/synthetic_ssd/config.py
@@ -13,3 +13,4 @@ TLESS_TEST_PATH = TLESS_BASE_PATH / 'test_primesense'
 
 DEPS_DIR = PROJECT_DIR / 'deps'
 
+SAVE_DIR = DATA_DIR / 'output'
diff --git a/synthetic_ssd/datasets/detection_dataset.py b/synthetic_ssd/datasets/detection_dataset.py
index 402ebb817a3c1b5834ec07d4b8b28fa410d5de24..008e39292654a04d8f486b228cb74cf229224170 100644
--- a/synthetic_ssd/datasets/detection_dataset.py
+++ b/synthetic_ssd/datasets/detection_dataset.py
@@ -25,6 +25,7 @@ class TLESSDetection(Dataset):
         self.gt = dict()
         self.gt['boxes'] = dict()
         self.gt['labels'] = dict()
+        annot_gt = None
         for cur_path in path.iterdir():
             if not cur_path.is_dir():
                 continue
@@ -49,6 +50,7 @@ class TLESSDetection(Dataset):
                     labels.append(id_data["obj_id"])
                 self.gt['boxes'][cur_path.name.zfill(6)][image_id.zfill(6)] = boxes
                 self.gt['labels'][cur_path.name.zfill(6)][image_id.zfill(6)] = labels
+        assert annot_gt is not None, f"Path {path} to dataset seems incorrect, got 0 annotations."
         del annot_gt, annot_gt_info
 
         self.length = len(self.rgb_paths)
@@ -58,7 +60,7 @@ class TLESSDetection(Dataset):
         return self.length
 
     def __repr__(self):
-        return f"Dataset TLESS of RGB images for object detection: "
+        return f"Dataset T-LESS of RGB images for object detection: "
 
     def __getitem__(self, idx):
         color_img = cv2.imread(self.rgb_paths[idx])[:, :, :3]  # uint8, BGR
@@ -94,6 +96,22 @@ class TLESSDetection(Dataset):
 
         return color_img, target
 
+    def train_valid_indices(self, folder_id=49):
+        """
+        Divides the TLESS synthetic dataset into a training and validation subsets.
+        :param id: (int) number identifying the folder to use as validation data (last folder 49 by default).
+        :return: (tuple) pair of lists corresponding to the indices of the training and validation sets respectively.
+        """
+        train_ind = list()
+        valid_ind = list()
+        for ind, file in enumerate(self.rgb_paths):
+            path, _ = os.path.split(file)
+            if str(folder_id) in path:
+                valid_ind.append(ind)
+            else:
+                train_ind.append(ind)
+        return train_ind, valid_ind
+
 
 def normalize_bboxes(bboxes, im_height, im_width):
     """
diff --git a/synthetic_ssd/models/__init__.py b/synthetic_ssd/models/__init__.py
index 92c225878515ea2e7926a9afe9b4741a8a4e958e..6c7f439c5b3f73ab07971156005451078ba76a87 100644
--- a/synthetic_ssd/models/__init__.py
+++ b/synthetic_ssd/models/__init__.py
@@ -1,49 +1,85 @@
 from torchvision.models.detection import maskrcnn_resnet50_fpn
+from torchvision.models import mobilenet_v2
+import torch
 
 from .mobilenet_v3 import MobileNetV3
 from .mobilenet_v2_ssd_lite import mobilenet_v2_ssd_lite
 from .mobilenet_v3_ssd_lite import mobilenet_v3_ssd_lite
 from . import mobilenetv2_ssd_config, mobilenetv3_ssd_config
 from .box_utils import generate_ssd_priors
-
-from torchvision.models import mobilenet_v2
+from .ssd import MatchPrior
+from ..config import WEIGHTS_DIR
+from .multibox_loss import MultiboxLoss
 
 
 def create_model(opt):
     num_classes_with_bkgd = opt.num_classes + 1
-    if not opt.train:
-        if opt.model == 'mask_rcnn':
-            # Load classes info
-            model = maskrcnn_resnet50_fpn(pretrained=False,
-                                          pretrained_backbone=False,
-                                          num_classes=num_classes_with_bkgd,
-                                          box_score_thresh=0.1)
-            opt.priors = None
-        elif opt.model == 'mobilenet_v2_ssd':
-            backbone = mobilenet_v2(pretrained=True)
-            priors = generate_ssd_priors(mobilenetv2_ssd_config.specs, opt.input_size)
-            model = mobilenet_v2_ssd_lite(num_classes=num_classes_with_bkgd,
-                                          base_net=backbone,
-                                          is_test=False,
-                                          config=mobilenetv2_ssd_config,
-                                          priors=priors,
-                                          device=opt.device
-                                          )
-            opt.priors = priors
-        elif 'mobilenet_v3' in opt.model:
-            if 'small' in opt.model:
-                backbone = MobileNetV3(mode='small')
-            else:
-                backbone = MobileNetV3(mode='large')
-            priors = generate_ssd_priors(mobilenetv3_ssd_config.specs, opt.input_size)
-            model = mobilenet_v3_ssd_lite(num_classes=num_classes_with_bkgd,
-                                          base_net=backbone,
-                                          is_test=False,
-                                          config=mobilenetv3_ssd_config,
-                                          priors=priors,
-                                          device=opt.device
-                                          )
-            opt.priors = priors
+    if opt.model == 'mobilenet_v2_ssd':
+        backbone = mobilenet_v2(pretrained=True)
+        priors = generate_ssd_priors(mobilenetv2_ssd_config.specs, opt.input_size)
+        model = mobilenet_v2_ssd_lite(num_classes=num_classes_with_bkgd,
+                                      base_net=backbone,
+                                      is_test=False,
+                                      config=mobilenetv2_ssd_config,
+                                      priors=priors,
+                                      device=opt.device
+                                      )
+        opt.priors = priors
+        opt.model_config = mobilenetv2_ssd_config
+        if opt.train:
+            opt.target_tf = MatchPrior(opt.priors,
+                                       mobilenetv3_ssd_config.center_variance,
+                                       mobilenetv3_ssd_config.size_variance,
+                                       mobilenetv3_ssd_config.iou_threshold
+                                       )
+        else:
+            opt.target_tf = None
+    elif 'mobilenet_v3' in opt.model:
+        if 'small' in opt.model:
+            backbone = MobileNetV3(mode='small')
+            weights_path = WEIGHTS_DIR / 'imagenet-mobilenetv3-small.pth'
+        else:
+            backbone = MobileNetV3(mode='large')
+            weights_path = WEIGHTS_DIR / 'imagenet-mobilenetv3-large.pth'
+
+        if opt.train:
+            # Load ImageNet pretrained weights
+            state = torch.load(weights_path)
+            copy_state = dict()
+
+            for k in state.keys():
+                if k.startswith("classifier"):
+                    # remove classifier weights (trained for 1000 classes)
+                    continue
+                copy_state[k] = state[k]
+
+            backbone.load_state_dict(copy_state, strict=False)
+
+        priors = generate_ssd_priors(mobilenetv3_ssd_config.specs, opt.input_size)
+        model = mobilenet_v3_ssd_lite(num_classes=num_classes_with_bkgd,
+                                      base_net=backbone,
+                                      is_test=False,
+                                      config=mobilenetv3_ssd_config,
+                                      priors=priors,
+                                      device=opt.device
+                                      )
+        opt.priors = priors
+        opt.model_config = mobilenetv3_ssd_config
+        if opt.train:
+            opt.target_tf = MatchPrior(opt.priors,
+                                       mobilenetv3_ssd_config.center_variance,
+                                       mobilenetv3_ssd_config.size_variance,
+                                       mobilenetv3_ssd_config.iou_threshold
+                                       )
         else:
-            raise ValueError(f"Model {opt.model} is not available.")
+            opt.target_tf = None
+    elif opt.model == 'mask_rcnn':
+        assert not opt.train, "Mask R-CNN can only be used in inference mode, not training."
+        model = maskrcnn_resnet50_fpn(pretrained=False,
+                                      pretrained_backbone=False,
+                                      num_classes=num_classes_with_bkgd,
+                                      box_score_thresh=0.1)
+        opt.priors = None
+    else:
+        raise ValueError(f"Model {opt.model} is not available.")
     return model
diff --git a/synthetic_ssd/models/multibox_loss.py b/synthetic_ssd/models/multibox_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39e8002671ef5337fdda17d6aebe21e0059e417
--- /dev/null
+++ b/synthetic_ssd/models/multibox_loss.py
@@ -0,0 +1,56 @@
+"""
+Source: https://github.com/shaoshengsong/MobileNetV3-SSD/blob/0ac9f36cff59c2286cf7555da7719c54d3c88c2c/vision/nn/multibox_loss.py
+"""
+
+import torch.nn as nn
+import torch
+import torch.nn.functional as F
+
+from .box_utils import hard_negative_mining
+
+
+class MultiboxLoss(nn.Module):
+    def __init__(self, priors, neg_pos_ratio, center_variance,
+                 size_variance, device, iou_threshold=0.5):
+        """Implement SSD Multibox Loss.
+
+        Basically, Multibox loss combines classification loss and Smooth L1 regression loss.
+        """
+        super(MultiboxLoss, self).__init__()
+        self.iou_threshold = iou_threshold
+        self.neg_pos_ratio = neg_pos_ratio
+        self.center_variance = center_variance
+        self.size_variance = size_variance
+        self.priors = priors
+        self.priors.to(device)
+
+    def forward(self, predicted_locations, confidence, gt_locations, labels):
+        """Compute classification loss and smooth l1 loss.
+
+        Args:
+            confidence (batch_size, num_priors, num_classes): class predictions.
+            predicted_locations (batch_size, num_priors, 4): predicted locations.
+            labels (batch_size, num_priors): real labels of all the priors.
+            gt_locations (batch_size, num_priors, 4): real boxes corresponding all the priors.
+        """
+        num_classes = confidence.size(2)
+        with torch.no_grad():
+            # derived from cross_entropy=sum(log(p))
+            loss = -F.log_softmax(confidence, dim=2)
+            loss = loss[:, :, 0]
+            mask = hard_negative_mining(loss, labels, self.neg_pos_ratio)
+
+        confidence = confidence[mask, :]
+        classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], reduction='sum')
+        pos_mask = labels > 0
+        predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4)
+        gt_locations = gt_locations[pos_mask, :].reshape(-1, 4)
+        smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, reduction='sum')
+        num_pos = gt_locations.size(0)
+        if num_pos != 0:
+            reg_loss = smooth_l1_loss/num_pos
+            classif_loss = classification_loss/num_pos
+        else:
+            reg_loss = 0
+            classif_loss = 0
+        return reg_loss, classif_loss
diff --git a/synthetic_ssd/scripts/run_eval.py b/synthetic_ssd/scripts/run_eval.py
index c372daac1966be551760bd5ea0021403a9965162..c80f96fe125d623a16891594239a0239d46e7750 100644
--- a/synthetic_ssd/scripts/run_eval.py
+++ b/synthetic_ssd/scripts/run_eval.py
@@ -27,6 +27,8 @@ def main():
                       help="Augmentation method to applied for training (default:'aug2').")
     args.add_argument('--random_seed', type=int, default=3,
                       help="Random seed (default: 3).")
+    args.add_argument('--weights', type=str,
+                      help="Path to the weights file to evaluate.")
     opt = args.parse_args()
     opt.train = False
     opt.num_classes = 30
@@ -47,7 +49,7 @@ def run_evaluation(opt):
 
     # Load model
     model = create_model(opt)
-    model = load_weights(model, model_name=opt.model, tf=opt.tf)
+    model = load_weights(model, model_name=opt.model, tf=opt.tf, weights=opt.weights)
 
     model.to(opt.device)
     model.eval()
diff --git a/synthetic_ssd/scripts/train_ssd.py b/synthetic_ssd/scripts/train_ssd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d88dc3c4fe86e2b44f75b56ba6e31a6d5634f1
--- /dev/null
+++ b/synthetic_ssd/scripts/train_ssd.py
@@ -0,0 +1,260 @@
+import argparse
+import math
+import pprint
+
+import tqdm
+import torch
+from torch.utils.data import DataLoader, Subset
+
+from synthetic_ssd.utils.configuration import initialize
+from synthetic_ssd.utils.evaluate import custom_nms
+from synthetic_ssd.models import create_model, MultiboxLoss
+from synthetic_ssd.datasets.detection_dataset import TLESSDetection, collate_fn
+from synthetic_ssd.datasets.augmentations import train_tf, test_tf_ssd
+from synthetic_ssd.config import SAVE_DIR
+
+from synthetic_ssd.metrics import _init_paths
+from BoundingBox import BoundingBox
+from BoundingBoxes import BoundingBoxes
+from Evaluator import Evaluator
+import utils as metrics_utils
+
+
+def main():
+    args = argparse.ArgumentParser()
+    args.add_argument('--model', type=str, default='mobilenet_v2_ssd',
+                      choices=['mobilenet_v2_ssd', 'mobilenet_v3_ssd', 'mobilenet_v3_small_ssd'],
+                      help="Object detection model (default: 'mobilenet_v2_ssd').")
+    args.add_argument('--tf', type=str, default='aug2', choices=['aug1', 'aug2'],
+                      help="Augmentation method to applied for training (default:'aug2').")
+    args.add_argument('--random_seed', type=int, default=3,
+                      help="Random seed (default: 3).")
+    args.add_argument('--valid_id', type=int, default=49,
+                      help="Folder ID to use for validation (default: 49).")
+    args.add_argument("--num_workers", type=int, default=2)
+    args.add_argument("--no_print", action="store_true",
+                      help="Flag to disable printing. This will also prevent the evaluation step from running.")
+    args.add_argument("--eval_freq", type=int, default=5,
+                      help="Frequency to evaluate the model (default: 5, for every 5 epoch).")
+    args.add_argument("--eval_on_valid", action="store_true",
+                      help="Evaluation is performed on the validation set (synthetic images).")
+    args.add_argument("--eval_on_test", action="store_true",
+                      help="Evaluation is performed on the test set (real images).")
+    args.add_argument("--save_freq", type=int, default=0,
+                      help="Frequency to save the trained model (default: 0, only save model at the end).")
+    # Training parameters
+    args.add_argument('--batch_size', type=int, default=16,
+                      help="Batch size (default: 16).")
+    args.add_argument('--epochs', type=int, default=75,
+                      help="Number of training epochs (default: 75).")
+    args.add_argument('--lr', type=float, default=0.05,
+                      help="Learning rate (default: 0.05).")
+    args.add_argument('--momentum', type=float, default=0.9,
+                      help="Momentum (default: 0.9).")
+    args.add_argument('--weight_decay', type=float, default=0.000012,
+                      help="Weight decay (default: 0.000012).")
+
+    opt = args.parse_args()
+    opt.train = True
+    opt.num_classes = 30
+    opt.input_size = 224
+    run_training(opt)
+
+
+def run_training(opt):
+    initialize(opt)
+
+    if not SAVE_DIR.exists():
+        SAVE_DIR.mkdir(parents=True)
+
+    # Load model
+    model = create_model(opt)  # Also adds target_tf, priors and model config file into opt
+
+    # Load datasets
+    tf = train_tf(opt.input_size)
+    dataset = TLESSDetection(tf=tf, target_tf=opt.target_tf, mode='train')
+    train_ind, valid_ind = dataset.train_valid_indices(folder_id=opt.valid_id)
+
+    train_dataset = Subset(dataset, train_ind)
+    train_loader = DataLoader(train_dataset,
+                              batch_size=opt.batch_size,
+                              shuffle=True,
+                              num_workers=opt.num_workers,
+                              collate_fn=collate_fn)
+
+    if opt.eval_on_valid:
+        valid_dataset = Subset(dataset, valid_ind)
+        valid_loader = DataLoader(valid_dataset,
+                                  batch_size=1,
+                                  shuffle=False,
+                                  num_workers=1,
+                                  collate_fn=collate_fn)
+    if opt.eval_on_test:
+        test_dataset = TLESSDetection(tf=test_tf_ssd(opt.input_size), target_tf=None, mode="test")
+        test_loader = DataLoader(test_dataset,
+                                 batch_size=1,
+                                 shuffle=False,
+                                 num_workers=1,
+                                 collate_fn=collate_fn)
+
+    if not opt.no_print:
+        print(model)
+        print(dataset)
+        pprint.pprint(opt)
+
+    model.to(opt.device)
+    optimizer = torch.optim.SGD(model.parameters(),
+                                lr=opt.lr,
+                                momentum=opt.momentum,
+                                weight_decay=opt.weight_decay)
+
+    criterion = MultiboxLoss(opt.priors,
+                             iou_threshold=opt.model_config.iou_threshold,
+                             neg_pos_ratio=3,
+                             center_variance=opt.model_config.center_variance,
+                             size_variance=opt.model_config.size_variance,
+                             device=opt.device)
+    try:
+        for epoch in range(opt.epochs):
+            print(f"\nTraining epoch {epoch+1}/{opt.epochs}")
+
+            model.train(True)
+            model.is_test = False
+            running_loss = 0.0
+            running_regression_loss = 0.0
+            running_classification_loss = 0.0
+
+            num = 0
+            optimizer.zero_grad()
+
+            for data in tqdm.tqdm(train_loader):
+                images, targets = data
+                boxes = [t['boxes'] for t in targets]
+                labels = [t['labels'] for t in targets]
+                if len(boxes) > 0:
+                    boxes = torch.stack(boxes, dim=0).to(opt.device)
+                    labels = torch.stack(labels, dim=0).to(opt.device)
+
+                confidences, locations = model(images.to(opt.device))
+                regression_loss, classification_loss = criterion(locations, confidences, boxes, labels)
+
+                loss = regression_loss + classification_loss
+                loss.backward()
+                optimizer.step()
+                optimizer.zero_grad()
+
+                running_loss += loss.item()
+                running_regression_loss += regression_loss.item()
+                running_classification_loss += classification_loss.item()
+
+                if not math.isfinite(loss.item()):
+                    print(f"Loss is {loss.item()}, stopping training")
+                    exit(1)
+
+                num += images.size(0)
+            running_loss /= num
+            running_regression_loss /= num
+            running_classification_loss /= num
+            if not opt.no_print:
+                print(f"\nRunning loss: {running_loss} ({running_classification_loss}"
+                      f" classification, {running_regression_loss} regression).\n")
+
+            if opt.save_freq > 0:
+                if ((epoch+1) % opt.save_freq) == 0:
+                    name = f"{opt.model}_epoch{epoch+1}.pth"
+                    torch.save(model.state_dict(), SAVE_DIR / name)
+
+            if ((epoch+1) % opt.eval_freq) == 0 and not opt.no_print:
+                if opt.eval_on_valid:
+                    print("Evaluation on validation set")
+                    run_evaluation(opt, model, valid_loader)
+                if opt.eval_on_test:
+                    print("Evaluation on test set")
+                    run_evaluation(opt, model, test_loader)
+
+    finally:
+        print("Done!")
+        name = f"{opt.model}_epoch{epoch+1}_last.pth"
+        full_path = SAVE_DIR / name
+        torch.save(model.state_dict(), full_path)
+        print(f"Saved trained model: {full_path}")
+
+
+@torch.no_grad()
+def run_evaluation(opt, model, dataloader):
+    model.eval()
+    model.is_test = True
+
+    evaluator = Evaluator()
+    bboxes = BoundingBoxes()
+
+    for data in tqdm.tqdm(dataloader):
+        images, targets = data
+        raw_boxes = [t['raw_boxes'] for t in targets]
+        raw_labels = [t['raw_labels'] for t in targets]
+        image_names = [t['image_name'] for t in targets]
+        image_sizes = [t['image_size'] for t in targets]
+
+        for im in range(len(images)):
+            im_boxes = raw_boxes[im]  # tensor of size (n_boxes, 4)
+            im_labels = raw_labels[im]
+            im_name = image_names[im]
+            im_size = image_sizes[im]
+            for i in range(im_boxes.size(0)):
+                x1, y1, x2, y2 = im_boxes[i]
+                x = (x1.item()+x2.item())/2.
+                y = (y1.item()+y2.item())/2.
+                w = x2.item() - x1.item()
+                h = y2.item() - y1.item()
+                gtBox = BoundingBox(imageName=im_name,
+                                    classId=str(im_labels[i].item()),
+                                    x=x, y=y,
+                                    w=w, h=h,
+                                    typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                    imgSize=(im_size[1], im_size[0]),
+                                    bbType=metrics_utils.BBType.GroundTruth,
+                                    format=metrics_utils.BBFormat.XYWH)
+                bboxes.addBoundingBox(gtBox)
+
+        confidences, locations = model(images.to(opt.device))
+        out_boxes, out_labels, out_confidences, out_names = custom_nms(locations, confidences,
+                                                                       images_names=image_names,
+                                                                       iou_threshold=0.5,
+                                                                       score_threshold=0.1)
+
+        for i in range(len(out_boxes)):
+            im_size = image_sizes[image_names.index(out_names[i])]  # (h, w)
+            x1, y1, x2, y2 = out_boxes[i].cpu()
+            x = (x1.item()+x2.item())/2.
+            y = (y1.item()+y2.item())/2.
+            w = x2.item() - x1.item()
+            h = y2.item() - y1.item()
+            newBox = BoundingBox(imageName=out_names[i],
+                                 classId=str(out_labels[i]),
+                                 x=x, y=y, w=w, h=h,
+                                 typeCoordinates=metrics_utils.CoordinatesType.Relative,
+                                 imgSize=(im_size[1], im_size[0]),
+                                 bbType=metrics_utils.BBType.Detected,
+                                 classConfidence=out_confidences[i],
+                                 format=metrics_utils.BBFormat.XYWH)
+            bboxes.addBoundingBox(newBox)
+
+    metrics = evaluator.GetPascalVOCMetrics(bboxes)
+    AP = []
+    for cls_metrics in metrics:
+        print("Class {cls}: {AP}% AP, {pos} gt positives, "
+              "{TP} true positives, {FP} false positives".
+              format(cls=cls_metrics['class'],
+                     AP=cls_metrics['AP'] * 100,
+                     pos=cls_metrics['total positives'],
+                     TP=cls_metrics['total TP'],
+                     FP=cls_metrics['total FP']
+                     ))
+        AP.append(cls_metrics['AP'])
+    mean_ap = sum(AP) / len(AP)
+    print(f"Average precision: {mean_ap * 100}%")
+
+
+if __name__ == "__main__":
+    main()
+
diff --git a/synthetic_ssd/utils/load_weights.py b/synthetic_ssd/utils/load_weights.py
index 2e8d0c512f1848d0eaceaa51fb4ab4e5ed5f59b8..b70dec51a77d00663324f4c00a0de7cc233d743c 100644
--- a/synthetic_ssd/utils/load_weights.py
+++ b/synthetic_ssd/utils/load_weights.py
@@ -3,41 +3,44 @@ import torch
 from synthetic_ssd.config import WEIGHTS_DIR, MASK_RCNN_PATH
 
 
-def load_weights(model, model_name, tf):
+def load_weights(model, model_name, tf, weights=None):
     if model_name == "mask_rcnn":
-        model = load_mask_rcnn_weights(model)
+        model = load_mask_rcnn_weights(model, weights=weights)
     elif model_name == "mobilenet_v2_ssd":
-        model = load_mobilenet_v2_ssd_weights(model, tf)
+        model = load_mobilenet_v2_ssd_weights(model, tf, weights=weights)
     elif model_name == "mobilenet_v3_ssd":
-        model = load_mobilenet_v3_ssd_weights(model, tf)
+        model = load_mobilenet_v3_ssd_weights(model, tf, weights=weights)
     elif model_name == "mobilenet_v3_small_ssd":
-        model = load_mobilenet_v3_small_ssd_weights(model, tf)
+        model = load_mobilenet_v3_small_ssd_weights(model, tf, weights=weights)
     return model
 
 
-def load_mask_rcnn_weights(model):
+def load_mask_rcnn_weights(model, weights=None):
     weights = MASK_RCNN_PATH
     state_dict = torch.load(weights)
     model.load_state_dict(state_dict['state_dict'])
     return model
 
 
-def load_mobilenet_v2_ssd_weights(model, tf):
-    weights = WEIGHTS_DIR / tf / "tless_icip21_V2ssd.pth"
+def load_mobilenet_v2_ssd_weights(model, tf, weights=None):
+    if weights is not None:
+        weights = WEIGHTS_DIR / tf / "tless_icip21_V2ssd.pth"
     state_dict = torch.load(weights)
     model.load_state_dict(state_dict)
     return model
 
 
-def load_mobilenet_v3_ssd_weights(model, tf):
-    weights = WEIGHTS_DIR / tf / "tless_icip21_V3ssd.pth"
+def load_mobilenet_v3_ssd_weights(model, tf, weights=None):
+    if weights is not None:
+        weights = WEIGHTS_DIR / tf / "tless_icip21_V3ssd.pth"
     state_dict = torch.load(weights)
     model.load_state_dict(state_dict, strict=False)
     return model
 
 
-def load_mobilenet_v3_small_ssd_weights(model, tf):
-    weights = WEIGHTS_DIR / tf / "tless_icip21_V3smallssd.pth"
+def load_mobilenet_v3_small_ssd_weights(model, tf, weights=None):
+    if weights is not None:
+        weights = WEIGHTS_DIR / tf / "tless_icip21_V3smallssd.pth"
     state_dict = torch.load(weights)
     model.load_state_dict(state_dict, strict=False)
     return model