Skip to content
Snippets Groups Projects
Commit 11419cc0 authored by Karl Stelzner's avatar Karl Stelzner
Browse files

Update README and config

parent 8381f28e
No related branches found
No related tags found
No related merge requests found
...@@ -12,16 +12,17 @@ After cloning the repository and creating a new conda environment, the following ...@@ -12,16 +12,17 @@ After cloning the repository and creating a new conda environment, the following
The code currently supports the following datasets. Simply download and place (or symlink) them in the data directory. The code currently supports the following datasets. Simply download and place (or symlink) them in the data directory.
- The 3D datasets introduced by [ObSuRF](https://stelzner.github.io/obsurf/). - The 3D datasets introduced by [ObSuRF](https://stelzner.github.io/obsurf/).
- SRT's [MultiShapeNet (MSN)](https://srt-paper.github.io/#dataset) dataset, specifically version 2.3.3. It may be downloaded via gsutil: - OSRT's [MultiShapeNet (MSN-hard)](https://osrt-paper.github.io/#dataset) dataset. It may be downloaded via gsutil:
``` ```
pip install gsutil pip install gsutil
mkdir -p data/msn/multi_shapenet_frames/ mkdir -p data/osrt/multi_shapenet_frames/
gsutil -m cp -r gs://kubric-public/tfds/multi_shapenet_frames/2.3.3/ data/msn/multi_shapenet_frames/ gsutil -m cp -r gs://kubric-public/tfds/kubric-frames/multi_shapenet_conditional/2.8.0/ data/osrt/multi_shapenet_frames/
``` ```
### Dependencies ### Dependencies
This code requires at least Python 3.9 and [PyTorch 1.11](https://pytorch.org/get-started/locally/). Additional dependencies may be installed via `pip -r requirements.txt`. This code requires at least Python 3.9 and [PyTorch 1.11](https://pytorch.org/get-started/locally/).
Note that Tensorflow is required to load SRT's MultiShapeNet data, though the CPU version suffices. Additional dependencies may be installed via `pip -r requirements.txt`. Note that Tensorflow is
required to load OSRT's MultiShapeNet data, though the CPU version suffices.
Rendering videos additionally depends on `ffmpeg>=4.3` being available in your `$PATH`. Rendering videos additionally depends on `ffmpeg>=4.3` being available in your `$PATH`.
...@@ -37,17 +38,26 @@ To train on multiple GPUs on a single machine, launch multiple processes via [To ...@@ -37,17 +38,26 @@ To train on multiple GPUs on a single machine, launch multiple processes via [To
``` ```
torchrun --standalone --nnodes 1 --nproc_per_node $NUM_GPUS train.py runs/clevr3d/osrt/config.yaml torchrun --standalone --nnodes 1 --nproc_per_node $NUM_GPUS train.py runs/clevr3d/osrt/config.yaml
``` ```
Checkpoints are automatically stored in and (if available) loaded from the run directory. Visualizations and evaluations are produced periodically. Checkpoints are automatically stored in and (if available) loaded from the run directory.
Check the args of `train.py` for additional options. Importantly, to log training progress, use the `--wandb` flag to enable [Weights & Biases](https://wandb.ai). Visualizations and evaluations are produced periodically. Check the args of `train.py` for
additional options. Importantly, to log training progress, use the `--wandb` flag to enable [Weights
& Biases](https://wandb.ai).
### Rendering videos ### Rendering videos
Videos may be rendered using `render.py`, e.g. Videos may be rendered using `render.py`, e.g.
``` ```
python render.py runs/clevr3d/osrt/config.yaml --sceneid 1 --motion rotate_and_closeup --fade python render.py runs/msn/osrt/config.yaml --sceneid 1 --motion rotate_and_closeup --fade
``` ```
Rendered frames and videos are placed in the run directory. Check the args of `render.py` for various camera movements, Rendered frames and videos are placed in the run directory. Check the args of `render.py` for various camera movements,
and `compile_video.py` for different ways of compiling videos. and `compile_video.py` for different ways of compiling videos.
## Results
We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes
used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard.
Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the
8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after
750k iterations. The checkpoint may be downloaded here:
## Citation ## Citation
``` ```
...@@ -63,7 +73,7 @@ and `compile_video.py` for different ways of compiling videos. ...@@ -63,7 +73,7 @@ and `compile_video.py` for different ways of compiling videos.
and Kipf, Thomas and Kipf, Thomas
}, },
title = {{Object Scene Representation Transformer}}, title = {{Object Scene Representation Transformer}},
journal = {arXiv preprint arXiv:2206.06922}, journal = {NeurIPS},
year = {2022} year = {2022}
} }
``` ```
......
data: data:
dataset: clevr3d dataset: clevr3d
num_points: 2048 num_points: 8192
kwargs: kwargs:
downsample: 1 downsample: 1
model: model:
encoder: osrt encoder: osrt
encoder_kwargs: encoder_kwargs:
pos_start_octave: -5 pos_start_octave: -5
num_slots: 7 num_slots: 6
decoder: slot_mixer decoder: slot_mixer
decoder_kwargs: decoder_kwargs:
pos_start_octave: -5 pos_start_octave: -5
......
data: data:
dataset: osrt dataset: osrt
num_points: 4096 num_points: 2304
model: model:
encoder: osrt encoder: osrt
encoder_kwargs: encoder_kwargs:
...@@ -13,13 +13,13 @@ model: ...@@ -13,13 +13,13 @@ model:
training: training:
num_workers: 1 num_workers: 1
batch_size: 48 batch_size: 256
model_selection_metric: psnr model_selection_metric: psnr
model_selection_mode: maximize model_selection_mode: maximize
print_every: 10 print_every: 10
visualize_every: 5000 visualize_every: 2000
validate_every: 5000 validate_every: 2000
checkpoint_every: 1000 checkpoint_every: 200
backup_every: 25000 backup_every: 25000
max_it: 4000000 max_it: 4000000
decay_it: 4000000 decay_it: 4000000
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment