OSRT: Object Scene Representation Transformer
This is an independent PyTorch implementation of OSRT, as presented in the paper "Object Scene Representation Transformer" by Sajjadi et al. All credit for the model goes to the original authors.
Setup
git clone -r git@github.com:alexcbb/OSRT-experiments.git
After cloning the repository and creating a new conda environment, the following steps will get you started:
Data
The code currently supports the following datasets. Simply download and place (or symlink) them in the data directory.
- The 3D datasets introduced by ObSuRF.
- OSRT's MultiShapeNet (MSN-hard) dataset. It may be downloaded via gsutil:
pip install gsutil
mkdir -p data/osrt/multi_shapenet_frames/
gsutil -m cp -r gs://kubric-public/tfds/kubric-frames/multi_shapenet_conditional/2.8.0/ data/osrt/multi_shapenet_frames/
Dependencies
This code requires at least Python 3.9 and PyTorch 1.11.
Additional dependencies may be installed via pip -r requirements.txt
. Note that Tensorflow is
required to load OSRT's MultiShapeNet data, though the CPU version suffices.
Rendering videos additionally depends on ffmpeg>=4.3
being available in your $PATH
.
To install Segment Anything dependencies :
cd segment-anything/
pip install -e
Running Experiments
Each run's config, checkpoints, and visualization are stored in a dedicated directory. Recommended configs can be found under runs/[dataset]/[model]
.
Training
To train a model on a single GPU, simply run e.g.:
python train.py runs/clevr3d/osrt/config.yaml
To train on multiple GPUs on a single machine, launch multiple processes via Torchrun, where $NUM_GPUS is the number of GPUs to use:
torchrun --standalone --nnodes 1 --nproc_per_node $NUM_GPUS train.py runs/clevr3d/osrt/config.yaml
Checkpoints are automatically stored in and (if available) loaded from the run directory.
Visualizations and evaluations are produced periodically. Check the args of train.py
for
additional options. Importantly, to log training progress, use the --wandb
flag to enable Weights
& Biases.
Rendering videos
Videos may be rendered using render.py
, e.g.
python render.py runs/msn/osrt/config.yaml --sceneid 1 --motion rotate_and_closeup --fade
Rendered frames and videos are placed in the run directory. Check the args of render.py
for various camera movements,
and compile_video.py
for different ways of compiling videos.
Results
We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes used during training. Due to memory constraints, we were unable to match OSRT's settings on MSN-hard. Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the 8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after 750k iterations. For download, we provide both the checkpoint and a sample video.
To match the memory availability of your hardware, consider adjusting data/num_points
or
training/batch_size
in config.yaml
. However, setting these too low can make the model prone to
getting stuck in local optima, especially early in training.
Citation
@article{sajjadi2022osrt,
author = {Sajjadi, Mehdi S. M.
and Duckworth, Daniel
and Mahendran, Aravindh
and van Steenkiste, Sjoerd
and Paveti{\'c}, Filip
and Lu{\v{c}}i{\'c}, Mario
and Guibas, Leonidas J.
and Greff, Klaus
and Kipf, Thomas
},
title = {{Object Scene Representation Transformer}},
journal = {NeurIPS},
year = {2022}
}