diff --git a/README.md b/README.md index 5fc3fff3dda977c89631cb50eb3cff44b806ba4e..cf8c9e3f465f56269aee2cbda44264008a9f4447 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,30 @@ V3-SSD (aug2) | 46.1 | 4.9 # Training -Coming soon... +To train one of the SSD models, use the trainign script: +``` +python -m synthetic_ssd.scripts.train_ssd \ + -h, --help show this help message and exit + --model {mobilenet_v2_ssd,mobilenet_v3_ssd,mobilenet_v3_small_ssd} + Object detection model (default: 'mobilenet_v2_ssd'). + --random_seed RANDOM_SEED + Random seed (default: 3). + --valid_id VALID_ID Folder ID to use for validation (default: 49). + --num_workers NUM_WORKERS + --no_print + --eval_freq EVAL_FREQ + Frequency to evaluate the model (default: 5, for every 5 epoch). + --eval_on_valid Evaluation is performed on the validation set (synthetic images). + --eval_on_test Evaluation is performed on the test set (real images). + --batch_size BATCH_SIZE + Batch size (default: 16). + --epochs EPOCHS Number of training epochs (default: 75). + --lr LR Learning rate (default: 0.05). + --momentum MOMENTUM Momentum (default: 0.9). + --weight_decay WEIGHT_DECAY + Weight decay (default: 0.000012). +``` +Training uses the proposed augmentation method (aug2). # Citation If you use this code in your research, please cite the paper: diff --git a/data/weights/imagenet-mobilenetv3-large.pth b/data/weights/imagenet-mobilenetv3-large.pth new file mode 100644 index 0000000000000000000000000000000000000000..c941861b17e19545b1920a23ffede01b31854edc Binary files /dev/null and b/data/weights/imagenet-mobilenetv3-large.pth differ diff --git a/data/weights/imagenet-mobilenetv3-small.pth b/data/weights/imagenet-mobilenetv3-small.pth new file mode 100644 index 0000000000000000000000000000000000000000..672c0057c5b8d7e00b2afee5ce3991f44af6bf50 Binary files /dev/null and b/data/weights/imagenet-mobilenetv3-small.pth differ diff --git a/synthetic_ssd/config.py b/synthetic_ssd/config.py index cd73b4473b37da99eb1041c9d766673998c15026..29a48fac57274b1cc5a68f5b9663922a3d68dfa5 100644 --- a/synthetic_ssd/config.py +++ b/synthetic_ssd/config.py @@ -13,3 +13,4 @@ TLESS_TEST_PATH = TLESS_BASE_PATH / 'test_primesense' DEPS_DIR = PROJECT_DIR / 'deps' +SAVE_DIR = DATA_DIR / 'output' diff --git a/synthetic_ssd/datasets/detection_dataset.py b/synthetic_ssd/datasets/detection_dataset.py index 402ebb817a3c1b5834ec07d4b8b28fa410d5de24..008e39292654a04d8f486b228cb74cf229224170 100644 --- a/synthetic_ssd/datasets/detection_dataset.py +++ b/synthetic_ssd/datasets/detection_dataset.py @@ -25,6 +25,7 @@ class TLESSDetection(Dataset): self.gt = dict() self.gt['boxes'] = dict() self.gt['labels'] = dict() + annot_gt = None for cur_path in path.iterdir(): if not cur_path.is_dir(): continue @@ -49,6 +50,7 @@ class TLESSDetection(Dataset): labels.append(id_data["obj_id"]) self.gt['boxes'][cur_path.name.zfill(6)][image_id.zfill(6)] = boxes self.gt['labels'][cur_path.name.zfill(6)][image_id.zfill(6)] = labels + assert annot_gt is not None, f"Path {path} to dataset seems incorrect, got 0 annotations." del annot_gt, annot_gt_info self.length = len(self.rgb_paths) @@ -58,7 +60,7 @@ class TLESSDetection(Dataset): return self.length def __repr__(self): - return f"Dataset TLESS of RGB images for object detection: " + return f"Dataset T-LESS of RGB images for object detection: " def __getitem__(self, idx): color_img = cv2.imread(self.rgb_paths[idx])[:, :, :3] # uint8, BGR @@ -94,6 +96,22 @@ class TLESSDetection(Dataset): return color_img, target + def train_valid_indices(self, folder_id=49): + """ + Divides the TLESS synthetic dataset into a training and validation subsets. + :param id: (int) number identifying the folder to use as validation data (last folder 49 by default). + :return: (tuple) pair of lists corresponding to the indices of the training and validation sets respectively. + """ + train_ind = list() + valid_ind = list() + for ind, file in enumerate(self.rgb_paths): + path, _ = os.path.split(file) + if str(folder_id) in path: + valid_ind.append(ind) + else: + train_ind.append(ind) + return train_ind, valid_ind + def normalize_bboxes(bboxes, im_height, im_width): """ diff --git a/synthetic_ssd/models/__init__.py b/synthetic_ssd/models/__init__.py index 92c225878515ea2e7926a9afe9b4741a8a4e958e..6c7f439c5b3f73ab07971156005451078ba76a87 100644 --- a/synthetic_ssd/models/__init__.py +++ b/synthetic_ssd/models/__init__.py @@ -1,49 +1,85 @@ from torchvision.models.detection import maskrcnn_resnet50_fpn +from torchvision.models import mobilenet_v2 +import torch from .mobilenet_v3 import MobileNetV3 from .mobilenet_v2_ssd_lite import mobilenet_v2_ssd_lite from .mobilenet_v3_ssd_lite import mobilenet_v3_ssd_lite from . import mobilenetv2_ssd_config, mobilenetv3_ssd_config from .box_utils import generate_ssd_priors - -from torchvision.models import mobilenet_v2 +from .ssd import MatchPrior +from ..config import WEIGHTS_DIR +from .multibox_loss import MultiboxLoss def create_model(opt): num_classes_with_bkgd = opt.num_classes + 1 - if not opt.train: - if opt.model == 'mask_rcnn': - # Load classes info - model = maskrcnn_resnet50_fpn(pretrained=False, - pretrained_backbone=False, - num_classes=num_classes_with_bkgd, - box_score_thresh=0.1) - opt.priors = None - elif opt.model == 'mobilenet_v2_ssd': - backbone = mobilenet_v2(pretrained=True) - priors = generate_ssd_priors(mobilenetv2_ssd_config.specs, opt.input_size) - model = mobilenet_v2_ssd_lite(num_classes=num_classes_with_bkgd, - base_net=backbone, - is_test=False, - config=mobilenetv2_ssd_config, - priors=priors, - device=opt.device - ) - opt.priors = priors - elif 'mobilenet_v3' in opt.model: - if 'small' in opt.model: - backbone = MobileNetV3(mode='small') - else: - backbone = MobileNetV3(mode='large') - priors = generate_ssd_priors(mobilenetv3_ssd_config.specs, opt.input_size) - model = mobilenet_v3_ssd_lite(num_classes=num_classes_with_bkgd, - base_net=backbone, - is_test=False, - config=mobilenetv3_ssd_config, - priors=priors, - device=opt.device - ) - opt.priors = priors + if opt.model == 'mobilenet_v2_ssd': + backbone = mobilenet_v2(pretrained=True) + priors = generate_ssd_priors(mobilenetv2_ssd_config.specs, opt.input_size) + model = mobilenet_v2_ssd_lite(num_classes=num_classes_with_bkgd, + base_net=backbone, + is_test=False, + config=mobilenetv2_ssd_config, + priors=priors, + device=opt.device + ) + opt.priors = priors + opt.model_config = mobilenetv2_ssd_config + if opt.train: + opt.target_tf = MatchPrior(opt.priors, + mobilenetv3_ssd_config.center_variance, + mobilenetv3_ssd_config.size_variance, + mobilenetv3_ssd_config.iou_threshold + ) + else: + opt.target_tf = None + elif 'mobilenet_v3' in opt.model: + if 'small' in opt.model: + backbone = MobileNetV3(mode='small') + weights_path = WEIGHTS_DIR / 'imagenet-mobilenetv3-small.pth' + else: + backbone = MobileNetV3(mode='large') + weights_path = WEIGHTS_DIR / 'imagenet-mobilenetv3-large.pth' + + if opt.train: + # Load ImageNet pretrained weights + state = torch.load(weights_path) + copy_state = dict() + + for k in state.keys(): + if k.startswith("classifier"): + # remove classifier weights (trained for 1000 classes) + continue + copy_state[k] = state[k] + + backbone.load_state_dict(copy_state, strict=False) + + priors = generate_ssd_priors(mobilenetv3_ssd_config.specs, opt.input_size) + model = mobilenet_v3_ssd_lite(num_classes=num_classes_with_bkgd, + base_net=backbone, + is_test=False, + config=mobilenetv3_ssd_config, + priors=priors, + device=opt.device + ) + opt.priors = priors + opt.model_config = mobilenetv3_ssd_config + if opt.train: + opt.target_tf = MatchPrior(opt.priors, + mobilenetv3_ssd_config.center_variance, + mobilenetv3_ssd_config.size_variance, + mobilenetv3_ssd_config.iou_threshold + ) else: - raise ValueError(f"Model {opt.model} is not available.") + opt.target_tf = None + elif opt.model == 'mask_rcnn': + assert not opt.train, "Mask R-CNN can only be used in inference mode, not training." + model = maskrcnn_resnet50_fpn(pretrained=False, + pretrained_backbone=False, + num_classes=num_classes_with_bkgd, + box_score_thresh=0.1) + opt.priors = None + else: + raise ValueError(f"Model {opt.model} is not available.") return model diff --git a/synthetic_ssd/models/multibox_loss.py b/synthetic_ssd/models/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b39e8002671ef5337fdda17d6aebe21e0059e417 --- /dev/null +++ b/synthetic_ssd/models/multibox_loss.py @@ -0,0 +1,56 @@ +""" +Source: https://github.com/shaoshengsong/MobileNetV3-SSD/blob/0ac9f36cff59c2286cf7555da7719c54d3c88c2c/vision/nn/multibox_loss.py +""" + +import torch.nn as nn +import torch +import torch.nn.functional as F + +from .box_utils import hard_negative_mining + + +class MultiboxLoss(nn.Module): + def __init__(self, priors, neg_pos_ratio, center_variance, + size_variance, device, iou_threshold=0.5): + """Implement SSD Multibox Loss. + + Basically, Multibox loss combines classification loss and Smooth L1 regression loss. + """ + super(MultiboxLoss, self).__init__() + self.iou_threshold = iou_threshold + self.neg_pos_ratio = neg_pos_ratio + self.center_variance = center_variance + self.size_variance = size_variance + self.priors = priors + self.priors.to(device) + + def forward(self, predicted_locations, confidence, gt_locations, labels): + """Compute classification loss and smooth l1 loss. + + Args: + confidence (batch_size, num_priors, num_classes): class predictions. + predicted_locations (batch_size, num_priors, 4): predicted locations. + labels (batch_size, num_priors): real labels of all the priors. + gt_locations (batch_size, num_priors, 4): real boxes corresponding all the priors. + """ + num_classes = confidence.size(2) + with torch.no_grad(): + # derived from cross_entropy=sum(log(p)) + loss = -F.log_softmax(confidence, dim=2) + loss = loss[:, :, 0] + mask = hard_negative_mining(loss, labels, self.neg_pos_ratio) + + confidence = confidence[mask, :] + classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], reduction='sum') + pos_mask = labels > 0 + predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4) + gt_locations = gt_locations[pos_mask, :].reshape(-1, 4) + smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, reduction='sum') + num_pos = gt_locations.size(0) + if num_pos != 0: + reg_loss = smooth_l1_loss/num_pos + classif_loss = classification_loss/num_pos + else: + reg_loss = 0 + classif_loss = 0 + return reg_loss, classif_loss diff --git a/synthetic_ssd/scripts/run_eval.py b/synthetic_ssd/scripts/run_eval.py index c372daac1966be551760bd5ea0021403a9965162..c80f96fe125d623a16891594239a0239d46e7750 100644 --- a/synthetic_ssd/scripts/run_eval.py +++ b/synthetic_ssd/scripts/run_eval.py @@ -27,6 +27,8 @@ def main(): help="Augmentation method to applied for training (default:'aug2').") args.add_argument('--random_seed', type=int, default=3, help="Random seed (default: 3).") + args.add_argument('--weights', type=str, + help="Path to the weights file to evaluate.") opt = args.parse_args() opt.train = False opt.num_classes = 30 @@ -47,7 +49,7 @@ def run_evaluation(opt): # Load model model = create_model(opt) - model = load_weights(model, model_name=opt.model, tf=opt.tf) + model = load_weights(model, model_name=opt.model, tf=opt.tf, weights=opt.weights) model.to(opt.device) model.eval() diff --git a/synthetic_ssd/scripts/train_ssd.py b/synthetic_ssd/scripts/train_ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d88dc3c4fe86e2b44f75b56ba6e31a6d5634f1 --- /dev/null +++ b/synthetic_ssd/scripts/train_ssd.py @@ -0,0 +1,260 @@ +import argparse +import math +import pprint + +import tqdm +import torch +from torch.utils.data import DataLoader, Subset + +from synthetic_ssd.utils.configuration import initialize +from synthetic_ssd.utils.evaluate import custom_nms +from synthetic_ssd.models import create_model, MultiboxLoss +from synthetic_ssd.datasets.detection_dataset import TLESSDetection, collate_fn +from synthetic_ssd.datasets.augmentations import train_tf, test_tf_ssd +from synthetic_ssd.config import SAVE_DIR + +from synthetic_ssd.metrics import _init_paths +from BoundingBox import BoundingBox +from BoundingBoxes import BoundingBoxes +from Evaluator import Evaluator +import utils as metrics_utils + + +def main(): + args = argparse.ArgumentParser() + args.add_argument('--model', type=str, default='mobilenet_v2_ssd', + choices=['mobilenet_v2_ssd', 'mobilenet_v3_ssd', 'mobilenet_v3_small_ssd'], + help="Object detection model (default: 'mobilenet_v2_ssd').") + args.add_argument('--tf', type=str, default='aug2', choices=['aug1', 'aug2'], + help="Augmentation method to applied for training (default:'aug2').") + args.add_argument('--random_seed', type=int, default=3, + help="Random seed (default: 3).") + args.add_argument('--valid_id', type=int, default=49, + help="Folder ID to use for validation (default: 49).") + args.add_argument("--num_workers", type=int, default=2) + args.add_argument("--no_print", action="store_true", + help="Flag to disable printing. This will also prevent the evaluation step from running.") + args.add_argument("--eval_freq", type=int, default=5, + help="Frequency to evaluate the model (default: 5, for every 5 epoch).") + args.add_argument("--eval_on_valid", action="store_true", + help="Evaluation is performed on the validation set (synthetic images).") + args.add_argument("--eval_on_test", action="store_true", + help="Evaluation is performed on the test set (real images).") + args.add_argument("--save_freq", type=int, default=0, + help="Frequency to save the trained model (default: 0, only save model at the end).") + # Training parameters + args.add_argument('--batch_size', type=int, default=16, + help="Batch size (default: 16).") + args.add_argument('--epochs', type=int, default=75, + help="Number of training epochs (default: 75).") + args.add_argument('--lr', type=float, default=0.05, + help="Learning rate (default: 0.05).") + args.add_argument('--momentum', type=float, default=0.9, + help="Momentum (default: 0.9).") + args.add_argument('--weight_decay', type=float, default=0.000012, + help="Weight decay (default: 0.000012).") + + opt = args.parse_args() + opt.train = True + opt.num_classes = 30 + opt.input_size = 224 + run_training(opt) + + +def run_training(opt): + initialize(opt) + + if not SAVE_DIR.exists(): + SAVE_DIR.mkdir(parents=True) + + # Load model + model = create_model(opt) # Also adds target_tf, priors and model config file into opt + + # Load datasets + tf = train_tf(opt.input_size) + dataset = TLESSDetection(tf=tf, target_tf=opt.target_tf, mode='train') + train_ind, valid_ind = dataset.train_valid_indices(folder_id=opt.valid_id) + + train_dataset = Subset(dataset, train_ind) + train_loader = DataLoader(train_dataset, + batch_size=opt.batch_size, + shuffle=True, + num_workers=opt.num_workers, + collate_fn=collate_fn) + + if opt.eval_on_valid: + valid_dataset = Subset(dataset, valid_ind) + valid_loader = DataLoader(valid_dataset, + batch_size=1, + shuffle=False, + num_workers=1, + collate_fn=collate_fn) + if opt.eval_on_test: + test_dataset = TLESSDetection(tf=test_tf_ssd(opt.input_size), target_tf=None, mode="test") + test_loader = DataLoader(test_dataset, + batch_size=1, + shuffle=False, + num_workers=1, + collate_fn=collate_fn) + + if not opt.no_print: + print(model) + print(dataset) + pprint.pprint(opt) + + model.to(opt.device) + optimizer = torch.optim.SGD(model.parameters(), + lr=opt.lr, + momentum=opt.momentum, + weight_decay=opt.weight_decay) + + criterion = MultiboxLoss(opt.priors, + iou_threshold=opt.model_config.iou_threshold, + neg_pos_ratio=3, + center_variance=opt.model_config.center_variance, + size_variance=opt.model_config.size_variance, + device=opt.device) + try: + for epoch in range(opt.epochs): + print(f"\nTraining epoch {epoch+1}/{opt.epochs}") + + model.train(True) + model.is_test = False + running_loss = 0.0 + running_regression_loss = 0.0 + running_classification_loss = 0.0 + + num = 0 + optimizer.zero_grad() + + for data in tqdm.tqdm(train_loader): + images, targets = data + boxes = [t['boxes'] for t in targets] + labels = [t['labels'] for t in targets] + if len(boxes) > 0: + boxes = torch.stack(boxes, dim=0).to(opt.device) + labels = torch.stack(labels, dim=0).to(opt.device) + + confidences, locations = model(images.to(opt.device)) + regression_loss, classification_loss = criterion(locations, confidences, boxes, labels) + + loss = regression_loss + classification_loss + loss.backward() + optimizer.step() + optimizer.zero_grad() + + running_loss += loss.item() + running_regression_loss += regression_loss.item() + running_classification_loss += classification_loss.item() + + if not math.isfinite(loss.item()): + print(f"Loss is {loss.item()}, stopping training") + exit(1) + + num += images.size(0) + running_loss /= num + running_regression_loss /= num + running_classification_loss /= num + if not opt.no_print: + print(f"\nRunning loss: {running_loss} ({running_classification_loss}" + f" classification, {running_regression_loss} regression).\n") + + if opt.save_freq > 0: + if ((epoch+1) % opt.save_freq) == 0: + name = f"{opt.model}_epoch{epoch+1}.pth" + torch.save(model.state_dict(), SAVE_DIR / name) + + if ((epoch+1) % opt.eval_freq) == 0 and not opt.no_print: + if opt.eval_on_valid: + print("Evaluation on validation set") + run_evaluation(opt, model, valid_loader) + if opt.eval_on_test: + print("Evaluation on test set") + run_evaluation(opt, model, test_loader) + + finally: + print("Done!") + name = f"{opt.model}_epoch{epoch+1}_last.pth" + full_path = SAVE_DIR / name + torch.save(model.state_dict(), full_path) + print(f"Saved trained model: {full_path}") + + +@torch.no_grad() +def run_evaluation(opt, model, dataloader): + model.eval() + model.is_test = True + + evaluator = Evaluator() + bboxes = BoundingBoxes() + + for data in tqdm.tqdm(dataloader): + images, targets = data + raw_boxes = [t['raw_boxes'] for t in targets] + raw_labels = [t['raw_labels'] for t in targets] + image_names = [t['image_name'] for t in targets] + image_sizes = [t['image_size'] for t in targets] + + for im in range(len(images)): + im_boxes = raw_boxes[im] # tensor of size (n_boxes, 4) + im_labels = raw_labels[im] + im_name = image_names[im] + im_size = image_sizes[im] + for i in range(im_boxes.size(0)): + x1, y1, x2, y2 = im_boxes[i] + x = (x1.item()+x2.item())/2. + y = (y1.item()+y2.item())/2. + w = x2.item() - x1.item() + h = y2.item() - y1.item() + gtBox = BoundingBox(imageName=im_name, + classId=str(im_labels[i].item()), + x=x, y=y, + w=w, h=h, + typeCoordinates=metrics_utils.CoordinatesType.Relative, + imgSize=(im_size[1], im_size[0]), + bbType=metrics_utils.BBType.GroundTruth, + format=metrics_utils.BBFormat.XYWH) + bboxes.addBoundingBox(gtBox) + + confidences, locations = model(images.to(opt.device)) + out_boxes, out_labels, out_confidences, out_names = custom_nms(locations, confidences, + images_names=image_names, + iou_threshold=0.5, + score_threshold=0.1) + + for i in range(len(out_boxes)): + im_size = image_sizes[image_names.index(out_names[i])] # (h, w) + x1, y1, x2, y2 = out_boxes[i].cpu() + x = (x1.item()+x2.item())/2. + y = (y1.item()+y2.item())/2. + w = x2.item() - x1.item() + h = y2.item() - y1.item() + newBox = BoundingBox(imageName=out_names[i], + classId=str(out_labels[i]), + x=x, y=y, w=w, h=h, + typeCoordinates=metrics_utils.CoordinatesType.Relative, + imgSize=(im_size[1], im_size[0]), + bbType=metrics_utils.BBType.Detected, + classConfidence=out_confidences[i], + format=metrics_utils.BBFormat.XYWH) + bboxes.addBoundingBox(newBox) + + metrics = evaluator.GetPascalVOCMetrics(bboxes) + AP = [] + for cls_metrics in metrics: + print("Class {cls}: {AP}% AP, {pos} gt positives, " + "{TP} true positives, {FP} false positives". + format(cls=cls_metrics['class'], + AP=cls_metrics['AP'] * 100, + pos=cls_metrics['total positives'], + TP=cls_metrics['total TP'], + FP=cls_metrics['total FP'] + )) + AP.append(cls_metrics['AP']) + mean_ap = sum(AP) / len(AP) + print(f"Average precision: {mean_ap * 100}%") + + +if __name__ == "__main__": + main() + diff --git a/synthetic_ssd/utils/load_weights.py b/synthetic_ssd/utils/load_weights.py index 2e8d0c512f1848d0eaceaa51fb4ab4e5ed5f59b8..b70dec51a77d00663324f4c00a0de7cc233d743c 100644 --- a/synthetic_ssd/utils/load_weights.py +++ b/synthetic_ssd/utils/load_weights.py @@ -3,41 +3,44 @@ import torch from synthetic_ssd.config import WEIGHTS_DIR, MASK_RCNN_PATH -def load_weights(model, model_name, tf): +def load_weights(model, model_name, tf, weights=None): if model_name == "mask_rcnn": - model = load_mask_rcnn_weights(model) + model = load_mask_rcnn_weights(model, weights=weights) elif model_name == "mobilenet_v2_ssd": - model = load_mobilenet_v2_ssd_weights(model, tf) + model = load_mobilenet_v2_ssd_weights(model, tf, weights=weights) elif model_name == "mobilenet_v3_ssd": - model = load_mobilenet_v3_ssd_weights(model, tf) + model = load_mobilenet_v3_ssd_weights(model, tf, weights=weights) elif model_name == "mobilenet_v3_small_ssd": - model = load_mobilenet_v3_small_ssd_weights(model, tf) + model = load_mobilenet_v3_small_ssd_weights(model, tf, weights=weights) return model -def load_mask_rcnn_weights(model): +def load_mask_rcnn_weights(model, weights=None): weights = MASK_RCNN_PATH state_dict = torch.load(weights) model.load_state_dict(state_dict['state_dict']) return model -def load_mobilenet_v2_ssd_weights(model, tf): - weights = WEIGHTS_DIR / tf / "tless_icip21_V2ssd.pth" +def load_mobilenet_v2_ssd_weights(model, tf, weights=None): + if weights is not None: + weights = WEIGHTS_DIR / tf / "tless_icip21_V2ssd.pth" state_dict = torch.load(weights) model.load_state_dict(state_dict) return model -def load_mobilenet_v3_ssd_weights(model, tf): - weights = WEIGHTS_DIR / tf / "tless_icip21_V3ssd.pth" +def load_mobilenet_v3_ssd_weights(model, tf, weights=None): + if weights is not None: + weights = WEIGHTS_DIR / tf / "tless_icip21_V3ssd.pth" state_dict = torch.load(weights) model.load_state_dict(state_dict, strict=False) return model -def load_mobilenet_v3_small_ssd_weights(model, tf): - weights = WEIGHTS_DIR / tf / "tless_icip21_V3smallssd.pth" +def load_mobilenet_v3_small_ssd_weights(model, tf, weights=None): + if weights is not None: + weights = WEIGHTS_DIR / tf / "tless_icip21_V3smallssd.pth" state_dict = torch.load(weights) model.load_state_dict(state_dict, strict=False) return model