Skip to content
Snippets Groups Projects
dataloading.py 9.96 KiB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import os
import random
import uuid

import numpy as np

import torch
from torch.utils.data.dataloader import DataLoader as torchDataLoader
from torch.utils.data.dataloader import default_collate
import operator

from detectron2.data.build import (
    AspectRatioGroupedDataset,
    worker_init_reset_seed,
    trivial_batch_collator,
    InferenceSampler,
)

from core.utils.my_comm import get_world_size

from .samplers import YoloBatchSampler, InfiniteSampler

# from .datasets import Base_DatasetFromList


class DataLoader(torchDataLoader):
    """Lightnet dataloader that enables on the fly resizing of the images.

    See :class:`torch.utils.data.DataLoader` for more information on the arguments.
    Check more on the following website:
    https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__initialized = False
        shuffle = False
        batch_sampler = None
        if len(args) > 5:
            shuffle = args[2]
            sampler = args[3]
            batch_sampler = args[4]
        elif len(args) > 4:
            shuffle = args[2]
            sampler = args[3]
            if "batch_sampler" in kwargs:
                batch_sampler = kwargs["batch_sampler"]
        elif len(args) > 3:
            shuffle = args[2]
            if "sampler" in kwargs:
                sampler = kwargs["sampler"]
            if "batch_sampler" in kwargs:
                batch_sampler = kwargs["batch_sampler"]
        else:
            if "shuffle" in kwargs:
                shuffle = kwargs["shuffle"]
            if "sampler" in kwargs:
                sampler = kwargs["sampler"]
            if "batch_sampler" in kwargs:
                batch_sampler = kwargs["batch_sampler"]

        # Use custom BatchSampler
        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
                    # sampler = torch.utils.data.DistributedSampler(self.dataset)
                else:
                    sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
            batch_sampler = YoloBatchSampler(
                sampler,
                self.batch_size,
                self.drop_last,
                input_dimension=self.dataset.input_dim,
            )
            # batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =

        self.batch_sampler = batch_sampler

        self.__initialized = True

    def close_mosaic(self):
        self.batch_sampler.mosaic = False


# def list_collate(batch):
#     """
#     Function that collates lists or tuples together into one list (of lists/tuples).
#     Use this as the collate function in a Dataloader, if you want to have a list of
#     items as an output, as opposed to tensors (eg. Brambox.boxes).
#     """
#     items = list(zip(*batch))

#     for i in range(len(items)):
#         if isinstance(items[i][0], (list, tuple)):
#             items[i] = list(items[i])
#         else:
#             items[i] = default_collate(items[i])

#     return items


def build_yolox_batch_data_loader(
    dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0, pin_memory=False
):
    """
    Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are:
    1. support aspect ratio grouping options
    2. use no "batch collation", because this is common for detection training

    Args:
        dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
        sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
        total_batch_size, aspect_ratio_grouping, num_workers): see
            :func:`build_detection_train_loader`.

    Returns:
        iterable[list]. Length of each list is the batch size of the current
            GPU. Each element in the list comes from the dataset.
    """
    world_size = get_world_size()
    assert (
        total_batch_size > 0 and total_batch_size % world_size == 0
    ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(total_batch_size, world_size)

    batch_size = total_batch_size // world_size
    if aspect_ratio_grouping:
        data_loader = torch.utils.data.DataLoader(
            dataset,
            sampler=sampler,
            num_workers=num_workers,
            batch_sampler=None,
            collate_fn=operator.itemgetter(0),  # don't batch, but yield individual elements
            worker_init_fn=worker_init_reset_seed,
        )  # yield individual mapped dict
        return AspectRatioGroupedDataset(data_loader, batch_size)
    else:
        # batch_sampler = torch.utils.data.sampler.BatchSampler(
        #     sampler, batch_size, drop_last=True
        # )  # drop_last so the batch always have the same size
        if hasattr(dataset, "enable_mosaic"):
            mosaic = dataset.enable_mosaic
        else:
            mosaic = False
        batch_sampler = YoloBatchSampler(
            mosaic=mosaic,
            sampler=sampler,
            batch_size=batch_size,
            drop_last=False,  # NOTE: different to d2
            # input_dimension=dataset.input_dim,
        )
        return DataLoader(
            dataset,
            num_workers=num_workers,
            batch_sampler=batch_sampler,
            # collate_fn=trivial_batch_collator,  # TODO: use this when item is changed to dict
            worker_init_fn=worker_init_reset_seed,
            pin_memory=pin_memory,
        )


def build_yolox_train_loader(
    dataset,
    *,
    aug_wrapper,
    total_batch_size,
    sampler=None,
    aspect_ratio_grouping=False,
    num_workers=0,
    pin_memory=False,
    seed=None
):
    """Build a dataloader for object detection with some default features. This
    interface is experimental.

    Args:
        dataset (torch.utils.data.Dataset): Base_DatasetFromList
        aug_wrapper (callable): MosaciDetection
        total_batch_size (int): total batch size across all workers. Batching
            simply puts data into a list.
        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
            indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
            which coordinates an infinite random shuffle sequence across all workers.
        aspect_ratio_grouping (bool): whether to group images with similar
            aspect ratio for efficiency. When enabled, it requires each
            element in dataset be a dict with keys "width" and "height".
        num_workers (int): number of parallel data loading workers

    Returns:
        torch.utils.data.DataLoader:
            a dataloader. Each output from it is a ``list[mapped_element]`` of length
            ``total_batch_size / num_workers``, where ``mapped_element`` is produced
            by the ``mapper``.
    """

    if aug_wrapper is not None:
        # MosaicDetection (mosaic, mixup, other augs)
        dataset = aug_wrapper.init_dataset(dataset)

    if sampler is None:
        # sampler = TrainingSampler(len(dataset))
        sampler = InfiniteSampler(len(dataset), seed=0 if seed is None else seed)
    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
    return build_yolox_batch_data_loader(
        dataset,
        sampler,
        total_batch_size,
        aspect_ratio_grouping=aspect_ratio_grouping,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )


def build_yolox_test_loader(
    dataset, *, aug_wrapper=None, total_batch_size=1, sampler=None, num_workers=0, pin_memory=False
):
    """Similar to `build_detection_train_loader`, but uses a batch size of 1,
    and :class:`InferenceSampler`. This sampler coordinates all workers to
    produce the exact set of all samples. This interface is experimental.

    Args:
        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
            or a map-style pytorch dataset. They can be obtained by using
            :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
        aug_wrapper (callable): MosaciDetection
        total_batch_size (int): total batch size across all workers. Batching
            simply puts data into a list. Default test batch size is 1.
        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
            indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
            which splits the dataset across all workers.
        num_workers (int): number of parallel data loading workers

    Returns:
        DataLoader: a torch DataLoader, that loads the given detection
        dataset, with test-time transformation and batching.

    Examples:
    ::
        data_loader = build_detection_test_loader(
            DatasetRegistry.get("my_test"),
            mapper=DatasetMapper(...))

        # or, instantiate with a CfgNode:
        data_loader = build_detection_test_loader(cfg, "my_test")
    """
    if aug_wrapper is not None:
        # MosaicDetection (mosaic, mixup, other augs)
        dataset = aug_wrapper.init_dataset(dataset)

    world_size = get_world_size()
    batch_size = total_batch_size // world_size
    if sampler is None:
        sampler = InferenceSampler(len(dataset))
    # Always use 1 image per worker during inference since this is the
    # standard when reporting inference time in papers.
    batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, batch_size, drop_last=False)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        # batch_size=batch_size,
        num_workers=num_workers,
        batch_sampler=batch_sampler,
        # collate_fn=trivial_batch_collator,
        pin_memory=pin_memory,
    )
    return data_loader


def yolox_worker_init_reset_seed(worker_id):
    seed = uuid.uuid4().int % 2**32
    random.seed(seed)
    torch.set_rng_state(torch.manual_seed(seed).get_state())
    np.random.seed(seed)