diff --git a/.gitmodules b/.gitmodules
index 37d13bdac9a6f3ffedcf2fb1299600a66bf0adcf..23cf10097cbd7800f269aec2452b9745705a9815 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -28,3 +28,6 @@
 [submodule "dependencies/zlib"]
 	path = dependencies/zlib
 	url = https://github.com/Tom94/zlib
+[submodule "dependencies/OpenXR-SDK"]
+	path = dependencies/OpenXR-SDK
+	url = https://github.com/KhronosGroup/OpenXR-SDK.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8ea96eb9860e7854f7007440a17a0254f72689d7..92d926af9326d69fe84de5750f97da2c6cfd52f4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -119,6 +119,38 @@ if (NGP_BUILD_WITH_GUI)
 		endif()
 	endif()
 
+	# OpenXR
+	if (WIN32)
+		list(APPEND NGP_DEFINITIONS -DXR_USE_PLATFORM_WIN32 -DGLFW_EXPOSE_NATIVE_WGL)
+	elseif (UNIX AND NOT APPLE)
+		list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_GLX)
+		if (JK_USE_WAYLAND)
+			set(PRESENTATION_BACKEND wayland CACHE STRING " " FORCE)
+			set(BUILD_WITH_XLIB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_XCB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_WAYLAND_HEADERS ON CACHE BOOL " " FORCE)
+			list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_WAYLAND -DXR_USE_PLATFORM_WAYLAND)
+		else()
+			set(PRESENTATION_BACKEND xlib CACHE STRING " " FORCE)
+			set(BUILD_WITH_XLIB_HEADERS ON CACHE BOOL " " FORCE)
+			set(BUILD_WITH_XCB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_WAYLAND_HEADERS OFF CACHE BOOL " " FORCE)
+			list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_X11 -DXR_USE_PLATFORM_XLIB)
+		endif()
+	else()
+		message(FATAL_ERROR "No OpenXR platform set for this OS")
+	endif()
+
+	add_subdirectory(dependencies/OpenXR-SDK)
+
+	list(APPEND NGP_INCLUDE_DIRECTORIES "dependencies/OpenXR-SDK/include" "dependencies/OpenXR-SDK/src/common")
+	list(APPEND NGP_LIBRARIES openxr_loader)
+	list(APPEND GUI_SOURCES src/openxr_hmd.cu)
+
+	# OpenGL
+	find_package(OpenGL REQUIRED)
+
+	# GLFW
 	set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL " " FORCE)
 	set(GLFW_BUILD_TESTS OFF CACHE BOOL " " FORCE)
 	set(GLFW_BUILD_DOCS OFF CACHE BOOL " " FORCE)
diff --git a/LICENSE.txt b/LICENSE.txt
index 2cfc50b56f42a59b9bbeb82e56b527ad8deace84..34191874786e0339d672cd74d10175bec14b9f9d 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,4 +1,4 @@
-Copyright (c) 2022, NVIDIA Corporation & affiliates. All rights reserved.
+Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved.
 
 
 NVIDIA Source Code License for instant neural graphics primitives
diff --git a/README.md b/README.md
index db75273d99b9606bcb8d3cf9695565b85e48985c..d2808e6f5530b21134e09d5e53f75327a6544001 100644
--- a/README.md
+++ b/README.md
@@ -5,120 +5,55 @@
 Ever wanted to train a NeRF model of a fox in under 5 seconds? Or fly around a scene captured from photos of a factory robot? Of course you have!
 
 Here you will find an implementation of four __neural graphics primitives__, being neural radiance fields (NeRF), signed distance functions (SDFs), neural images, and neural volumes.
-In each case, we train and render a MLP with multiresolution hash input encoding using the [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn) framework.
+In each case, we train and render a MLP with multiresolution hash input encoding using the [__tiny-cuda-nn__](https://github.com/NVlabs/tiny-cuda-nn) framework.
 
 > __Instant Neural Graphics Primitives with a Multiresolution Hash Encoding__  
 > [Thomas Müller](https://tom94.net), [Alex Evans](https://research.nvidia.com/person/alex-evans), [Christoph Schied](https://research.nvidia.com/person/christoph-schied), [Alexander Keller](https://research.nvidia.com/person/alex-keller)  
 > _ACM Transactions on Graphics (__SIGGRAPH__), July 2022_  
 > __[Project page](https://nvlabs.github.io/instant-ngp) / [Paper](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf) / [Video](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.mp4) / [Presentation](https://tom94.net/data/publications/mueller22instant/mueller22instant-gtc.mp4) / [Real-Time Live](https://tom94.net/data/publications/mueller22instant/mueller22instant-rtl.mp4) / [BibTeX](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.bib)__
 
-To get started with NVIDIA Instant NeRF, check out the [blog post](https://developer.nvidia.com/blog/getting-started-with-nvidia-instant-nerfs/) and [SIGGRAPH tutorial](https://www.nvidia.com/en-us/on-demand/session/siggraph2022-sigg22-s-16/).
-
 For business inquiries, please submit the [NVIDIA research licensing form](https://www.nvidia.com/en-us/research/inquiries/).
 
 
-## Windows binary release
+## Installation
 
-If you have Windows and if you do not need developer Python bindings, you can download one of the following binary releases and then jump directly to the [usage instructions](https://github.com/NVlabs/instant-ngp#interactive-training-and-rendering) or to [creating your own NeRF from a recording](docs/nerf_dataset_tips.md).
+If you have Windows, download one of the following releases corresponding to your graphics card and extract it. Then, start `instant-ngp.exe`.
 
 - [**RTX 3000 & 4000 series, RTX A4000–A6000**, and other Ampere & Ada cards](https://github.com/NVlabs/instant-ngp/releases/download/continuous/Instant-NGP-for-RTX-3000-and-4000.zip)
 - [**RTX 2000 series, Titan RTX, Quadro RTX 4000–8000**, and other Turing cards](https://github.com/NVlabs/instant-ngp/releases/download/continuous/Instant-NGP-for-RTX-2000.zip)
 - [**GTX 1000 series, Titan Xp, Quadro P1000–P6000**, and other Pascal cards](https://github.com/NVlabs/instant-ngp/releases/download/continuous/Instant-NGP-for-GTX-1000.zip)
 
-If you use Linux, or want the developer Python bindings, or if your GPU is not listed above (e.g. Hopper, Volta, or Maxwell generations), use the following step-by-step instructions to compile __instant-ngp__ yourself.
-
-
-## Requirements
-
-- An __NVIDIA GPU__; tensor cores increase performance when available. All shown results come from an RTX 3090.
-- A __C++14__ capable compiler. The following choices are recommended and have been tested:
-  - __Windows:__ Visual Studio 2019 or 2022
-  - __Linux:__ GCC/G++ 8 or higher
-- A recent version of __[CUDA](https://developer.nvidia.com/cuda-toolkit)__. The following choices are recommended and have been tested:
-  - __Windows:__ CUDA 11.5 or higher
-  - __Linux:__ CUDA 10.2 or higher
-- __[CMake](https://cmake.org/) v3.21 or higher__.
-- __(optional) [Python](https://www.python.org/) 3.7 or higher__ for interactive bindings. Also, run `pip install -r requirements.txt`.
-- __(optional) [OptiX](https://developer.nvidia.com/optix) 7.6 or higher__ for faster mesh SDF training.
-- __(optional) [Vulkan SDK](https://vulkan.lunarg.com/)__ for DLSS support.
-
-
-If you are using Debian based Linux distribution, install the following packages
-```sh
-sudo apt-get install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev \
-                     libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev
-```
-
-Alternatively, if you are using Arch or Arch derivatives, install the following packages
-```sh
-sudo pacman -S cuda base-devel cmake openexr libxi glfw openmp libxinerama libxcursor
-```
-
-We also recommend installing [CUDA](https://developer.nvidia.com/cuda-toolkit) and [OptiX](https://developer.nvidia.com/optix) in `/usr/local/` and adding the CUDA installation to your PATH.
-
-For example, if you have CUDA 11.4, add the following to your `~/.bashrc`
-```sh
-export PATH="/usr/local/cuda-11.4/bin:$PATH"
-export LD_LIBRARY_PATH="/usr/local/cuda-11.4/lib64:$LD_LIBRARY_PATH"
-```
-
-
-## Compilation (Windows & Linux)
-
-Begin by cloning this repository and all its submodules using the following command:
-```sh
-$ git clone --recursive https://github.com/nvlabs/instant-ngp
-$ cd instant-ngp
-```
-
-Then, use CMake to build the project: (on Windows, this must be in a [developer command prompt](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-160#developer_command_prompt))
-```sh
-instant-ngp$ cmake . -B build
-instant-ngp$ cmake --build build --config RelWithDebInfo -j
-```
-
-If compilation fails inexplicably or takes longer than an hour, you might be running out of memory. Try running the above command without `-j` in that case.
-If this does not help, please consult [this list of possible fixes](https://github.com/NVlabs/instant-ngp#troubleshooting-compile-errors) before opening an issue.
-
-If the build succeeds, you can now run the code via the `./instant-ngp` executable or the `scripts/run.py` script described below.
-
-If automatic GPU architecture detection fails, (as can happen if you have multiple GPUs installed), set the `TCNN_CUDA_ARCHITECTURES` environment variable for the GPU you would like to use. The following table lists the values for common GPUs. If your GPU is not listed, consult [this exhaustive list](https://developer.nvidia.com/cuda-gpus).
+Keep reading for a guided tour of the application or, if you are interested in creating your own NeRF, watch [the video tutorial](https://www.youtube.com/watch?v=3TWxO1PftMc) or read the [written instructions for creating your own NeRF](docs/nerf_dataset_tips.md).
 
-| H100 | 40X0 | 30X0 | A100 | 20X0 | TITAN V / V100 | 10X0 / TITAN Xp | 9X0 | K80 |
-|:----:|:----:|:----:|:----:|:----:|:--------------:|:---------------:|:---:|:---:|
-|   90 |   89 |   86 |   80 |   75 |             70 |              61 |  52 |  37 |
+If you use Linux, or want the [developer Python bindings](https://github.com/NVlabs/instant-ngp#python-bindings), or if your GPU is not listed above (e.g. Hopper, Volta, or Maxwell generations), you need to [build __instant-ngp__ yourself](https://github.com/NVlabs/instant-ngp#building-instant-ngp-windows--linux).
 
 
+## Usage
 
-## Interactive training and rendering
+### Graphical user interface
 
 <img src="docs/assets_readme/testbed.png" width="100%"/>
 
-This codebase comes with an interactive GUI that includes many features beyond our academic publication:
-- Additional training features, such as extrinsics and intrinsics optimization.
-- Marching cubes for `NeRF->Mesh` and `SDF->Mesh` conversion.
-- A spline-based camera path editor to create videos.
-- Debug visualizations of the activations of every neuron input and output.
-- And many more task-specific settings.
-- See also our [one minute demonstration video of the tool](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.mp4).
+__instant-ngp__ comes with an interactive GUI that includes many features beyond our academic publication, including
+- [comprehensive controls](https://github.com/NVlabs/instant-ngp#gui-controls) for interactively exploring neural graphics primitives,
+- [VR mode](https://github.com/NVlabs/instant-ngp#vr-controls) for viewing neural graphics primitives through a virtual-reality headset,
+- saving and loading "snapshots" so you can share your graphics primitives on the internet,
+- a camera path editor to create videos,
+- `NeRF->Mesh` and `SDF->Mesh` conversion,
+- camera pose and lens optimization,
+- and many more.
+
+See also our [one minute demonstration video of the tool](https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.mp4).
 
-Let's start using __instant-ngp__; more information about the GUI and other scripts follow these test scenes.
 
 ### NeRF fox
 
-One test scene is provided in this repository, using a small number of frames from a casually captured phone video.
-Simply start `instant-ngp` and drag the `data/nerf/fox` folder into the GUI. Or, alternatively, use the command line:
+Simply start `instant-ngp` and drag the `data/nerf/fox` folder into the window. Or, alternatively, use the command line:
 
 ```sh
 instant-ngp$ ./instant-ngp data/nerf/fox
 ```
 
-On Windows you need to reverse the slashes here (and below), i.e.:
-
-```sh
-instant-ngp> .\instant-ngp data\nerf\fox
-```
-
 <img src="docs/assets_readme/fox.png"/>
 
 Alternatively, download any NeRF-compatible scene (e.g. from the [NeRF authors' drive](https://drive.google.com/drive/folders/1JDdLGDruGNXWnM1eqY1FNL9PlStjaKWi), the [SILVR dataset](https://github.com/IDLabMedia/large-lightfields-dataset), or the [DroneDeploy dataset](https://github.com/nickponline/dd-nerf-dataset)).
@@ -132,7 +67,7 @@ instant-ngp$ ./instant-ngp data/nerf_synthetic/lego/transforms_train.json
 
 ### SDF armadillo
 
-Drag `data/sdf/armadillo.obj` into the GUI or use the command:
+Drag `data/sdf/armadillo.obj` into the window or use the command:
 
 ```sh
 instant-ngp$ ./instant-ngp data/sdf/armadillo.obj
@@ -142,7 +77,7 @@ instant-ngp$ ./instant-ngp data/sdf/armadillo.obj
 
 ### Image of Einstein
 
-Drag `data/image/albert.exr` into the GUI or use the command:
+Drag `data/image/albert.exr` into the window or use the command:
 
 ```sh
 instant-ngp$ ./instant-ngp data/image/albert.exr
@@ -161,7 +96,7 @@ instant-ngp$ ./instant-ngp data/image/tokyo.bin
 
 Download the [nanovdb volume for the Disney cloud](https://drive.google.com/drive/folders/1SuycSAOSG64k2KLV7oWgyNWyCvZAkafK?usp=sharing), which is derived [from here](https://disneyanimation.com/data-sets/?drawer=/resources/clouds/) ([CC BY-SA 3.0](https://media.disneyanimation.com/uploads/production/data_set_asset/6/asset/License_Cloud.pdf)).
 
-Then drag `wdas_cloud_quarter.nvdb` into the GUI or use the command:
+Then drag `wdas_cloud_quarter.nvdb` into the window or use the command:
 
 ```sh
 instant-ngp$ ./instant-ngp wdas_cloud_quarter.nvdb
@@ -169,7 +104,7 @@ instant-ngp$ ./instant-ngp wdas_cloud_quarter.nvdb
 <img src="docs/assets_readme/cloud.png"/>
 
 
-### GUI controls
+### Keyboard shortcuts and recommended controls
 
 Here are the main keyboard controls for the __instant-ngp__ application.
 
@@ -177,9 +112,12 @@ Here are the main keyboard controls for the __instant-ngp__ application.
 | :-------------: | ------------- |
 | WASD            | Forward / pan left / backward / pan right. |
 | Spacebar / C    | Move up / down. |
-| = or + / - or _ | Increase / decrease camera velocity. |
+| = or + / - or _ | Increase / decrease camera velocity (first person mode) or zoom in / out (third person mode). |
 | E / Shift+E     | Increase / decrease exposure. |
+| Tab             | Toggle menu visibility. |
 | T               | Toggle training. After around two minutes training tends to settle down, so can be toggled off. |
+| { }             | Go to the first/last training image camera view. |
+| [ ]             | Go to the previous/next training image camera view. |
 | R               | Reload network from file. |
 | Shift+R         | Reset camera. |
 | O               | Toggle visualization or accumulated error map. |
@@ -191,27 +129,140 @@ Here are the main keyboard controls for the __instant-ngp__ application.
 There are many controls in the __instant-ngp__ GUI.
 First, note that this GUI can be moved and resized, as can the "Camera path" GUI (which first must be expanded to be used).
 
-Some popular user controls in __instant-ngp__ are:
+Recommended user controls in __instant-ngp__ are:
 
-* __Snapshot:__ use Save to save the NeRF solution generated, Load to reload. Necessary if you want to make an animation.
-* __Rendering -> DLSS:__ toggling this on and setting "DLSS sharpening" below it to 1.0 can often improve rendering quality.
+* __Snapshot:__ use Save to save the trained NeRF, Load to reload. Necessary if you want to make an animation.
+* __Rendering -> DLSS:__ toggling this on and setting "DLSS sharpening" to 1.0 can often improve rendering quality.
 * __Rendering -> Crop size:__ trim back the surrounding environment to focus on the model. "Crop aabb" lets you move the center of the volume of interest and fine tune. See more about this feature in [our NeRF training & dataset tips](https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md).
 
-The "Camera path" GUI lets you set frames along a path. "Add from cam" is the main button you'll want to push. Then, you can render a video `.mp4` of your camera path or export the keyframes to a `.json` file. There is a bit more information about the GUI [in this post](https://developer.nvidia.com/blog/getting-started-with-nvidia-instant-nerfs/) and [in this video guide to creating your own video](https://www.youtube.com/watch?v=3TWxO1PftMc).
+The "Camera path" GUI lets you create a camera path for rendering a video.
+The button "Add from cam" inserts keyframes from the current perspective.
+Then, you can render a video `.mp4` of your camera path or export the keyframes to a `.json` file.
+There is a bit more information about the GUI [in this post](https://developer.nvidia.com/blog/getting-started-with-nvidia-instant-nerfs/) and [in this video guide to creating your own video](https://www.youtube.com/watch?v=3TWxO1PftMc).
+
+
+### VR controls
+
+To view the neural graphics primitive in VR, first start your VR runtime. This will most likely be either
+- __OculusVR__ if you have an Oculus Rift or Meta Quest (with link cable) headset, and
+- __SteamVR__ if you have another headset.
+- Any OpenXR-compatible runtime will work.
+
+Then, press the __View in VR/AR headset__ button in the __instant-ngp__ GUI and put on your headset.
+In VR, you have the following controls.
+
+| Control                | Meaning       |
+| :--------------------: | ------------- |
+| Left stick / trackpad  | Move |
+| Right stick / trackpad | Turn camera |
+| Press stick / trackpad | Erase NeRF around the hand |
+| Grab (one-handed)      | Drag neural graphics primitive |
+| Grab (two-handed)      | Rotate and zoom (like pinch-to-zoom on a smartphone) |
+
+
+## Building instant-ngp (Windows & Linux)
+
+### Requirements
+
+- An __NVIDIA GPU__; tensor cores increase performance when available. All shown results come from an RTX 3090.
+- A __C++14__ capable compiler. The following choices are recommended and have been tested:
+  - __Windows:__ Visual Studio 2019 or 2022
+  - __Linux:__ GCC/G++ 8 or higher
+- A recent version of __[CUDA](https://developer.nvidia.com/cuda-toolkit)__. The following choices are recommended and have been tested:
+  - __Windows:__ CUDA 11.5 or higher
+  - __Linux:__ CUDA 10.2 or higher
+- __[CMake](https://cmake.org/) v3.21 or higher__.
+- __(optional) [Python](https://www.python.org/) 3.7 or higher__ for interactive bindings. Also, run `pip install -r requirements.txt`.
+- __(optional) [OptiX](https://developer.nvidia.com/optix) 7.6 or higher__ for faster mesh SDF training.
+- __(optional) [Vulkan SDK](https://vulkan.lunarg.com/)__ for DLSS support.
+
+
+If you are using Debian based Linux distribution, install the following packages
+```sh
+sudo apt-get install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev \
+                     libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev
+```
+
+Alternatively, if you are using Arch or Arch derivatives, install the following packages
+```sh
+sudo pacman -S cuda base-devel cmake openexr libxi glfw openmp libxinerama libxcursor
+```
+
+We also recommend installing [CUDA](https://developer.nvidia.com/cuda-toolkit) and [OptiX](https://developer.nvidia.com/optix) in `/usr/local/` and adding the CUDA installation to your PATH.
+
+For example, if you have CUDA 11.4, add the following to your `~/.bashrc`
+```sh
+export PATH="/usr/local/cuda-11.4/bin:$PATH"
+export LD_LIBRARY_PATH="/usr/local/cuda-11.4/lib64:$LD_LIBRARY_PATH"
+```
+
+
+### Compilation
+
+Begin by cloning this repository and all its submodules using the following command:
+```sh
+$ git clone --recursive https://github.com/nvlabs/instant-ngp
+$ cd instant-ngp
+```
+
+Then, use CMake to build the project: (on Windows, this must be in a [developer command prompt](https://docs.microsoft.com/en-us/cpp/build/building-on-the-command-line?view=msvc-160#developer_command_prompt))
+```sh
+instant-ngp$ cmake . -B build
+instant-ngp$ cmake --build build --config RelWithDebInfo -j
+```
+
+If compilation fails inexplicably or takes longer than an hour, you might be running out of memory. Try running the above command without `-j` in that case.
+If this does not help, please consult [this list of possible fixes](https://github.com/NVlabs/instant-ngp#troubleshooting-compile-errors) before opening an issue.
+
+If the build succeeds, you can now run the code via the `./instant-ngp` executable or the `scripts/run.py` script described below.
+
+If automatic GPU architecture detection fails, (as can happen if you have multiple GPUs installed), set the `TCNN_CUDA_ARCHITECTURES` environment variable for the GPU you would like to use. The following table lists the values for common GPUs. If your GPU is not listed, consult [this exhaustive list](https://developer.nvidia.com/cuda-gpus).
+
+| H100 | 40X0 | 30X0 | A100 | 20X0 | TITAN V / V100 | 10X0 / TITAN Xp | 9X0 | K80 |
+|:----:|:----:|:----:|:----:|:----:|:--------------:|:---------------:|:---:|:---:|
+|   90 |   89 |   86 |   80 |   75 |             70 |              61 |  52 |  37 |
 
 
 ## Python bindings
 
-To conduct controlled experiments in an automated fashion, all features from the interactive GUI (and more!) have Python bindings that can be easily instrumented.
+After you have built __instant-ngp__, you can use its Python bindings to conduct controlled experiments in an automated fashion.
+All features from the interactive GUI (and more!) have Python bindings that can be easily instrumented.
 For an example of how the `./instant-ngp` application can be implemented and extended from within Python, see `./scripts/run.py`, which supports a superset of the command line arguments that `./instant-ngp` does.
 
-If you'd rather build new models from the hash encoding and fast neural networks, consider the [__tiny-cuda-nn__'s PyTorch extension](https://github.com/nvlabs/tiny-cuda-nn#pytorch-extension).
+If you would rather build new models from the hash encoding and fast neural networks, consider [__tiny-cuda-nn__'s PyTorch extension](https://github.com/nvlabs/tiny-cuda-nn#pytorch-extension).
 
 Happy hacking!
 
 
+## Additional resources
+
+- [Getting started with NVIDIA Instant NeRF blog post](https://developer.nvidia.com/blog/getting-started-with-nvidia-instant-nerfs/)
+- [SIGGRAPH tutorial for advanced NeRF dataset creation](https://www.nvidia.com/en-us/on-demand/session/siggraph2022-sigg22-s-16/).
+
+
 ## Frequently asked questions (FAQ)
 
+__Q:__ The NeRF reconstruction of my custom dataset looks bad; what can I do?
+
+__A:__ There could be multiple issues:
+- COLMAP might have been unable to reconstruct camera poses.
+- There might have been movement or blur during capture. Don't treat capture as an artistic task; treat it as photogrammetry. You want _\*as little blur as possible\*_ in your dataset (motion, defocus, or otherwise) and all objects must be _\*static\*_ during the entire capture. Bonus points if you are using a wide-angle lens (iPhone wide angle works well), because it covers more space than narrow lenses.
+- The dataset parameters (in particular `aabb_scale`) might have been tuned suboptimally. We recommend starting with `aabb_scale=128` and then increasing or decreasing it by factors of two until you get optimal quality.
+- Carefully read [our NeRF training & dataset tips](https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md).
+
+##
+__Q:__ How can I save the trained model and load it again later?
+
+__A:__ Two options:
+1. Use the GUI's "Snapshot" section.
+2. Use the Python bindings `load_snapshot` / `save_snapshot` (see `scripts/run.py` for example usage).
+
+##
+__Q:__ Can this codebase use multiple GPUs at the same time?
+
+__A:__ Only for VR rendering, in which case one GPU is used per eye. Otherwise, no. To select a specific GPU to run on, use the [CUDA_VISIBLE_DEVICES](https://stackoverflow.com/questions/39649102/how-do-i-select-which-gpu-to-run-a-job-on) environment variable. To optimize the _compilation_ for that specific GPU use the [TCNN_CUDA_ARCHITECTURES](https://github.com/NVlabs/instant-ngp#compilation-windows--linux) environment variable.
+
+##
 __Q:__ How can I run __instant-ngp__ in headless mode?
 
 __A:__ Use `./instant-ngp --no-gui` or `python scripts/run.py`. You can also compile without GUI via `cmake -DNGP_BUILD_WITH_GUI=off ...`
@@ -239,38 +290,16 @@ __Q:__ How can I edit and train the underlying hash encoding or neural network o
 
 __A:__ Use [__tiny-cuda-nn__'s PyTorch extension](https://github.com/nvlabs/tiny-cuda-nn#pytorch-extension).
 
-##
-__Q:__ How can I save the trained model and load it again later?
-
-__A:__ Two options:
-1. Use the GUI's "Snapshot" section.
-2. Use the Python bindings `load_snapshot` / `save_snapshot` (see `scripts/run.py` for example usage).
-
-##
-__Q:__ Can this codebase use multiple GPUs at the same time?
-
-__A:__ No. To select a specific GPU to run on, use the [CUDA_VISIBLE_DEVICES](https://stackoverflow.com/questions/39649102/how-do-i-select-which-gpu-to-run-a-job-on) environment variable. To optimize the _compilation_ for that specific GPU use the [TCNN_CUDA_ARCHITECTURES](https://github.com/NVlabs/instant-ngp#compilation-windows--linux) environment variable.
-
 ##
 __Q:__ What is the coordinate system convention?
 
 __A:__ See [this helpful diagram](https://github.com/NVlabs/instant-ngp/discussions/153?converting=1#discussioncomment-2187652) by user @jc211.
 
-##
-__Q:__ The NeRF reconstruction of my custom dataset looks bad; what can I do?
-
-__A:__ There could be multiple issues:
-- COLMAP might have been unable to reconstruct camera poses.
-- There might have been movement or blur during capture. Don't treat capture as an artistic task; treat it as photogrammetry. You want _\*as little blur as possible\*_ in your dataset (motion, defocus, or otherwise) and all objects must be _\*static\*_ during the entire capture. Bonus points if you are using a wide-angle lens (iPhone wide angle works well), because it covers more space than narrow lenses.
-- The dataset parameters (in particular `aabb_scale`) might have been tuned suboptimally. We recommend starting with `aabb_scale=16` and then increasing or decreasing it by factors of two until you get optimal quality.
-- Carefully read [our NeRF training & dataset tips](https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md).
-
 ##
 __Q:__ Why are background colors randomized during NeRF training?
 
 __A:__ Transparency in the training data indicates a desire for transparency in the learned model. Using a solid background color, the model can minimize its loss by simply predicting that background color, rather than transparency (zero density). By randomizing the background colors, the model is _forced_ to learn zero density to let the randomized colors "shine through".
 
-
 ##
 __Q:__ How to mask away NeRF training pixels (e.g. for dynamic object removal)?
 
diff --git a/dependencies/OpenXR-SDK b/dependencies/OpenXR-SDK
new file mode 160000
index 0000000000000000000000000000000000000000..e2da9ce83a4388c9622da328bf48548471261290
--- /dev/null
+++ b/dependencies/OpenXR-SDK
@@ -0,0 +1 @@
+Subproject commit e2da9ce83a4388c9622da328bf48548471261290
diff --git a/docs/nerf_dataset_tips.md b/docs/nerf_dataset_tips.md
index d5f982429dae4e9edf26c02720ce1042a14e9900..de311e80e227caca34fc73203ef42bcf433001d0 100644
--- a/docs/nerf_dataset_tips.md
+++ b/docs/nerf_dataset_tips.md
@@ -43,7 +43,7 @@ If you have an existing dataset in `transforms.json` format, it should be center
 You can set any of the following parameters, where the listed values are the default.
 ```json
 {
-	"aabb_scale": 16,
+	"aabb_scale": 64,
 	"scale": 0.33,
 	"offset": [0.5, 0.5, 0.5],
 	...	
@@ -88,7 +88,7 @@ If you use Windows, you do not need to install anything. COLMAP and FFmpeg will
 If you are training from a video file, run the [scripts/colmap2nerf.py](/scripts/colmap2nerf.py) script from the folder containing the video, with the following recommended parameters:
 
 ```sh
-data-folder$ python [path-to-instant-ngp]/scripts/colmap2nerf.py --video_in <filename of video> --video_fps 2 --run_colmap --aabb_scale 16
+data-folder$ python [path-to-instant-ngp]/scripts/colmap2nerf.py --video_in <filename of video> --video_fps 2 --run_colmap --aabb_scale 64
 ```
 
 The above assumes a single video file as input, which then has frames extracted at the specified framerate (2). It is recommended to choose a frame rate that leads to around 50-150 images. So for a one minute video, `--video_fps 2` is ideal.
@@ -96,7 +96,7 @@ The above assumes a single video file as input, which then has frames extracted
 For training from images, place them in a subfolder called `images` and then use suitable options such as the ones below:
 
 ```sh
-data-folder$ python [path-to-instant-ngp]/scripts/colmap2nerf.py --colmap_matcher exhaustive --run_colmap --aabb_scale 16
+data-folder$ python [path-to-instant-ngp]/scripts/colmap2nerf.py --colmap_matcher exhaustive --run_colmap --aabb_scale 64
 ```
 
 The script will run (and install, if you use Windows) FFmpeg and COLMAP as needed, followed by a conversion step to the required `transforms.json` format, which will be written in the current directory. 
@@ -104,7 +104,7 @@ The script will run (and install, if you use Windows) FFmpeg and COLMAP as neede
 By default, the script invokes colmap with the "sequential matcher", which is suitable for images taken from a smoothly changing camera path, as in a video. The exhaustive matcher is more appropriate if the images are in no particular order, as shown in the image example above.
 For more options, you can run the script with `--help`. For more advanced uses of COLMAP or for challenging scenes, please see the [COLMAP documentation](https://colmap.github.io/cli.html); you may need to modify the [scripts/colmap2nerf.py](/scripts/colmap2nerf.py) script itself.
 
-The `aabb_scale` parameter is the most important `instant-ngp` specific parameter. It specifies the extent of the scene, defaulting to 1; that is, the scene is scaled such that the camera positions are at an average distance of 1 unit from the origin. For small synthetic scenes such as the original NeRF dataset, the default `aabb_scale` of 1 is ideal and leads to fastest training. The NeRF model makes the assumption that the training images can entirely be explained by a scene contained within this bounding box. However, for natural scenes where there is a background that extends beyond this bounding box, the NeRF model will struggle and may hallucinate "floaters" at the boundaries of the box. By setting `aabb_scale` to a larger power of 2 (up to a maximum of 16), the NeRF model will extend rays to a much larger bounding box. Note that this can impact training speed slightly. If in doubt, for natural scenes, start with an `aabb_scale` of 128, and subsequently reduce it if possible. The value can be directly edited in the `transforms.json` output file, without re-running the [scripts/colmap2nerf.py](/scripts/colmap2nerf.py) script.
+The `aabb_scale` parameter is the most important `instant-ngp` specific parameter. It specifies the extent of the scene, defaulting to 1; that is, the scene is scaled such that the camera positions are at an average distance of 1 unit from the origin. For small synthetic scenes such as the original NeRF dataset, the default `aabb_scale` of 1 is ideal and leads to fastest training. The NeRF model makes the assumption that the training images can entirely be explained by a scene contained within this bounding box. However, for natural scenes where there is a background that extends beyond this bounding box, the NeRF model will struggle and may hallucinate "floaters" at the boundaries of the box. By setting `aabb_scale` to a larger power of 2 (up to a maximum of 128), the NeRF model will extend rays to a much larger bounding box. Note that this can impact training speed slightly. If in doubt, for natural scenes, start with an `aabb_scale` of 128, and subsequently reduce it if possible. The value can be directly edited in the `transforms.json` output file, without re-running the [scripts/colmap2nerf.py](/scripts/colmap2nerf.py) script.
 
 You can optionally pass in object categories (e.g. `--mask_categories person car`) which runs [Detectron2](https://github.com/facebookresearch/detectron2) to generate masks automatically.
 __instant-ngp__ will not use the masked pixels for training.
diff --git a/include/neural-graphics-primitives/camera_path.h b/include/neural-graphics-primitives/camera_path.h
index 8d3fe24792b182606950fba15df410bbb692816c..9d121757233a323295f0f722f2025df523b0dffe 100644
--- a/include/neural-graphics-primitives/camera_path.h
+++ b/include/neural-graphics-primitives/camera_path.h
@@ -132,8 +132,8 @@ struct CameraPath {
 #ifdef NGP_GUI
 	ImGuizmo::MODE m_gizmo_mode = ImGuizmo::LOCAL;
 	ImGuizmo::OPERATION m_gizmo_op = ImGuizmo::TRANSLATE;
-	bool imgui_viz(ImDrawList* list, Eigen::Matrix<float, 4, 4>& view2proj, Eigen::Matrix<float, 4, 4>& world2proj, Eigen::Matrix<float, 4, 4>& world2view, Eigen::Vector2f focal, float aspect);
 	int imgui(char path_filename_buf[1024], float frame_milliseconds, Eigen::Matrix<float, 3, 4>& camera, float slice_plane_z, float scale, float fov, float aperture_size, float bounding_radius, const Eigen::Matrix<float, 3, 4>& first_xform, int glow_mode, float glow_y_cutoff);
+	bool imgui_viz(ImDrawList* list, Eigen::Matrix<float, 4, 4>& view2proj, Eigen::Matrix<float, 4, 4>& world2proj, Eigen::Matrix<float, 4, 4>& world2view, Eigen::Vector2f focal, float aspect, float znear, float zfar);
 #endif
 };
 
diff --git a/include/neural-graphics-primitives/common.h b/include/neural-graphics-primitives/common.h
index 8d3f6068b8462f29fedaa4837f3596d5187d1993..3a632b4d2f51fa1189c3fe9cf1c53b3d3052359f 100644
--- a/include/neural-graphics-primitives/common.h
+++ b/include/neural-graphics-primitives/common.h
@@ -232,8 +232,13 @@ enum class ELensMode : int {
 	FTheta,
 	LatLong,
 	OpenCVFisheye,
+	Equirectangular,
 };
-static constexpr const char* LensModeStr = "Perspective\0OpenCV\0F-Theta\0LatLong\0OpenCV Fisheye\0\0";
+static constexpr const char* LensModeStr = "Perspective\0OpenCV\0F-Theta\0LatLong\0OpenCV Fisheye\0Equirectangular\0\0";
+
+inline bool supports_dlss(ELensMode mode) {
+	return mode == ELensMode::Perspective || mode == ELensMode::OpenCV || mode == ELensMode::OpenCVFisheye;
+}
 
 struct Lens {
 	ELensMode mode = ELensMode::Perspective;
@@ -343,6 +348,47 @@ private:
 	std::chrono::time_point<std::chrono::steady_clock> m_creation_time;
 };
 
+template <typename T>
+struct Buffer2DView {
+	T* data = nullptr;
+	Eigen::Vector2i resolution = Eigen::Vector2i::Zero();
+
+	// Lookup via integer pixel position (no bounds checking)
+	NGP_HOST_DEVICE T at(const Eigen::Vector2i& xy) const {
+		return data[xy.x() + xy.y() * resolution.x()];
+	}
+
+	// Lookup via UV coordinates in [0,1]^2
+	NGP_HOST_DEVICE T at(const Eigen::Vector2f& uv) const {
+		Eigen::Vector2i xy = resolution.cast<float>().cwiseProduct(uv).cast<int>().cwiseMax(0).cwiseMin(resolution - Eigen::Vector2i::Ones());
+		return at(xy);
+	}
+
+	// Lookup via UV coordinates in [0,1]^2 and LERP the nearest texels
+	NGP_HOST_DEVICE T at_lerp(const Eigen::Vector2f& uv) const {
+		const Eigen::Vector2f xy_float = resolution.cast<float>().cwiseProduct(uv);
+		const Eigen::Vector2i xy = xy_float.cast<int>();
+
+		const Eigen::Vector2f weight = xy_float - xy.cast<float>();
+
+		auto read_val = [&](Eigen::Vector2i pos) {
+			pos = pos.cwiseMax(0).cwiseMin(resolution - Eigen::Vector2i::Ones());
+			return at(pos);
+		};
+
+		return (
+			(1 - weight.x()) * (1 - weight.y()) * read_val({xy.x(), xy.y()}) +
+			(weight.x()) * (1 - weight.y()) * read_val({xy.x()+1, xy.y()}) +
+			(1 - weight.x()) * (weight.y()) * read_val({xy.x(), xy.y()+1}) +
+			(weight.x()) * (weight.y()) * read_val({xy.x()+1, xy.y()+1})
+		);
+	}
+
+	NGP_HOST_DEVICE operator bool() const {
+		return data;
+	}
+};
+
 uint8_t* load_stbi(const fs::path& path, int* width, int* height, int* comp, int req_comp);
 float* load_stbi_float(const fs::path& path, int* width, int* height, int* comp, int req_comp);
 uint16_t* load_stbi_16(const fs::path& path, int* width, int* height, int* comp, int req_comp);
diff --git a/include/neural-graphics-primitives/common_device.cuh b/include/neural-graphics-primitives/common_device.cuh
index 389fc3bdab7aff5925f14c5cfe88d8aebbd94e39..361c62bafd81b58ce3eabcbe79f02b472f7767d1 100644
--- a/include/neural-graphics-primitives/common_device.cuh
+++ b/include/neural-graphics-primitives/common_device.cuh
@@ -28,6 +28,52 @@ NGP_NAMESPACE_BEGIN
 using precision_t = tcnn::network_precision_t;
 
 
+// The maximum depth that can be produced when rendering a frame.
+// Chosen somewhat low (rather than std::numeric_limits<float>::infinity())
+// to permit numerically stable reprojection and DLSS operation,
+// even when rendering the infinitely distant horizon.
+inline constexpr __device__ float MAX_DEPTH() { return 16384.0f; }
+
+template <typename T>
+class Buffer2D {
+public:
+	Buffer2D() = default;
+	Buffer2D(const Eigen::Vector2i& resolution) {
+		resize(resolution);
+	}
+
+	T* data() const {
+		return m_data.data();
+	}
+
+	size_t bytes() const {
+		return m_data.bytes();
+	}
+
+	void resize(const Eigen::Vector2i& resolution) {
+		m_data.resize(resolution.prod());
+		m_resolution = resolution;
+	}
+
+	const Eigen::Vector2i& resolution() const {
+		return m_resolution;
+	}
+
+	Buffer2DView<T> view() const {
+		// Row major for now.
+		return {data(), m_resolution};
+	}
+
+	Buffer2DView<const T> const_view() const {
+		// Row major for now.
+		return {data(), m_resolution};
+	}
+
+private:
+	tcnn::GPUMemory<T> m_data;
+	Eigen::Vector2i m_resolution;
+};
+
 inline NGP_HOST_DEVICE float srgb_to_linear(float srgb) {
 	if (srgb <= 0.04045f) {
 		return srgb / 12.92f;
@@ -76,42 +122,9 @@ inline NGP_HOST_DEVICE Eigen::Array3f linear_to_srgb_derivative(const Eigen::Arr
 	return {linear_to_srgb_derivative(x.x()), linear_to_srgb_derivative(x.y()), (linear_to_srgb_derivative(x.z()))};
 }
 
-template <uint32_t N_DIMS, typename T>
-NGP_HOST_DEVICE Eigen::Matrix<float, N_DIMS, 1> read_image(const T* __restrict__ data, const Eigen::Vector2i& resolution, const Eigen::Vector2f& pos) {
-	const Eigen::Vector2f pos_float = Eigen::Vector2f{pos.x() * (float)(resolution.x()-1), pos.y() * (float)(resolution.y()-1)};
-	const Eigen::Vector2i texel = pos_float.cast<int>();
-
-	const Eigen::Vector2f weight = pos_float - texel.cast<float>();
-
-	auto read_val = [&](Eigen::Vector2i pos) {
-		pos.x() = std::max(std::min(pos.x(), resolution.x()-1), 0);
-		pos.y() = std::max(std::min(pos.y(), resolution.y()-1), 0);
-
-		Eigen::Matrix<float, N_DIMS, 1> result;
-		if (std::is_same<T, float>::value) {
-			result = *(Eigen::Matrix<T, N_DIMS, 1>*)&data[(pos.x() + pos.y() * resolution.x()) * N_DIMS];
-		} else {
-			auto val = *(tcnn::vector_t<T, N_DIMS>*)&data[(pos.x() + pos.y() * resolution.x()) * N_DIMS];
-
-			NGP_PRAGMA_UNROLL
-			for (uint32_t i = 0; i < N_DIMS; ++i) {
-				result[i] = (float)val[i];
-			}
-		}
-		return result;
-	};
-
-	return (
-		(1 - weight.x()) * (1 - weight.y()) * read_val({texel.x(), texel.y()}) +
-		(weight.x()) * (1 - weight.y()) * read_val({texel.x()+1, texel.y()}) +
-		(1 - weight.x()) * (weight.y()) * read_val({texel.x(), texel.y()+1}) +
-		(weight.x()) * (weight.y()) * read_val({texel.x()+1, texel.y()+1})
-	);
-}
-
 template <uint32_t N_DIMS, typename T>
 __device__ void deposit_image_gradient(const Eigen::Matrix<float, N_DIMS, 1>& value, T* __restrict__ gradient, T* __restrict__ gradient_weight, const Eigen::Vector2i& resolution, const Eigen::Vector2f& pos) {
-	const Eigen::Vector2f pos_float = Eigen::Vector2f{pos.x() * (resolution.x()-1), pos.y() * (resolution.y()-1)};
+	const Eigen::Vector2f pos_float = resolution.cast<float>().cwiseProduct(pos);
 	const Eigen::Vector2i texel = pos_float.cast<int>();
 
 	const Eigen::Vector2f weight = pos_float - texel.cast<float>();
@@ -142,6 +155,138 @@ __device__ void deposit_image_gradient(const Eigen::Matrix<float, N_DIMS, 1>& va
 	deposit_val(value, (weight.x()) * (weight.y()), {texel.x()+1, texel.y()+1});
 }
 
+struct FoveationPiecewiseQuadratic {
+	NGP_HOST_DEVICE FoveationPiecewiseQuadratic() = default;
+
+	FoveationPiecewiseQuadratic(float center_pixel_steepness, float center_inverse_piecewise_y, float center_radius) {
+		float center_inverse_radius = center_radius * center_pixel_steepness;
+		float left_inverse_piecewise_switch = center_inverse_piecewise_y - center_inverse_radius;
+		float right_inverse_piecewise_switch = center_inverse_piecewise_y + center_inverse_radius;
+
+		if (left_inverse_piecewise_switch < 0) {
+			left_inverse_piecewise_switch = 0.0f;
+		}
+
+		if (right_inverse_piecewise_switch > 1) {
+			right_inverse_piecewise_switch = 1.0f;
+		}
+
+		float am = center_pixel_steepness;
+		float d = (right_inverse_piecewise_switch - left_inverse_piecewise_switch) / center_pixel_steepness / 2;
+
+		// binary search for l,r,bm since analytical is very complex
+		float bm;
+		float m_min = 0.0f;
+		float m_max = 1.0f;
+		for (uint32_t i = 0; i < 20; i++) {
+			float m = (m_min + m_max) / 2.0f;
+			float l = m - d;
+			float r = m + d;
+
+			bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+
+			float l_actual = (left_inverse_piecewise_switch - bm) / am;
+			float r_actual = (right_inverse_piecewise_switch - bm) / am;
+			float m_actual = (l_actual + r_actual) / 2;
+
+			if (m_actual > m) {
+				m_min = m;
+			} else {
+				m_max = m;
+			}
+		}
+
+		float l = (left_inverse_piecewise_switch - bm) / am;
+		float r = (right_inverse_piecewise_switch - bm) / am;
+
+		// Full linear case. Default construction covers this.
+		if ((l == 0.0f && r == 1.0f) || (am == 1.0f)) {
+			return;
+		}
+
+		// write out solution
+		switch_left = l;
+		switch_right = r;
+		this->am = am;
+		al = (am - 1) / (r * r - 2 * r + l * l + 1);
+		bl = (am * (r * r - 2 * r + 1) + am * l * l + (2 - 2 * am) * l) / (r * r - 2 * r + l * l + 1);
+		cl = 0;
+		this->bm = bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+		ar = -(am - 1) / (r * r - 2 * r + l * l + 1);
+		br = (am * (r * r + 1) - 2 * r + am * l * l) / (r * r - 2 * r + l * l + 1);
+		cr = -(am * r * r - r * r + (am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+
+		inv_switch_left = am * switch_left + bm;
+		inv_switch_right = am * switch_right + bm;
+	}
+
+	// left parabola: al * x^2 + bl * x + cl
+	float al = 0.0f, bl = 0.0f, cl = 0.0f;
+	// middle linear piece: am * x + bm.  am should give 1:1 pixel mapping between warped size and full size.
+	float am = 1.0f, bm = 0.0f;
+	// right parabola: al * x^2 + bl * x + cl
+	float ar = 0.0f, br = 0.0f, cr = 0.0f;
+
+	// points where left and right switch over from quadratic to linear
+	float switch_left = 0.0f, switch_right = 1.0f;
+	// same, in inverted space
+	float inv_switch_left = 0.0f, inv_switch_right = 1.0f;
+
+	NGP_HOST_DEVICE float warp(float x) const {
+		x = tcnn::clamp(x, 0.0f, 1.0f);
+		if (x < switch_left) {
+			return al * x * x + bl * x + cl;
+		} else if (x > switch_right) {
+			return ar * x * x + br * x + cr;
+		} else {
+			return am * x + bm;
+		}
+	}
+
+	NGP_HOST_DEVICE float unwarp(float y) const {
+		y = tcnn::clamp(y, 0.0f, 1.0f);
+		if (y < inv_switch_left) {
+			return (std::sqrt(-4 * al * cl + 4 * al * y + bl * bl) - bl) / (2 * al);
+		} else if (y > inv_switch_right) {
+			return (std::sqrt(-4 * ar * cr + 4 * ar * y + br * br) - br) / (2 * ar);
+		} else {
+			return (y - bm) / am;
+		}
+	}
+
+	NGP_HOST_DEVICE float density(float x) const {
+		x = tcnn::clamp(x, 0.0f, 1.0f);
+		if (x < switch_left) {
+			return 2 * al * x + bl;
+		} else if (x > switch_right) {
+			return 2 * ar * x + br;
+		} else {
+			return am;
+		}
+	}
+};
+
+struct Foveation {
+	NGP_HOST_DEVICE Foveation() = default;
+
+	Foveation(const Eigen::Vector2f& center_pixel_steepness, const Eigen::Vector2f& center_inverse_piecewise_y, const Eigen::Vector2f& center_radius)
+	: warp_x{center_pixel_steepness.x(), center_inverse_piecewise_y.x(), center_radius.x()}, warp_y{center_pixel_steepness.y(), center_inverse_piecewise_y.y(), center_radius.y()} {}
+
+	FoveationPiecewiseQuadratic warp_x, warp_y;
+
+	NGP_HOST_DEVICE Eigen::Vector2f warp(const Eigen::Vector2f& x) const {
+		return {warp_x.warp(x.x()), warp_y.warp(x.y())};
+	}
+
+	NGP_HOST_DEVICE Eigen::Vector2f unwarp(const Eigen::Vector2f& y) const {
+		return {warp_x.unwarp(y.x()), warp_y.unwarp(y.y())};
+	}
+
+	NGP_HOST_DEVICE float density(const Eigen::Vector2f& x) const {
+		return warp_x.density(x.x()) * warp_y.density(x.y());
+	}
+};
+
 template <typename T>
 NGP_HOST_DEVICE inline void opencv_lens_distortion_delta(const T* extra_params, const T u, const T v, T* du, T* dv) {
 	const T k1 = extra_params[0];
@@ -292,37 +437,53 @@ inline NGP_HOST_DEVICE Eigen::Vector3f latlong_to_dir(const Eigen::Vector2f& uv)
 	return {sp * ct, st, cp * ct};
 }
 
-inline NGP_HOST_DEVICE Ray pixel_to_ray(
+inline NGP_HOST_DEVICE Eigen::Vector3f equirectangular_to_dir(const Eigen::Vector2f& uv) {
+	float ct = (uv.y() - 0.5f) * 2.0f;
+	float st = std::sqrt(std::max(1.0f - ct * ct, 0.0f));
+	float phi = (uv.x() - 0.5f) * PI() * 2.0f;
+	float sp, cp;
+	sincosf(phi, &sp, &cp);
+	return {sp * st, ct, cp * st};
+}
+
+inline NGP_HOST_DEVICE Ray uv_to_ray(
 	uint32_t spp,
-	const Eigen::Vector2i& pixel,
+	const Eigen::Vector2f& uv,
 	const Eigen::Vector2i& resolution,
 	const Eigen::Vector2f& focal_length,
 	const Eigen::Matrix<float, 3, 4>& camera_matrix,
 	const Eigen::Vector2f& screen_center,
-	const Eigen::Vector3f& parallax_shift,
-	bool snap_to_pixel_centers = false,
+	const Eigen::Vector3f& parallax_shift = Eigen::Vector3f::Zero(),
 	float near_distance = 0.0f,
 	float focus_z = 1.0f,
 	float aperture_size = 0.0f,
+	const Foveation& foveation = {},
+	Buffer2DView<const uint8_t> hidden_area_mask = {},
 	const Lens& lens = {},
-	const float* __restrict__ distortion_grid = nullptr,
-	const Eigen::Vector2i distortion_grid_resolution = Eigen::Vector2i::Zero()
+	Buffer2DView<const Eigen::Vector2f> distortion = {}
 ) {
-	Eigen::Vector2f offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp);
-	Eigen::Vector2f uv = (pixel.cast<float>() + offset).cwiseQuotient(resolution.cast<float>());
+	Eigen::Vector2f warped_uv = foveation.warp(uv);
+
+	// Check the hidden area mask _after_ applying foveation, because foveation will be undone
+	// before blitting to the framebuffer to which the hidden area mask corresponds.
+	if (hidden_area_mask && !hidden_area_mask.at(warped_uv)) {
+		return Ray::invalid();
+	}
 
 	Eigen::Vector3f dir;
 	if (lens.mode == ELensMode::FTheta) {
-		dir = f_theta_undistortion(uv - screen_center, lens.params, {1000.f, 0.f, 0.f});
-		if (dir.x() == 1000.f) {
-			return {{1000.f, 0.f, 0.f}, {0.f, 0.f, 1.f}}; // return a point outside the aabb so the pixel is not rendered
+		dir = f_theta_undistortion(warped_uv - screen_center, lens.params, {0.f, 0.f, 0.f});
+		if (dir == Eigen::Vector3f::Zero()) {
+			return Ray::invalid();
 		}
 	} else if (lens.mode == ELensMode::LatLong) {
-		dir = latlong_to_dir(uv);
+		dir = latlong_to_dir(warped_uv);
+	} else if (lens.mode == ELensMode::Equirectangular) {
+		dir = equirectangular_to_dir(warped_uv);
 	} else {
 		dir = {
-			(uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(),
-			(uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(),
+			(warped_uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(),
+			(warped_uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(),
 			1.0f
 		};
 
@@ -332,8 +493,9 @@ inline NGP_HOST_DEVICE Ray pixel_to_ray(
 			iterative_opencv_fisheye_lens_undistortion(lens.params, &dir.x(), &dir.y());
 		}
 	}
-	if (distortion_grid) {
-		dir.head<2>() += read_image<2>(distortion_grid, distortion_grid_resolution, uv);
+
+	if (distortion) {
+		dir.head<2>() += distortion.at_lerp(warped_uv);
 	}
 
 	Eigen::Vector3f head_pos = {parallax_shift.x(), parallax_shift.y(), 0.f};
@@ -341,26 +503,60 @@ inline NGP_HOST_DEVICE Ray pixel_to_ray(
 	dir = camera_matrix.block<3, 3>(0, 0) * dir;
 
 	Eigen::Vector3f origin = camera_matrix.block<3, 3>(0, 0) * head_pos + camera_matrix.col(3);
-	
-	if (aperture_size > 0.0f) {
+	if (aperture_size != 0.0f) {
 		Eigen::Vector3f lookat = origin + dir * focus_z;
-		Eigen::Vector2f blur = aperture_size * square2disk_shirley(ld_random_val_2d(spp, (uint32_t)pixel.x() * 19349663 + (uint32_t)pixel.y() * 96925573) * 2.0f - Eigen::Vector2f::Ones());
+		Eigen::Vector2f blur = aperture_size * square2disk_shirley(ld_random_val_2d(spp, uv.cwiseProduct(resolution.cast<float>()).cast<int>().dot(Eigen::Vector2i{19349663, 96925573})) * 2.0f - Eigen::Vector2f::Ones());
 		origin += camera_matrix.block<3, 2>(0, 0) * blur;
 		dir = (lookat - origin) / focus_z;
 	}
-	
-	origin += dir * near_distance;
 
+	origin += dir * near_distance;
 	return {origin, dir};
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
+inline NGP_HOST_DEVICE Ray pixel_to_ray(
+	uint32_t spp,
+	const Eigen::Vector2i& pixel,
+	const Eigen::Vector2i& resolution,
+	const Eigen::Vector2f& focal_length,
+	const Eigen::Matrix<float, 3, 4>& camera_matrix,
+	const Eigen::Vector2f& screen_center,
+	const Eigen::Vector3f& parallax_shift = Eigen::Vector3f::Zero(),
+	bool snap_to_pixel_centers = false,
+	float near_distance = 0.0f,
+	float focus_z = 1.0f,
+	float aperture_size = 0.0f,
+	const Foveation& foveation = {},
+	Buffer2DView<const uint8_t> hidden_area_mask = {},
+	const Lens& lens = {},
+	Buffer2DView<const Eigen::Vector2f> distortion = {}
+) {
+	return uv_to_ray(
+		spp,
+		(pixel.cast<float>() + ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp)).cwiseQuotient(resolution.cast<float>()),
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		near_distance,
+		focus_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask,
+		lens,
+		distortion
+	);
+}
+
+inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_uv(
 	const Eigen::Vector3f& pos,
 	const Eigen::Vector2i& resolution,
 	const Eigen::Vector2f& focal_length,
 	const Eigen::Matrix<float, 3, 4>& camera_matrix,
 	const Eigen::Vector2f& screen_center,
 	const Eigen::Vector3f& parallax_shift,
+	const Foveation& foveation = {},
 	const Lens& lens = {}
 ) {
 	// Express ray in terms of camera frame
@@ -386,10 +582,23 @@ inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
 	dir.y() += dv;
 
 	Eigen::Vector2f uv = Eigen::Vector2f{dir.x(), dir.y()}.cwiseProduct(focal_length).cwiseQuotient(resolution.cast<float>()) + screen_center;
-	return uv.cwiseProduct(resolution.cast<float>());
+	return foveation.unwarp(uv);
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
+inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
+	const Eigen::Vector3f& pos,
+	const Eigen::Vector2i& resolution,
+	const Eigen::Vector2f& focal_length,
+	const Eigen::Matrix<float, 3, 4>& camera_matrix,
+	const Eigen::Vector2f& screen_center,
+	const Eigen::Vector3f& parallax_shift,
+	const Foveation& foveation = {},
+	const Lens& lens = {}
+) {
+	return pos_to_uv(pos, resolution, focal_length, camera_matrix, screen_center, parallax_shift, foveation, lens).cwiseProduct(resolution.cast<float>());
+}
+
+inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector(
 	const uint32_t sample_index,
 	const Eigen::Vector2i& pixel,
 	const Eigen::Vector2i& resolution,
@@ -400,6 +609,8 @@ inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
 	const Eigen::Vector3f& parallax_shift,
 	const bool snap_to_pixel_centers,
 	const float depth,
+	const Foveation& foveation = {},
+	const Foveation& prev_foveation = {},
 	const Lens& lens = {}
 ) {
 	Ray ray = pixel_to_ray(
@@ -414,98 +625,39 @@ inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
 		0.0f,
 		1.0f,
 		0.0f,
-		lens,
-		nullptr,
-		Eigen::Vector2i::Zero()
+		foveation,
+		{}, // No hidden area mask
+		lens
 	);
 
 	Eigen::Vector2f prev_pixel = pos_to_pixel(
-		ray.o + ray.d * depth,
+		ray(depth),
 		resolution,
 		focal_length,
 		prev_camera,
 		screen_center,
 		parallax_shift,
+		prev_foveation,
 		lens
 	);
 
 	return prev_pixel - (pixel.cast<float>() + ld_random_pixel_offset(sample_index));
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f pixel_to_image_uv(
-	const uint32_t sample_index,
-	const Eigen::Vector2i& pixel,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const Eigen::Vector2f& image_pos,
-	const bool snap_to_pixel_centers
-) {
-	Eigen::Vector2f jit = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index);
-	Eigen::Vector2f offset = screen_center.cwiseProduct(resolution.cast<float>()) + jit;
-
-	float y_scale = view_dist;
-	float x_scale = y_scale * resolution.x() / resolution.y();
-
-	return {
-		((x_scale * (pixel.x() + offset.x())) / resolution.x() - view_dist * image_pos.x()) / image_resolution.x() * image_resolution.y(),
-		(y_scale * (pixel.y() + offset.y())) / resolution.y() - view_dist * image_pos.y()
-	};
-}
-
-inline NGP_HOST_DEVICE Eigen::Vector2f image_uv_to_pixel(
-	const Eigen::Vector2f& uv,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const Eigen::Vector2f& image_pos
-) {
-	Eigen::Vector2f offset = screen_center.cwiseProduct(resolution.cast<float>());
-
-	float y_scale = view_dist;
-	float x_scale = y_scale * resolution.x() / resolution.y();
+// Maps view-space depth (physical units) in the range [znear, zfar] hyperbolically to
+// the interval [1, 0]. This is the reverse-z-component of "normalized device coordinates",
+// which are commonly used in rasterization, where linear interpolation in screen space
+// has to be equivalent to linear interpolation in real space (which, in turn, is
+// guaranteed by the hyperbolic mapping of depth). This format is commonly found in
+// z-buffers, and hence expected by downstream image processing functions, such as DLSS
+// and VR reprojection.
+inline NGP_HOST_DEVICE float to_ndc_depth(float z, float n, float f) {
+	// View depth outside of the view frustum leads to output outside of [0, 1]
+	z = tcnn::clamp(z, n, f);
 
-	return {
-		((uv.x() / image_resolution.y() * image_resolution.x()) + view_dist * image_pos.x()) * resolution.x() / x_scale - offset.x(),
-		(uv.y() + view_dist * image_pos.y()) * resolution.y() / y_scale - offset.y()
-	};
-}
-
-inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_2d(
-	const uint32_t sample_index,
-	const Eigen::Vector2i& pixel,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const float prev_view_dist,
-	const Eigen::Vector2f& image_pos,
-	const Eigen::Vector2f& prev_image_pos,
-	const bool snap_to_pixel_centers
-) {
-	Eigen::Vector2f uv = pixel_to_image_uv(
-		sample_index,
-		pixel,
-		resolution,
-		image_resolution,
-		screen_center,
-		view_dist,
-		image_pos,
-		snap_to_pixel_centers
-	);
-
-	Eigen::Vector2f prev_pixel = image_uv_to_pixel(
-		uv,
-		resolution,
-		image_resolution,
-		screen_center,
-		prev_view_dist,
-		prev_image_pos
-	);
-
-	return prev_pixel - (pixel.cast<float>() + ld_random_pixel_offset(sample_index));
+	float scale = n / (n - f);
+	float bias = -f * scale;
+	return tcnn::clamp((z * scale + bias) / z, 0.0f, 1.0f);
 }
 
 inline NGP_HOST_DEVICE float fov_to_focal_length(int resolution, float degrees) {
@@ -587,7 +739,8 @@ inline NGP_HOST_DEVICE void apply_quilting(uint32_t* x, uint32_t* y, const Eigen
 
 	if (quilting_dims == Eigen::Vector2i{2, 1}) {
 		// Likely VR: parallax_shift.x() is the IPD in this case. The following code centers the camera matrix between both eyes.
-		parallax_shift.x() = idx ? (-0.5f * parallax_shift.x()) : (0.5f * parallax_shift.x());
+		// idx == 0 -> left eye -> -1/2 x
+		parallax_shift.x() = (idx == 0) ? (-0.5f * parallax_shift.x()) : (0.5f * parallax_shift.x());
 	} else {
 		// Likely HoloPlay lenticular display: in this case, `parallax_shift.z()` is the inverse height of the head above the display.
 		// The following code computes the x-offset of views as a function of this.
diff --git a/include/neural-graphics-primitives/dlss.h b/include/neural-graphics-primitives/dlss.h
index dbe86fccca1cd0d911535b340e8c2a05bd406a5b..49d16f95925bd0b032c44440431d7ee72a433ef3 100644
--- a/include/neural-graphics-primitives/dlss.h
+++ b/include/neural-graphics-primitives/dlss.h
@@ -54,16 +54,16 @@ public:
 	virtual EDlssQuality quality() const = 0;
 };
 
-#ifdef NGP_VULKAN
-std::shared_ptr<IDlss> dlss_init(const Eigen::Vector2i& out_resolution);
+class IDlssProvider {
+public:
+	virtual ~IDlssProvider() {}
 
-void vulkan_and_ngx_init();
-size_t dlss_allocated_bytes();
-void vulkan_and_ngx_destroy();
-#else
-inline size_t dlss_allocated_bytes() {
-	return 0;
-}
+	virtual size_t allocated_bytes() const = 0;
+	virtual std::unique_ptr<IDlss> init_dlss(const Eigen::Vector2i& out_resolution) = 0;
+};
+
+#ifdef NGP_VULKAN
+std::shared_ptr<IDlssProvider> init_vulkan_and_ngx();
 #endif
 
 NGP_NAMESPACE_END
diff --git a/include/neural-graphics-primitives/envmap.cuh b/include/neural-graphics-primitives/envmap.cuh
index 6960800719147f13170d8347a6493c83064291cc..7ba6698fec85609c9ef981a5c7a305ace0399c77 100644
--- a/include/neural-graphics-primitives/envmap.cuh
+++ b/include/neural-graphics-primitives/envmap.cuh
@@ -26,31 +26,22 @@
 
 NGP_NAMESPACE_BEGIN
 
-template <typename T>
-__device__ Eigen::Array4f read_envmap(const T* __restrict__ envmap_data, const Eigen::Vector2i envmap_resolution, const Eigen::Vector3f& dir) {
+inline __device__ Eigen::Array4f read_envmap(const Buffer2DView<const Eigen::Array4f>& envmap, const Eigen::Vector3f& dir) {
 	auto dir_cyl = dir_to_spherical_unorm({dir.z(), -dir.x(), dir.y()});
 
-	auto envmap_float = Eigen::Vector2f{dir_cyl.y() * (envmap_resolution.x()-1), dir_cyl.x() * (envmap_resolution.y()-1)};
+	auto envmap_float = Eigen::Vector2f{dir_cyl.y() * (envmap.resolution.x()-1), dir_cyl.x() * (envmap.resolution.y()-1)};
 	Eigen::Vector2i envmap_texel = envmap_float.cast<int>();
 
 	auto weight = envmap_float - envmap_texel.cast<float>();
 
 	auto read_val = [&](Eigen::Vector2i pos) {
 		if (pos.x() < 0) {
-			pos.x() += envmap_resolution.x();
-		} else if (pos.x() >= envmap_resolution.x()) {
-			pos.x() -= envmap_resolution.x();
-		}
-		pos.y() = std::max(std::min(pos.y(), envmap_resolution.y()-1), 0);
-
-		Eigen::Array4f result;
-		if (std::is_same<T, float>::value) {
-			result = *(Eigen::Array4f*)&envmap_data[(pos.x() + pos.y() * envmap_resolution.x()) * 4];
-		} else {
-			auto val = *(tcnn::vector_t<T, 4>*)&envmap_data[(pos.x() + pos.y() * envmap_resolution.x()) * 4];
-			result = {(float)val[0], (float)val[1], (float)val[2], (float)val[3]};
+			pos.x() += envmap.resolution.x();
+		} else if (pos.x() >= envmap.resolution.x()) {
+			pos.x() -= envmap.resolution.x();
 		}
-		return result;
+		pos.y() = std::max(std::min(pos.y(), envmap.resolution.y()-1), 0);
+		return envmap.at(pos);
 	};
 
 	auto result = (
diff --git a/include/neural-graphics-primitives/json_binding.h b/include/neural-graphics-primitives/json_binding.h
index 9f04a71d37a4eb30c8ad42f7f1b1f3dca9bad60f..0d2c2283607f0b5f076d9743ed030c237c9777e1 100644
--- a/include/neural-graphics-primitives/json_binding.h
+++ b/include/neural-graphics-primitives/json_binding.h
@@ -166,6 +166,7 @@ inline void to_json(nlohmann::json& j, const NerfDataset& dataset) {
 	j["from_mitsuba"] = dataset.from_mitsuba;
 	j["is_hdr"] = dataset.is_hdr;
 	j["wants_importance_sampling"] = dataset.wants_importance_sampling;
+	j["n_extra_learnable_dims"] = dataset.n_extra_learnable_dims;
 }
 
 inline void from_json(const nlohmann::json& j, NerfDataset& dataset) {
@@ -209,11 +210,14 @@ inline void from_json(const nlohmann::json& j, NerfDataset& dataset) {
 	dataset.aabb_scale = j.at("aabb_scale");
 	dataset.from_mitsuba = j.at("from_mitsuba");
 	dataset.is_hdr = j.value("is_hdr", false);
+
 	if (j.contains("wants_importance_sampling")) {
 		dataset.wants_importance_sampling = j.at("wants_importance_sampling");
 	} else {
 		dataset.wants_importance_sampling = true;
 	}
+
+	dataset.n_extra_learnable_dims = j.value("n_extra_learnable_dims", 0);
 }
 
 NGP_NAMESPACE_END
diff --git a/include/neural-graphics-primitives/openxr_hmd.h b/include/neural-graphics-primitives/openxr_hmd.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee442ef9aa5f07abeae2c2f78a06f3d376c59916
--- /dev/null
+++ b/include/neural-graphics-primitives/openxr_hmd.h
@@ -0,0 +1,261 @@
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
+ * and proprietary rights in and to this software, related documentation
+ * and any modifications thereto.  Any use, reproduction, disclosure or
+ * distribution of this software and related documentation without an express
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
+ */
+
+/** @file   openxr_hmd.h
+ *  @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA
+ *  @brief  Wrapper around the OpenXR API, providing access to
+ *          per-eye framebuffers, lens parameters, visible area,
+ *          view, hand, and eye poses, as well as controller inputs.
+ */
+
+#pragma once
+
+#ifdef _WIN32
+#  include <GL/gl3w.h>
+#else
+#  include <GL/glew.h>
+#endif
+
+#define XR_USE_GRAPHICS_API_OPENGL
+
+#include <neural-graphics-primitives/common_device.cuh>
+
+#include <openxr/openxr.h>
+#include <xr_linear.h>
+#include <xr_dependencies.h>
+#include <openxr/openxr_platform.h>
+
+#include <Eigen/Dense>
+
+#include <tiny-cuda-nn/gpu_memory.h>
+
+#include <array>
+#include <memory>
+#include <vector>
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type
+#endif
+
+NGP_NAMESPACE_BEGIN
+
+class OpenXRHMD {
+public:
+	enum class ControlFlow {
+		CONTINUE,
+		RESTART,
+		QUIT,
+	};
+
+	struct FrameInfo {
+		struct View {
+			GLuint framebuffer;
+			XrCompositionLayerProjectionView view{XR_TYPE_COMPOSITION_LAYER_PROJECTION_VIEW};
+			XrCompositionLayerDepthInfoKHR depth_info{XR_TYPE_COMPOSITION_LAYER_DEPTH_INFO_KHR};
+			std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask = nullptr;
+			Eigen::Matrix<float, 3, 4> pose;
+		};
+		struct Hand {
+			Eigen::Matrix<float, 3, 4> pose;
+			bool pose_active = false;
+			Eigen::Vector2f thumbstick = Eigen::Vector2f::Zero();
+			float grab_strength = 0.0f;
+			bool grabbing = false;
+			bool pressing = false;
+			Eigen::Vector3f grab_pos;
+			Eigen::Vector3f prev_grab_pos;
+			Eigen::Vector3f drag() const {
+				return grab_pos - prev_grab_pos;
+			}
+		};
+		std::vector<View> views;
+		Hand hands[2];
+	};
+	using FrameInfoPtr = std::shared_ptr<FrameInfo>;
+
+	// RAII OpenXRHMD with OpenGL
+#if defined(XR_USE_PLATFORM_WIN32)
+	OpenXRHMD(HDC hdc, HGLRC hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	OpenXRHMD(wl_display* display);
+#endif
+
+	virtual ~OpenXRHMD();
+
+	// disallow copy / move
+	OpenXRHMD(const OpenXRHMD&) = delete;
+	OpenXRHMD& operator=(const OpenXRHMD&) = delete;
+	OpenXRHMD(OpenXRHMD&&) = delete;
+	OpenXRHMD& operator=(OpenXRHMD&&) = delete;
+
+	void clear();
+
+	// poll events, handle state changes, return control flow information
+	ControlFlow poll_events();
+
+	// begin OpenXR frame, return views to render
+	FrameInfoPtr begin_frame();
+	// must be called for each begin_frame
+	void end_frame(FrameInfoPtr frame_info, float znear, float zfar);
+
+	// if true call begin_frame and end_frame - does not imply visibility
+	bool must_run_frame_loop() const {
+		return
+			m_session_state == XR_SESSION_STATE_READY ||
+			m_session_state == XR_SESSION_STATE_SYNCHRONIZED ||
+			m_session_state == XR_SESSION_STATE_VISIBLE ||
+			m_session_state == XR_SESSION_STATE_FOCUSED;
+	}
+
+	// if true, VR is being rendered to the HMD
+	bool is_visible() const {
+		// XR_SESSION_STATE_VISIBLE -> app content is shown in HMD
+		// XR_SESSION_STATE_FOCUSED -> VISIBLE + input is send to app
+		return m_session_state == XR_SESSION_STATE_VISIBLE || m_session_state == XR_SESSION_STATE_FOCUSED;
+	}
+
+private:
+	// steps of the init process, called from the constructor
+	void init_create_xr_instance();
+	void init_get_xr_system();
+	void init_configure_xr_views();
+	void init_check_for_xr_blend_mode();
+	void init_xr_actions();
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	void init_open_gl(HDC hdc, HGLRC hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	void init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	void init_open_gl(wl_display* display);
+#endif
+
+	void init_xr_session();
+	void init_xr_spaces();
+	void init_xr_swapchain_open_gl();
+	void init_open_gl_shaders();
+
+	// session state change
+	void session_state_change(XrSessionState state, ControlFlow& flow);
+
+	std::shared_ptr<Buffer2D<uint8_t>> rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view);
+	// system/instance
+	XrInstance m_instance{XR_NULL_HANDLE};
+	XrSystemId m_system_id = {};
+	XrInstanceProperties m_instance_properties = {XR_TYPE_INSTANCE_PROPERTIES};
+	XrSystemProperties m_system_properties = {XR_TYPE_SYSTEM_PROPERTIES};
+	std::vector<XrApiLayerProperties> m_api_layer_properties;
+	std::vector<XrExtensionProperties> m_instance_extension_properties;
+
+	// view and blending
+	XrViewConfigurationType m_view_configuration_type = {};
+	XrViewConfigurationProperties m_view_configuration_properties = {XR_TYPE_VIEW_CONFIGURATION_PROPERTIES};
+	std::vector<XrViewConfigurationView> m_view_configuration_views;
+	std::vector<XrEnvironmentBlendMode> m_environment_blend_modes;
+	XrEnvironmentBlendMode m_environment_blend_mode = {XR_ENVIRONMENT_BLEND_MODE_OPAQUE};
+
+	// actions
+	std::array<XrPath, 2> m_hand_paths;
+	std::array<XrSpace, 2> m_hand_spaces;
+	XrAction m_pose_action{XR_NULL_HANDLE};
+	XrAction m_press_action{XR_NULL_HANDLE};
+	XrAction m_grab_action{XR_NULL_HANDLE};
+
+	// Two separate actions for Xbox controller support
+	std::array<XrAction, 2> m_thumbstick_actions;
+
+	XrActionSet m_action_set{XR_NULL_HANDLE};
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	XrGraphicsBindingOpenGLWin32KHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WIN32_KHR};
+#elif defined(XR_USE_PLATFORM_XLIB)
+	XrGraphicsBindingOpenGLXlibKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_XLIB_KHR};
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	XrGraphicsBindingOpenGLWaylandKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WAYLAND_KHR};
+#endif
+
+	XrSession m_session{XR_NULL_HANDLE};
+	XrSessionState m_session_state{XR_SESSION_STATE_UNKNOWN};
+
+	// reference space
+	std::vector<XrReferenceSpaceType> m_reference_spaces;
+	XrSpace m_space{XR_NULL_HANDLE};
+	XrExtent2Df m_bounds;
+
+	// swap chains
+	struct Swapchain {
+		Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& xr_instance);
+		Swapchain(const Swapchain&) = delete;
+		Swapchain& operator=(const Swapchain&) = delete;
+		Swapchain(Swapchain&& other) {
+			*this = std::move(other);
+		}
+		Swapchain& operator=(Swapchain&& other) {
+			std::swap(handle, other.handle);
+			std::swap(depth_handle, other.depth_handle);
+			std::swap(width, other.width);
+			std::swap(height, other.height);
+			images_gl = std::move(other.images_gl);
+			depth_images_gl = std::move(other.depth_images_gl);
+			framebuffers_gl = std::move(other.framebuffers_gl);
+			return *this;
+		}
+		virtual ~Swapchain();
+
+		void clear();
+
+		XrSwapchain handle{XR_NULL_HANDLE};
+		XrSwapchain depth_handle{XR_NULL_HANDLE};
+
+		int32_t width = 0;
+		int32_t height = 0;
+		std::vector<XrSwapchainImageOpenGLKHR> images_gl;
+		std::vector<XrSwapchainImageOpenGLKHR> depth_images_gl;
+		std::vector<GLuint> framebuffers_gl;
+	};
+
+	int64_t m_swapchain_rgba_format = 0;
+	std::vector<Swapchain> m_swapchains;
+
+	bool m_supports_composition_layer_depth = false;
+	int64_t m_swapchain_depth_format = 0;
+
+	bool m_supports_hidden_area_mask = false;
+	std::vector<std::shared_ptr<Buffer2D<uint8_t>>> m_hidden_area_masks;
+
+	bool m_supports_eye_tracking = false;
+
+	// frame data
+	XrFrameState m_frame_state{XR_TYPE_FRAME_STATE};
+	FrameInfoPtr m_previous_frame_info;
+
+	GLuint m_hidden_area_mask_program = 0;
+
+	// print more debug info during OpenXRs init:
+	const bool m_print_api_layers = false;
+	const bool m_print_extensions = false;
+	const bool m_print_system_properties = false;
+	const bool m_print_instance_properties = false;
+	const bool m_print_view_configuration_types = false;
+	const bool m_print_view_configuration_properties = false;
+	const bool m_print_view_configuration_view = false;
+	const bool m_print_environment_blend_modes = false;
+	const bool m_print_available_swapchain_formats = false;
+	const bool m_print_reference_spaces = false;
+};
+
+NGP_NAMESPACE_END
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
diff --git a/include/neural-graphics-primitives/random_val.cuh b/include/neural-graphics-primitives/random_val.cuh
index 667c938dd4c37ddecda10b8939ec1cd27c51a96e..e436b18ba23500a909a1ef27fca01b62bb0999dc 100644
--- a/include/neural-graphics-primitives/random_val.cuh
+++ b/include/neural-graphics-primitives/random_val.cuh
@@ -61,11 +61,16 @@ inline __host__ __device__ Eigen::Vector2f dir_to_cylindrical(const Eigen::Vecto
 	return {(cos_theta + 1.0f) / 2.0f, (phi / (2.0f * PI())) + 0.5f};
 }
 
-inline __host__ __device__ Eigen::Vector2f dir_to_spherical_unorm(const Eigen::Vector3f& d) {
+inline __host__ __device__ Eigen::Vector2f dir_to_spherical(const Eigen::Vector3f& d) {
 	const float cos_theta = fminf(fmaxf(d.z(), -1.0f), 1.0f);
 	const float theta = acosf(cos_theta);
 	float phi = std::atan2(d.y(), d.x());
-	return {theta / PI(), (phi / (2.0f * PI()) + 0.5f)};
+	return {theta, phi};
+}
+
+inline __host__ __device__ Eigen::Vector2f dir_to_spherical_unorm(const Eigen::Vector3f& d) {
+	Eigen::Vector2f spherical = dir_to_spherical(d);
+	return {spherical.x() / PI(), (spherical.y() / (2.0f * PI()) + 0.5f)};
 }
 
 template <typename RNG>
diff --git a/include/neural-graphics-primitives/render_buffer.h b/include/neural-graphics-primitives/render_buffer.h
index 2e29f72a6fd896c815645fa81273bbf20e219ff2..0e51364f36f228209b125f6f507c41dc0e2f43fb 100644
--- a/include/neural-graphics-primitives/render_buffer.h
+++ b/include/neural-graphics-primitives/render_buffer.h
@@ -34,7 +34,7 @@ public:
 	virtual cudaSurfaceObject_t surface() = 0;
 	virtual cudaArray_t array() = 0;
 	virtual Eigen::Vector2i resolution() const = 0;
-	virtual void resize(const Eigen::Vector2i&) = 0;
+	virtual void resize(const Eigen::Vector2i&, int n_channels = 4) = 0;
 };
 
 class CudaSurface2D : public SurfaceProvider {
@@ -50,7 +50,7 @@ public:
 
 	void free();
 
-	void resize(const Eigen::Vector2i& size) override;
+	void resize(const Eigen::Vector2i& size, int n_channels) override;
 
 	cudaSurfaceObject_t surface() override {
 		return m_surface;
@@ -65,7 +65,8 @@ public:
 	}
 
 private:
-	Eigen::Vector2i m_size = Eigen::Vector2i::Constant(0);
+	Eigen::Vector2i m_size = Eigen::Vector2i::Zero();
+	int m_n_channels = 0;
 	cudaArray_t m_array;
 	cudaSurfaceObject_t m_surface;
 };
@@ -111,10 +112,10 @@ public:
 
 	void load(const uint8_t* data, Eigen::Vector2i new_size, int n_channels);
 
-	void resize(const Eigen::Vector2i& new_size, int n_channels, bool is_8bit = false);
+	void resize(const Eigen::Vector2i& new_size, int n_channels, bool is_8bit);
 
-	void resize(const Eigen::Vector2i& new_size) override {
-		resize(new_size, 4);
+	void resize(const Eigen::Vector2i& new_size, int n_channels) override {
+		resize(new_size, n_channels, false);
 	}
 
 	Eigen::Vector2i resolution() const override {
@@ -124,7 +125,7 @@ public:
 private:
 	class CUDAMapping {
 	public:
-		CUDAMapping(GLuint texture_id, const Eigen::Vector2i& size);
+		CUDAMapping(GLuint texture_id, const Eigen::Vector2i& size, int n_channels);
 		~CUDAMapping();
 
 		cudaSurfaceObject_t surface() const { return m_cuda_surface ? m_cuda_surface->surface() : m_surface; }
@@ -141,6 +142,7 @@ private:
 		cudaSurfaceObject_t m_surface = {};
 
 		Eigen::Vector2i m_size;
+		int m_n_channels;
 		std::vector<float> m_data_cpu;
 
 		std::unique_ptr<CudaSurface2D> m_cuda_surface;
@@ -157,9 +159,20 @@ private:
 };
 #endif //NGP_GUI
 
+struct CudaRenderBufferView {
+	Eigen::Array4f* frame_buffer = nullptr;
+	float* depth_buffer = nullptr;
+	Eigen::Vector2i resolution = Eigen::Vector2i::Zero();
+	uint32_t spp = 0;
+
+	std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask = nullptr;
+
+	void clear(cudaStream_t stream) const;
+};
+
 class CudaRenderBuffer {
 public:
-	CudaRenderBuffer(const std::shared_ptr<SurfaceProvider>& surf) : m_surface_provider{surf} {}
+	CudaRenderBuffer(const std::shared_ptr<SurfaceProvider>& rgba, const std::shared_ptr<SurfaceProvider>& depth = nullptr) : m_rgba_target{rgba}, m_depth_target{depth} {}
 
 	CudaRenderBuffer(const CudaRenderBuffer& other) = delete;
 	CudaRenderBuffer& operator=(const CudaRenderBuffer& other) = delete;
@@ -167,7 +180,7 @@ public:
 	CudaRenderBuffer& operator=(CudaRenderBuffer&& other) = default;
 
 	cudaSurfaceObject_t surface() {
-		return m_surface_provider->surface();
+		return m_rgba_target->surface();
 	}
 
 	Eigen::Vector2i in_resolution() const {
@@ -175,7 +188,7 @@ public:
 	}
 
 	Eigen::Vector2i out_resolution() const {
-		return m_surface_provider->resolution();
+		return m_rgba_target->resolution();
 	}
 
 	void resize(const Eigen::Vector2i& res);
@@ -204,11 +217,21 @@ public:
 		return m_accumulate_buffer.data();
 	}
 
+	CudaRenderBufferView view() const {
+		return {
+			frame_buffer(),
+			depth_buffer(),
+			in_resolution(),
+			spp(),
+			hidden_area_mask(),
+		};
+	}
+
 	void clear_frame(cudaStream_t stream);
 
 	void accumulate(float exposure, cudaStream_t stream);
 
-	void tonemap(float exposure, const Eigen::Array4f& background_color, EColorSpace output_color_space, cudaStream_t stream);
+	void tonemap(float exposure, const Eigen::Array4f& background_color, EColorSpace output_color_space, float znear, float zfar, cudaStream_t stream);
 
 	void overlay_image(
 		float alpha,
@@ -238,7 +261,7 @@ public:
 	void overlay_false_color(Eigen::Vector2i training_resolution, bool to_srgb, int fov_axis, cudaStream_t stream, const float *error_map, Eigen::Vector2i error_map_resolution, const float *average, float brightness, bool viridis);
 
 	SurfaceProvider& surface_provider() {
-		return *m_surface_provider;
+		return *m_rgba_target;
 	}
 
 	void set_color_space(EColorSpace color_space) {
@@ -255,22 +278,30 @@ public:
 		}
 	}
 
-	void enable_dlss(const Eigen::Vector2i& max_out_res);
+	void enable_dlss(IDlssProvider& dlss_provider, const Eigen::Vector2i& max_out_res);
 	void disable_dlss();
 	void set_dlss_sharpening(float value) {
 		m_dlss_sharpening = value;
 	}
 
-	const std::shared_ptr<IDlss>& dlss() const {
+	const std::unique_ptr<IDlss>& dlss() const {
 		return m_dlss;
 	}
 
+	void set_hidden_area_mask(const std::shared_ptr<Buffer2D<uint8_t>>& hidden_area_mask) {
+		m_hidden_area_mask = hidden_area_mask;
+	}
+
+	const std::shared_ptr<Buffer2D<uint8_t>>& hidden_area_mask() const {
+		return m_hidden_area_mask;
+	}
+
 private:
 	uint32_t m_spp = 0;
 	EColorSpace m_color_space = EColorSpace::Linear;
 	ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity;
 
-	std::shared_ptr<IDlss> m_dlss;
+	std::unique_ptr<IDlss> m_dlss;
 	float m_dlss_sharpening = 0.0f;
 
 	Eigen::Vector2i m_in_resolution = Eigen::Vector2i::Zero();
@@ -279,7 +310,10 @@ private:
 	tcnn::GPUMemory<float> m_depth_buffer;
 	tcnn::GPUMemory<Eigen::Array4f> m_accumulate_buffer;
 
-	std::shared_ptr<SurfaceProvider> m_surface_provider;
+	std::shared_ptr<Buffer2D<uint8_t>> m_hidden_area_mask = nullptr;
+
+	std::shared_ptr<SurfaceProvider> m_rgba_target;
+	std::shared_ptr<SurfaceProvider> m_depth_target;
 };
 
 NGP_NAMESPACE_END
diff --git a/include/neural-graphics-primitives/testbed.h b/include/neural-graphics-primitives/testbed.h
index e6db007b6ce0c39437b4ac51ff1820566c368ed3..9459285110f29eb678dac7cd0c9063db5c2bef51 100644
--- a/include/neural-graphics-primitives/testbed.h
+++ b/include/neural-graphics-primitives/testbed.h
@@ -26,6 +26,10 @@
 #include <neural-graphics-primitives/thread_pool.h>
 #include <neural-graphics-primitives/trainable_buffer.cuh>
 
+#ifdef NGP_GUI
+#  include <neural-graphics-primitives/openxr_hmd.h>
+#endif
+
 #include <tiny-cuda-nn/multi_stream.h>
 #include <tiny-cuda-nn/random.h>
 
@@ -95,10 +99,11 @@ public:
 			float near_distance,
 			float plane_z,
 			float aperture_size,
-			const float* envmap_data,
-			const Eigen::Vector2i& envmap_resolution,
+			const Foveation& foveation,
+			const Buffer2DView<const Eigen::Array4f>& envmap,
 			Eigen::Array4f* frame_buffer,
 			float* depth_buffer,
+			const Buffer2DView<const uint8_t>& hidden_area_mask,
 			const TriangleOctree* octree,
 			uint32_t n_octree_levels,
 			cudaStream_t stream
@@ -151,21 +156,20 @@ public:
 			const Eigen::Vector4f& rolling_shutter,
 			const Eigen::Vector2f& screen_center,
 			const Eigen::Vector3f& parallax_shift,
-			const Eigen::Vector2i& quilting_dims,
 			bool snap_to_pixel_centers,
 			const BoundingBox& render_aabb,
 			const Eigen::Matrix3f& render_aabb_to_local,
 			float near_distance,
 			float plane_z,
 			float aperture_size,
+			const Foveation& foveation,
 			const Lens& lens,
-			const float* envmap_data,
-			const Eigen::Vector2i& envmap_resolution,
-			const float* distortion_data,
-			const Eigen::Vector2i& distortion_resolution,
+			const Buffer2DView<const Eigen::Array4f>& envmap,
+			const Buffer2DView<const Eigen::Vector2f>& distortion,
 			Eigen::Array4f* frame_buffer,
 			float* depth_buffer,
-			uint8_t* grid,
+			const Buffer2DView<const uint8_t>& hidden_area_mask,
+			const uint8_t* grid,
 			int show_accel,
 			float cone_angle_constant,
 			ERenderMode render_mode,
@@ -177,8 +181,6 @@ public:
 			const BoundingBox& render_aabb,
 			const Eigen::Matrix3f& render_aabb_to_local,
 			const BoundingBox& train_aabb,
-			const uint32_t n_training_images,
-			const TrainingXForm* training_xforms,
 			const Eigen::Vector2f& focal_length,
 			float cone_angle_constant,
 			const uint8_t* grid,
@@ -250,7 +252,11 @@ public:
 		int count;
 	};
 
-	static constexpr float LOSS_SCALE = 128.f;
+	// Due to mixed-precision training, small loss values can lead to
+	// underflow (round to zero) in the gradient computations. Hence,
+	// scale the loss (and thereby gradients) up by this factor and
+	// divide it out in the optimizer later on.
+	static constexpr float LOSS_SCALE = 128.0f;
 
 	struct NetworkDims {
 		uint32_t n_input;
@@ -265,30 +271,91 @@ public:
 
 	NetworkDims network_dims() const;
 
-	void render_volume(CudaRenderBuffer& render_buffer,
-		const Eigen::Vector2f& focal_length,
-		const Eigen::Matrix<float, 3, 4>& camera_matrix,
-		const Eigen::Vector2f& screen_center,
-		cudaStream_t stream
-	);
 	void train_volume(size_t target_batch_size, bool get_loss_scalar, cudaStream_t stream);
 	void training_prep_volume(uint32_t batch_size, cudaStream_t stream) {}
 	void load_volume(const fs::path& data_path);
 
+	class CudaDevice;
+
+	const float* get_inference_extra_dims(cudaStream_t stream) const;
+	void render_nerf(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		NerfNetwork<precision_t>& nerf_network,
+		const uint8_t* density_grid_bitfield,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Vector4f& rolling_shutter,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
 	void render_sdf(
+		cudaStream_t stream,
 		const distance_fun_t& distance_function,
 		const normals_fun_t& normals_function,
-		CudaRenderBuffer& render_buffer,
-		const Eigen::Vector2i& max_res,
+		const CudaRenderBufferView& render_buffer,
 		const Eigen::Vector2f& focal_length,
 		const Eigen::Matrix<float, 3, 4>& camera_matrix,
 		const Eigen::Vector2f& screen_center,
-		cudaStream_t stream
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_image(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_volume(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation
+	);
+
+	void render_frame(
+		cudaStream_t stream,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Matrix<float, 3, 4>& prev_camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Eigen::Vector4f& nerf_rolling_shutter,
+		const Foveation& foveation,
+		const Foveation& prev_foveation,
+		int visualized_dimension,
+		CudaRenderBuffer& render_buffer,
+		bool to_srgb = true,
+		CudaDevice* device = nullptr
+	);
+	void render_frame_main(
+		CudaDevice& device,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Eigen::Vector4f& nerf_rolling_shutter,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_frame_epilogue(
+		cudaStream_t stream,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& prev_camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Foveation& foveation,
+		const Foveation& prev_foveation,
+		CudaRenderBuffer& render_buffer,
+		bool to_srgb = true
 	);
-	const float* get_inference_extra_dims(cudaStream_t stream) const;
-	void render_nerf(CudaRenderBuffer& render_buffer, const Eigen::Vector2i& max_res, const Eigen::Vector2f& focal_length, const Eigen::Matrix<float, 3, 4>& camera_matrix0, const Eigen::Matrix<float, 3, 4>& camera_matrix1, const Eigen::Vector4f& rolling_shutter, const Eigen::Vector2f& screen_center, cudaStream_t stream);
-	void render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream);
-	void render_frame(const Eigen::Matrix<float, 3, 4>& camera_matrix0, const Eigen::Matrix<float, 3, 4>& camera_matrix1, const Eigen::Vector4f& nerf_rolling_shutter, CudaRenderBuffer& render_buffer, bool to_srgb = true) ;
 	void visualize_nerf_cameras(ImDrawList* list, const Eigen::Matrix<float, 4, 4>& world2proj);
 	fs::path find_network_config(const fs::path& network_config_path);
 	nlohmann::json load_network_config(const fs::path& network_config_path);
@@ -310,9 +377,10 @@ public:
 	void set_min_level(float minlevel);
 	void set_visualized_dim(int dim);
 	void set_visualized_layer(int layer);
-	void translate_camera(const Eigen::Vector3f& rel);
-	void mouse_drag(const Eigen::Vector2f& rel, int button);
-	void mouse_wheel(Eigen::Vector2f m, float delta);
+	void translate_camera(const Eigen::Vector3f& rel, const Eigen::Matrix3f& rot, bool allow_up_down = true);
+	Eigen::Matrix3f rotation_from_angles(const Eigen::Vector2f& angles) const;
+	void mouse_drag();
+	void mouse_wheel();
 	void load_file(const fs::path& path);
 	void set_nerf_camera_matrix(const Eigen::Matrix<float, 3, 4>& cam);
 	Eigen::Vector3f look_at() const;
@@ -334,6 +402,7 @@ public:
 	void generate_training_samples_sdf(Eigen::Vector3f* positions, float* distances, uint32_t n_to_generate, cudaStream_t stream, bool uniform_only);
 	void update_density_grid_nerf(float decay, uint32_t n_uniform_density_grid_samples, uint32_t n_nonuniform_density_grid_samples, cudaStream_t stream);
 	void update_density_grid_mean_and_bitfield(cudaStream_t stream);
+	void mark_density_grid_in_sphere_empty(const Eigen::Vector3f& pos, float radius, cudaStream_t stream);
 
 	struct NerfCounters {
 		tcnn::GPUMemory<uint32_t> numsteps_counter; // number of steps each ray took
@@ -364,8 +433,8 @@ public:
 	void training_prep_sdf(uint32_t batch_size, cudaStream_t stream);
 	void training_prep_image(uint32_t batch_size, cudaStream_t stream) {}
 	void train(uint32_t batch_size);
-	Eigen::Vector2f calc_focal_length(const Eigen::Vector2i& resolution, int fov_axis, float zoom) const ;
-	Eigen::Vector2f render_screen_center() const ;
+	Eigen::Vector2f calc_focal_length(const Eigen::Vector2i& resolution, const Eigen::Vector2f& relative_focal_length, int fov_axis, float zoom) const;
+	Eigen::Vector2f render_screen_center(const Eigen::Vector2f& screen_center) const;
 	void optimise_mesh_step(uint32_t N_STEPS);
 	void compute_mesh_vertex_colors();
 	tcnn::GPUMemory<float> get_density_on_grid(Eigen::Vector3i res3d, const BoundingBox& aabb, const Eigen::Matrix3f& render_aabb_to_local); // network version (nerf or sdf)
@@ -373,9 +442,8 @@ public:
 	tcnn::GPUMemory<Eigen::Array4f> get_rgba_on_grid(Eigen::Vector3i res3d, Eigen::Vector3f ray_dir, bool voxel_centers, float depth, bool density_as_alpha = false);
 	int marching_cubes(Eigen::Vector3i res3d, const BoundingBox& render_aabb, const Eigen::Matrix3f& render_aabb_to_local, float thresh);
 
-	// Determines the 3d focus point by rendering a little 16x16 depth image around
-	// the mouse cursor and picking the median depth.
-	void determine_autofocus_target_from_pixel(const Eigen::Vector2i& focus_pixel);
+	float get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const Eigen::Vector2f& uv);
+	Eigen::Vector3f get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const Eigen::Vector2i& focus_pixel);
 	void autofocus();
 	size_t n_params();
 	size_t first_encoder_param();
@@ -396,7 +464,10 @@ public:
 	void destroy_window();
 	void apply_camera_smoothing(float elapsed_ms);
 	int find_best_training_view(int default_view);
-	bool begin_frame_and_handle_user_input();
+	bool begin_frame();
+	void handle_user_input();
+	Eigen::Vector3f vr_to_world(const Eigen::Vector3f& pos) const;
+	void begin_vr_frame_and_handle_vr_input();
 	void gather_histograms();
 	void draw_gui();
 	bool frame();
@@ -479,18 +550,18 @@ public:
 	bool m_dynamic_res = true;
 	float m_dynamic_res_target_fps = 20.0f;
 	int m_fixed_res_factor = 8;
-	float m_last_render_res_factor = 1.0f;
 	float m_scale = 1.0;
-	float m_prev_scale = 1.0;
 	float m_aperture_size = 0.0f;
 	Eigen::Vector2f m_relative_focal_length = Eigen::Vector2f::Ones();
 	uint32_t m_fov_axis = 1;
 	float m_zoom = 1.f; // 2d zoom factor (for insets?)
 	Eigen::Vector2f m_screen_center = Eigen::Vector2f::Constant(0.5f); // center of 2d zoom
 
+	float m_ndc_znear = 1.0f / 32.0f;
+	float m_ndc_zfar = 128.0f;
+
 	Eigen::Matrix<float, 3, 4> m_camera = Eigen::Matrix<float, 3, 4>::Zero();
 	Eigen::Matrix<float, 3, 4> m_smoothed_camera = Eigen::Matrix<float, 3, 4>::Zero();
-	Eigen::Matrix<float, 3, 4> m_prev_camera = Eigen::Matrix<float, 3, 4>::Zero();
 	size_t m_render_skip_due_to_lack_of_camera_movement_counter = 0;
 
 	bool m_fps_camera = false;
@@ -505,8 +576,6 @@ public:
 	float m_bounding_radius = 1;
 	float m_exposure = 0.f;
 
-	Eigen::Vector2i m_quilting_dims = Eigen::Vector2i::Ones();
-
 	ERenderMode m_render_mode = ERenderMode::Shade;
 	EMeshRenderMode m_mesh_render_mode = EMeshRenderMode::VertexNormals;
 
@@ -520,19 +589,31 @@ public:
 		void draw(GLuint texture);
 	} m_second_window;
 
+	float m_drag_depth = 1.0f;
+
+	// The VAO will be empty, but we need a valid one for attribute-less rendering
+	GLuint m_blit_vao = 0;
+	GLuint m_blit_program = 0;
+
+	void init_opengl_shaders();
+	void blit_texture(const Foveation& foveation, GLint rgba_texture, GLint rgba_filter_mode, GLint depth_texture, GLint framebuffer, const Eigen::Vector2i& offset, const Eigen::Vector2i& resolution);
+
 	void create_second_window();
 
+	std::unique_ptr<OpenXRHMD> m_hmd;
+	OpenXRHMD::FrameInfoPtr m_vr_frame_info;
+	void init_vr();
+	void set_n_views(size_t n_views);
+
 	std::function<bool()> m_keyboard_event_callback;
 
 	std::shared_ptr<GLTexture> m_pip_render_texture;
-	std::vector<std::shared_ptr<GLTexture>> m_render_textures;
+	std::vector<std::shared_ptr<GLTexture>> m_rgba_render_textures;
+	std::vector<std::shared_ptr<GLTexture>> m_depth_render_textures;
 #endif
 
-	ThreadPool m_thread_pool;
-	std::vector<std::future<void>> m_render_futures;
 
-	std::vector<CudaRenderBuffer> m_render_surfaces;
-	std::unique_ptr<CudaRenderBuffer> m_pip_render_surface;
+	std::unique_ptr<CudaRenderBuffer> m_pip_render_buffer;
 
 	SharedQueue<std::unique_ptr<ICallable>> m_task_queue;
 
@@ -731,8 +812,6 @@ public:
 	};
 
 	struct Image {
-		Eigen::Vector2f pos = Eigen::Vector2f::Constant(0.0f);
-		Eigen::Vector2f prev_pos = Eigen::Vector2f::Constant(0.0f);
 		tcnn::GPUMemory<char> data;
 
 		EDataType type = EDataType::Float;
@@ -785,7 +864,7 @@ public:
 	EColorSpace m_color_space = EColorSpace::Linear;
 	ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity;
 	bool m_dlss = false;
-	bool m_dlss_supported = false;
+	std::shared_ptr<IDlssProvider> m_dlss_provider;
 	float m_dlss_sharpening = 0.0f;
 
 	// 3D stuff
@@ -814,13 +893,35 @@ public:
 	Eigen::Array4f m_background_color = {0.0f, 0.0f, 0.0f, 1.0f};
 
 	bool m_vsync = false;
+	bool m_render_transparency_as_checkerboard = false;
 
 	// Visualization of neuron activations
 	int m_visualized_dimension = -1;
 	int m_visualized_layer = 0;
+
+	struct View {
+		std::shared_ptr<CudaRenderBuffer> render_buffer;
+		Eigen::Vector2i full_resolution = {1, 1};
+		int visualized_dimension = 0;
+
+		Eigen::Matrix<float, 3, 4> camera0 = Eigen::Matrix<float, 3, 4>::Zero();
+		Eigen::Matrix<float, 3, 4> camera1 = Eigen::Matrix<float, 3, 4>::Zero();
+		Eigen::Matrix<float, 3, 4> prev_camera = Eigen::Matrix<float, 3, 4>::Zero();
+
+		Foveation foveation;
+		Foveation prev_foveation;
+
+		Eigen::Vector2f relative_focal_length;
+		Eigen::Vector2f screen_center;
+
+		CudaDevice* device = nullptr;
+	};
+
+	std::vector<View> m_views;
 	Eigen::Vector2i m_n_views = {1, 1};
-	Eigen::Vector2i m_view_size = {1, 1};
-	bool m_single_view = true; // Whether a single neuron is visualized, or all in a tiled grid
+
+	bool m_single_view = true;
+
 	float m_picture_in_picture_res = 0.f; // if non zero, requests a small second picture :)
 
 	struct ImGuiVars {
@@ -835,9 +936,10 @@ public:
 	} m_imgui;
 
 	bool m_visualize_unit_cube = false;
-	bool m_snap_to_pixel_centers = false;
 	bool m_edit_render_aabb = false;
 
+	bool m_snap_to_pixel_centers = false;
+
 	Eigen::Vector3f m_parallax_shift = {0.0f, 0.0f, 0.0f}; // to shift the viewer's origin by some amount in camera space
 
 	// CUDA stuff
@@ -863,6 +965,172 @@ public:
 	bool m_train_encoding = true;
 	bool m_train_network = true;
 
+	class CudaDevice {
+	public:
+		struct Data {
+			tcnn::GPUMemory<uint8_t> density_grid_bitfield;
+			uint8_t* density_grid_bitfield_ptr;
+
+			tcnn::GPUMemory<precision_t> params;
+			std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask;
+		};
+
+		CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} {
+			auto guard = device_guard();
+			m_stream = std::make_unique<tcnn::StreamAndEvent>();
+			m_data = std::make_unique<Data>();
+			m_render_worker = std::make_unique<ThreadPool>(is_primary ? 0u : 1u);
+		}
+
+		CudaDevice(const CudaDevice&) = delete;
+		CudaDevice& operator=(const CudaDevice&) = delete;
+
+		CudaDevice(CudaDevice&&) = default;
+		CudaDevice& operator=(CudaDevice&&) = default;
+
+		tcnn::ScopeGuard device_guard() {
+			int prev_device = tcnn::cuda_device();
+			if (prev_device == m_id) {
+				return {};
+			}
+
+			tcnn::set_cuda_device(m_id);
+			return tcnn::ScopeGuard{[prev_device]() {
+				tcnn::set_cuda_device(prev_device);
+			}};
+		}
+
+		int id() const {
+			return m_id;
+		}
+
+		bool is_primary() const {
+			return m_is_primary;
+		}
+
+		std::string name() const {
+			return tcnn::cuda_device_name(m_id);
+		}
+
+		int compute_capability() const {
+			return tcnn::cuda_compute_capability(m_id);
+		}
+
+		cudaStream_t stream() const {
+			return m_stream->get();
+		}
+
+		void wait_for(cudaStream_t stream) const {
+			CUDA_CHECK_THROW(cudaEventRecord(m_primary_device_event.event, stream));
+			m_stream->wait_for(m_primary_device_event.event);
+		}
+
+		void signal(cudaStream_t stream) const {
+			m_stream->signal(stream);
+		}
+
+		const CudaRenderBufferView& render_buffer_view() const {
+			return m_render_buffer_view;
+		}
+
+		void set_render_buffer_view(const CudaRenderBufferView& view) {
+			m_render_buffer_view = view;
+		}
+
+		Data& data() const {
+			return *m_data;
+		}
+
+		bool dirty() const {
+			return m_dirty;
+		}
+
+		void set_dirty(bool value) {
+			m_dirty = value;
+		}
+
+		void set_network(const std::shared_ptr<tcnn::Network<float, precision_t>>& network) {
+			m_network = network;
+		}
+
+		void set_nerf_network(const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network);
+
+		const std::shared_ptr<tcnn::Network<float, precision_t>>& network() const {
+			return m_network;
+		}
+
+		const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network() const {
+			return m_nerf_network;
+		}
+
+		void clear() {
+			m_data = std::make_unique<Data>();
+			m_render_buffer_view = {};
+			m_network = {};
+			m_nerf_network = {};
+			set_dirty(true);
+		}
+
+		template <class F>
+		auto enqueue_task(F&& f) -> std::future<std::result_of_t <F()>> {
+			if (is_primary()) {
+				return std::async(std::launch::deferred, std::forward<F>(f));
+			} else {
+				return m_render_worker->enqueue_task(std::forward<F>(f));
+			}
+		}
+
+	private:
+		int m_id;
+		bool m_is_primary;
+		std::unique_ptr<tcnn::StreamAndEvent> m_stream;
+		struct Event {
+			Event() {
+				CUDA_CHECK_THROW(cudaEventCreate(&event));
+			}
+
+			~Event() {
+				cudaEventDestroy(event);
+			}
+
+			Event(const Event&) = delete;
+			Event& operator=(const Event&) = delete;
+			Event(Event&& other) { *this = std::move(other); }
+			Event& operator=(Event&& other) {
+				std::swap(event, other.event);
+				return *this;
+			}
+
+			cudaEvent_t event = {};
+		};
+		Event m_primary_device_event;
+		std::unique_ptr<Data> m_data;
+		CudaRenderBufferView m_render_buffer_view = {};
+
+		std::shared_ptr<tcnn::Network<float, precision_t>> m_network;
+		std::shared_ptr<NerfNetwork<precision_t>> m_nerf_network;
+
+		bool m_dirty = true;
+
+		std::unique_ptr<ThreadPool> m_render_worker;
+	};
+
+	void sync_device(CudaRenderBuffer& render_buffer, CudaDevice& device);
+	tcnn::ScopeGuard use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, CudaDevice& device);
+	void set_all_devices_dirty();
+
+	std::vector<CudaDevice> m_devices;
+	CudaDevice& primary_device() {
+		return m_devices.front();
+	}
+
+	ThreadPool m_thread_pool;
+	std::vector<std::future<void>> m_render_futures;
+
+	bool m_use_aux_devices = false;
+	bool m_foveated_rendering = false;
+	float m_foveated_rendering_max_scaling = 2.0f;
+
 	fs::path m_data_path;
 	fs::path m_network_config_path = "base.json";
 
@@ -876,8 +1144,8 @@ public:
 	uint32_t network_width(uint32_t layer) const;
 	uint32_t network_num_forward_activations() const;
 
-	std::shared_ptr<tcnn::Loss<precision_t>> m_loss;
 	// Network & training stuff
+	std::shared_ptr<tcnn::Loss<precision_t>> m_loss;
 	std::shared_ptr<tcnn::Optimizer<precision_t>> m_optimizer;
 	std::shared_ptr<tcnn::Encoding<precision_t>> m_encoding;
 	std::shared_ptr<tcnn::Network<float, precision_t>> m_network;
@@ -890,6 +1158,22 @@ public:
 
 		Eigen::Vector2i resolution;
 		ELossType loss_type;
+
+		Buffer2DView<const Eigen::Array4f> inference_view() const {
+			if (!envmap) {
+				return {};
+			}
+
+			return {(const Eigen::Array4f*)envmap->inference_params(), resolution};
+		}
+
+		Buffer2DView<const Eigen::Array4f> view() const {
+			if (!envmap) {
+				return {};
+			}
+
+			return {(const Eigen::Array4f*)envmap->params(), resolution};
+		}
 	} m_envmap;
 
 	struct TrainableDistortionMap {
@@ -897,6 +1181,22 @@ public:
 		std::shared_ptr<TrainableBuffer<2, 2, float>> map;
 		std::shared_ptr<tcnn::Trainer<float, float, float>> trainer;
 		Eigen::Vector2i resolution;
+
+		Buffer2DView<const Eigen::Vector2f> inference_view() const {
+			if (!map) {
+				return {};
+			}
+
+			return {(const Eigen::Vector2f*)map->inference_params(), resolution};
+		}
+
+		Buffer2DView<const Eigen::Vector2f> view() const {
+			if (!map) {
+				return {};
+			}
+
+			return {(const Eigen::Vector2f*)map->params(), resolution};
+		}
 	} m_distortion;
 	std::shared_ptr<NerfNetwork<precision_t>> m_nerf_network;
 };
diff --git a/scripts/colmap2nerf.py b/scripts/colmap2nerf.py
index 88ba5e67d6e261aae7d0dd26feafa63566e4c74b..30f5104a0f597dea4e645cf141ef8584d6e61136 100755
--- a/scripts/colmap2nerf.py
+++ b/scripts/colmap2nerf.py
@@ -37,7 +37,7 @@ def parse_args():
 	parser.add_argument("--colmap_camera_params", default="", help="Intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist")
 	parser.add_argument("--images", default="images", help="Input path to the images.")
 	parser.add_argument("--text", default="colmap_text", help="Input path to the colmap text files (set automatically if --run_colmap is used).")
-	parser.add_argument("--aabb_scale", default=16, choices=["1", "2", "4", "8", "16", "32", "64", "128"], help="Large scene scale factor. 1=scene fits in unit cube; power of 2 up to 128")
+	parser.add_argument("--aabb_scale", default=64, choices=["1", "2", "4", "8", "16", "32", "64", "128"], help="Large scene scale factor. 1=scene fits in unit cube; power of 2 up to 128")
 	parser.add_argument("--skip_early", default=0, help="Skip this many images from the start.")
 	parser.add_argument("--keep_colmap_coords", action="store_true", help="Keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering).")
 	parser.add_argument("--out", default="transforms.json", help="Output path.")
@@ -414,8 +414,7 @@ if __name__ == "__main__":
 		from detectron2 import model_zoo
 		from detectron2.engine import DefaultPredictor
 
-		dir_path = Path(os.path.dirname(os.path.realpath(__file__)))
-		category2id = json.load(open(dir_path / "category2id.json", "r"))
+		category2id = json.load(open(SCRIPTS_FOLDER / "category2id.json", "r"))
 		mask_ids = [category2id[c] for c in args.mask_categories]
 
 		cfg = get_cfg()
diff --git a/scripts/run.py b/scripts/run.py
index 16152cce42ddf2f2b00edeac8fe0e4088512f33b..d7b2054383b6eda31656629ce5dba0f2ed23514b 100644
--- a/scripts/run.py
+++ b/scripts/run.py
@@ -64,6 +64,7 @@ def parse_args():
 	parser.add_argument("--train", action="store_true", help="If the GUI is enabled, controls whether training starts immediately.")
 	parser.add_argument("--n_steps", type=int, default=-1, help="Number of steps to train for before quitting.")
 	parser.add_argument("--second_window", action="store_true", help="Open a second window containing a copy of the main output.")
+	parser.add_argument("--vr", action="store_true", help="Render to a VR headset.")
 
 	parser.add_argument("--sharpen", default=0, help="Set amount of sharpening applied to NeRF training images. Range 0.0 to 1.0.")
 
@@ -78,6 +79,8 @@ def get_scene(scene):
 
 if __name__ == "__main__":
 	args = parse_args()
+	if args.vr: # VR implies having the GUI running at the moment
+		args.gui = True
 
 	if args.mode:
 		print("Warning: the '--mode' argument is no longer in use. It has no effect. The mode is automatically chosen based on the scene.")
@@ -106,7 +109,9 @@ if __name__ == "__main__":
 		while sw * sh > 1920 * 1080 * 4:
 			sw = int(sw / 2)
 			sh = int(sh / 2)
-		testbed.init_window(sw, sh, second_window = args.second_window or False)
+		testbed.init_window(sw, sh, second_window=args.second_window)
+		if args.vr:
+			testbed.init_vr()
 
 
 	if args.load_snapshot:
@@ -159,10 +164,8 @@ if __name__ == "__main__":
 		# setting here.
 		testbed.nerf.cone_angle_constant = 0
 
-		# Optionally match nerf paper behaviour and train on a
-		# fixed white bg. We prefer training on random BG colors.
-		# testbed.background_color = [1.0, 1.0, 1.0, 1.0]
-		# testbed.nerf.training.random_bg_color = False
+		# Match nerf paper behaviour and train on a fixed bg.
+		testbed.nerf.training.random_bg_color = False
 
 	old_training_step = 0
 	n_steps = args.n_steps
@@ -223,53 +226,24 @@ if __name__ == "__main__":
 
 		testbed.nerf.render_min_transmittance = 1e-4
 
-		testbed.fov_axis = 0
-		testbed.fov = test_transforms["camera_angle_x"] * 180 / np.pi
 		testbed.shall_train = False
+		testbed.load_training_data(args.test_transforms)
 
-		with tqdm(list(enumerate(test_transforms["frames"])), unit="images", desc=f"Rendering test frame") as t:
-			for i, frame in t:
-				p = frame["file_path"]
-				if "." not in p:
-					p = p + ".png"
-				ref_fname = os.path.join(data_dir, p)
-				if not os.path.isfile(ref_fname):
-					ref_fname = os.path.join(data_dir, p + ".png")
-					if not os.path.isfile(ref_fname):
-						ref_fname = os.path.join(data_dir, p + ".jpg")
-						if not os.path.isfile(ref_fname):
-							ref_fname = os.path.join(data_dir, p + ".jpeg")
-							if not os.path.isfile(ref_fname):
-								ref_fname = os.path.join(data_dir, p + ".exr")
-
-				ref_image = read_image(ref_fname)
-
-				# NeRF blends with background colors in sRGB space, rather than first
-				# transforming to linear space, blending there, and then converting back.
-				# (See e.g. the PNG spec for more information on how the `alpha` channel
-				# is always a linear quantity.)
-				# The following lines of code reproduce NeRF's behavior (if enabled in
-				# testbed) in order to make the numbers comparable.
-				if testbed.color_space == ngp.ColorSpace.SRGB and ref_image.shape[2] == 4:
-					# Since sRGB conversion is non-linear, alpha must be factored out of it
-					ref_image[...,:3] = np.divide(ref_image[...,:3], ref_image[...,3:4], out=np.zeros_like(ref_image[...,:3]), where=ref_image[...,3:4] != 0)
-					ref_image[...,:3] = linear_to_srgb(ref_image[...,:3])
-					ref_image[...,:3] *= ref_image[...,3:4]
-					ref_image += (1.0 - ref_image[...,3:4]) * testbed.background_color
-					ref_image[...,:3] = srgb_to_linear(ref_image[...,:3])
+		with tqdm(range(testbed.nerf.training.dataset.n_images), unit="images", desc=f"Rendering test frame") as t:
+			for i in t:
+				resolution = testbed.nerf.training.dataset.metadata[i].resolution
+				testbed.render_ground_truth = True
+				testbed.set_camera_to_training_view(i)
+				ref_image = testbed.render(resolution[0], resolution[1], 1, True)
+				testbed.render_ground_truth = False
+				image = testbed.render(resolution[0], resolution[1], spp, True)
 
 				if i == 0:
-					write_image("ref.png", ref_image)
-
-				testbed.set_nerf_camera_matrix(np.matrix(frame["transform_matrix"])[:-1,:])
-				image = testbed.render(ref_image.shape[1], ref_image.shape[0], spp, True)
+					write_image(f"ref.png", ref_image)
+					write_image(f"out.png", image)
 
-				if i == 0:
-					write_image("out.png", image)
-
-				diffimg = np.absolute(image - ref_image)
-				diffimg[...,3:4] = 1.0
-				if i == 0:
+					diffimg = np.absolute(image - ref_image)
+					diffimg[...,3:4] = 1.0
 					write_image("diff.png", diffimg)
 
 				A = np.clip(linear_to_srgb(image[...,:3]), 0.0, 1.0)
diff --git a/src/camera_path.cu b/src/camera_path.cu
index 853f495a92af1056b3cab7183fd0bb82d36ff2ff..6b8953ce339c6dbff353012d9ca975afba970250 100644
--- a/src/camera_path.cu
+++ b/src/camera_path.cu
@@ -318,13 +318,11 @@ void visualize_nerf_camera(ImDrawList* list, const Matrix<float, 4, 4>& world2pr
 	add_debug_line(list, world2proj, d, a, col, thickness);
 }
 
-bool CameraPath::imgui_viz(ImDrawList* list, Matrix<float, 4, 4> &view2proj, Matrix<float, 4, 4> &world2proj, Matrix<float, 4, 4> &world2view, Vector2f focal, float aspect) {
+bool CameraPath::imgui_viz(ImDrawList* list, Matrix<float, 4, 4> &view2proj, Matrix<float, 4, 4> &world2proj, Matrix<float, 4, 4> &world2view, Vector2f focal, float aspect, float znear, float zfar) {
 	bool changed = false;
 	float flx = focal.x();
 	float fly = focal.y();
 	Matrix<float, 4, 4> view2proj_guizmo;
-	float zfar = 100.f;
-	float znear = 0.1f;
 	view2proj_guizmo <<
 		fly * 2.0f / aspect, 0, 0, 0,
 		0, -fly * 2.0f, 0, 0,
diff --git a/src/dlss.cu b/src/dlss.cu
index d81df9390f895cd0f4f47039e6db9a849081d805..516d92fac67e94c27a0b2c5129cd7c38c511dded 100644
--- a/src/dlss.cu
+++ b/src/dlss.cu
@@ -79,36 +79,18 @@ std::string ngx_error_string(NVSDK_NGX_Result result) {
 			throw std::runtime_error(std::string(FILE_LINE " " #x " failed with error ") + ngx_error_string(result)); \
 	} while(0)
 
-static VkInstance vk_instance = VK_NULL_HANDLE;
-static VkDebugUtilsMessengerEXT vk_debug_messenger = VK_NULL_HANDLE;
-static VkPhysicalDevice vk_physical_device = VK_NULL_HANDLE;
-static VkDevice vk_device = VK_NULL_HANDLE;
-static VkQueue vk_queue = VK_NULL_HANDLE;
-static VkCommandPool vk_command_pool = VK_NULL_HANDLE;
-static VkCommandBuffer vk_command_buffer = VK_NULL_HANDLE;
-
-static bool ngx_initialized = false;
-static NVSDK_NGX_Parameter* ngx_parameters = nullptr;
-
-uint32_t vk_find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags properties) {
-	VkPhysicalDeviceMemoryProperties mem_properties;
-	vkGetPhysicalDeviceMemoryProperties(vk_physical_device, &mem_properties);
-
-	for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++) {
-		if (type_filter & (1 << i) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties) {
-			return i;
-		}
-	}
-
-	throw std::runtime_error{"Failed to find suitable memory type."};
-}
-
 static VKAPI_ATTR VkBool32 VKAPI_CALL vk_debug_callback(
 	VkDebugUtilsMessageSeverityFlagBitsEXT message_severity,
 	VkDebugUtilsMessageTypeFlagsEXT message_type,
 	const VkDebugUtilsMessengerCallbackDataEXT* callback_data,
 	void* user_data
 ) {
+	// Ignore json files that couldn't be found... third party tools sometimes install bogus layers
+	// that manifest as warnings like this.
+	if (std::string{callback_data->pMessage}.find("Failed to open JSON file") != std::string::npos) {
+		return VK_FALSE;
+	}
+
 	if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) {
 		tlog::warning() << "Vulkan error: " << callback_data->pMessage;
 	} else if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT) {
@@ -120,366 +102,450 @@ static VKAPI_ATTR VkBool32 VKAPI_CALL vk_debug_callback(
 	return VK_FALSE;
 }
 
-void vulkan_and_ngx_init() {
-	static bool already_initialized = false;
+class VulkanAndNgx : public IDlssProvider, public std::enable_shared_from_this<VulkanAndNgx> {
+public:
+	VulkanAndNgx() {
+		ScopeGuard cleanup_guard{[&]() { clear(); }};
 
-	if (already_initialized) {
-		return;
-	}
+		if (!glfwVulkanSupported()) {
+			throw std::runtime_error{"!glfwVulkanSupported()"};
+		}
 
-	already_initialized = true;
+		// -------------------------------
+		// Vulkan Instance
+		// -------------------------------
+		VkApplicationInfo app_info{};
+		app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+		app_info.pApplicationName = "NGP";
+		app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
+		app_info.pEngineName = "No engine";
+		app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
+		app_info.apiVersion = VK_API_VERSION_1_0;
+
+		VkInstanceCreateInfo instance_create_info = {};
+		instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+		instance_create_info.pApplicationInfo = &app_info;
+
+		uint32_t available_layer_count;
+		vkEnumerateInstanceLayerProperties(&available_layer_count, nullptr);
+
+		std::vector<VkLayerProperties> available_layers(available_layer_count);
+		vkEnumerateInstanceLayerProperties(&available_layer_count, available_layers.data());
+
+		std::vector<const char*> layers;
+		auto try_add_layer = [&](const char* layer) {
+			for (const auto& props : available_layers) {
+				if (strcmp(layer, props.layerName) == 0) {
+					layers.emplace_back(layer);
+					return true;
+				}
+			}
 
-	if (!glfwVulkanSupported()) {
-		throw std::runtime_error{"!glfwVulkanSupported()"};
-	}
+			return false;
+		};
 
-	// -------------------------------
-	// Vulkan Instance
-	// -------------------------------
-	VkApplicationInfo app_info{};
-	app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
-	app_info.pApplicationName = "NGP";
-	app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
-	app_info.pEngineName = "No engine";
-	app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
-	app_info.apiVersion = VK_API_VERSION_1_0;
-
-	VkInstanceCreateInfo instance_create_info = {};
-	instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
-	instance_create_info.pApplicationInfo = &app_info;
-
-	uint32_t available_layer_count;
-	vkEnumerateInstanceLayerProperties(&available_layer_count, nullptr);
-
-	std::vector<VkLayerProperties> available_layers(available_layer_count);
-	vkEnumerateInstanceLayerProperties(&available_layer_count, available_layers.data());
-
-	std::vector<const char*> layers;
-	auto try_add_layer = [&](const char* layer) {
-		for (const auto& props : available_layers) {
-			if (strcmp(layer, props.layerName) == 0) {
-				layers.emplace_back(layer);
-				return true;
-			}
+		bool validation_layer_enabled = try_add_layer("VK_LAYER_KHRONOS_validation");
+		if (!validation_layer_enabled) {
+			tlog::warning() << "Vulkan validation layer is not available. Vulkan errors will be difficult to diagnose.";
 		}
 
-		return false;
-	};
+		instance_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
+		instance_create_info.ppEnabledLayerNames = layers.empty() ? nullptr : layers.data();
 
-	bool validation_layer_enabled = try_add_layer("VK_LAYER_KHRONOS_validation");
-	if (!validation_layer_enabled) {
-		tlog::warning() << "Vulkan validation layer is not available. Vulkan errors will be difficult to diagnose.";
-	}
+		std::vector<const char*> instance_extensions;
+		std::vector<const char*> device_extensions;
 
-	instance_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
-	instance_create_info.ppEnabledLayerNames = layers.empty() ? nullptr : layers.data();
+		uint32_t n_ngx_instance_extensions = 0;
+		const char** ngx_instance_extensions;
 
-	std::vector<const char*> instance_extensions;
-	std::vector<const char*> device_extensions;
+		uint32_t n_ngx_device_extensions = 0;
+		const char** ngx_device_extensions;
 
-	uint32_t n_ngx_instance_extensions = 0;
-	const char** ngx_instance_extensions;
+		NVSDK_NGX_VULKAN_RequiredExtensions(&n_ngx_instance_extensions, &ngx_instance_extensions, &n_ngx_device_extensions, &ngx_device_extensions);
 
-	uint32_t n_ngx_device_extensions = 0;
-	const char** ngx_device_extensions;
+		for (uint32_t i = 0; i < n_ngx_instance_extensions; ++i) {
+			instance_extensions.emplace_back(ngx_instance_extensions[i]);
+		}
 
-	NVSDK_NGX_VULKAN_RequiredExtensions(&n_ngx_instance_extensions, &ngx_instance_extensions, &n_ngx_device_extensions, &ngx_device_extensions);
+		instance_extensions.emplace_back(VK_KHR_DEVICE_GROUP_CREATION_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
 
-	for (uint32_t i = 0; i < n_ngx_instance_extensions; ++i) {
-		instance_extensions.emplace_back(ngx_instance_extensions[i]);
-	}
+		if (validation_layer_enabled) {
+			instance_extensions.emplace_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
+		}
 
-	instance_extensions.emplace_back(VK_KHR_DEVICE_GROUP_CREATION_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
+		for (uint32_t i = 0; i < n_ngx_device_extensions; ++i) {
+			device_extensions.emplace_back(ngx_device_extensions[i]);
+		}
 
-	if (validation_layer_enabled) {
-		instance_extensions.emplace_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
-	}
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME);
+	#ifdef _WIN32
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME);
+	#else
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME);
+	#endif
+		device_extensions.emplace_back(VK_KHR_DEVICE_GROUP_EXTENSION_NAME);
+
+		instance_create_info.enabledExtensionCount = (uint32_t)instance_extensions.size();
+		instance_create_info.ppEnabledExtensionNames = instance_extensions.data();
+
+		VkDebugUtilsMessengerCreateInfoEXT debug_messenger_create_info = {};
+		debug_messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
+		debug_messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
+		debug_messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
+		debug_messenger_create_info.pfnUserCallback = vk_debug_callback;
+		debug_messenger_create_info.pUserData = nullptr;
+
+		if (validation_layer_enabled) {
+			instance_create_info.pNext = &debug_messenger_create_info;
+		}
 
-	for (uint32_t i = 0; i < n_ngx_device_extensions; ++i) {
-		device_extensions.emplace_back(ngx_device_extensions[i]);
-	}
+		VK_CHECK_THROW(vkCreateInstance(&instance_create_info, nullptr, &m_vk_instance));
 
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME);
-#ifdef _WIN32
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME);
-#else
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME);
-#endif
-	device_extensions.emplace_back(VK_KHR_DEVICE_GROUP_EXTENSION_NAME);
+		if (validation_layer_enabled) {
+			auto CreateDebugUtilsMessengerEXT = [](VkInstance instance, const VkDebugUtilsMessengerCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugUtilsMessengerEXT* pDebugMessenger) {
+				auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkCreateDebugUtilsMessengerEXT");
+				if (func != nullptr) {
+					return func(instance, pCreateInfo, pAllocator, pDebugMessenger);
+				} else {
+					return VK_ERROR_EXTENSION_NOT_PRESENT;
+				}
+			};
 
-	instance_create_info.enabledExtensionCount = (uint32_t)instance_extensions.size();
-	instance_create_info.ppEnabledExtensionNames = instance_extensions.data();
+			if (CreateDebugUtilsMessengerEXT(m_vk_instance, &debug_messenger_create_info, nullptr, &m_vk_debug_messenger) != VK_SUCCESS) {
+				tlog::warning() << "Vulkan: could not initialize debug messenger.";
+			}
+		}
 
-	VkDebugUtilsMessengerCreateInfoEXT debug_messenger_create_info = {};
-	debug_messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
-	debug_messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
-	debug_messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
-	debug_messenger_create_info.pfnUserCallback = vk_debug_callback;
-	debug_messenger_create_info.pUserData = nullptr;
+		// -------------------------------
+		// Vulkan Physical Device
+		// -------------------------------
+		uint32_t n_devices = 0;
+		vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, nullptr);
 
-	if (validation_layer_enabled) {
-		instance_create_info.pNext = &debug_messenger_create_info;
-	}
+		if (n_devices == 0) {
+			throw std::runtime_error{"Failed to find GPUs with Vulkan support."};
+		}
 
-	VK_CHECK_THROW(vkCreateInstance(&instance_create_info, nullptr, &vk_instance));
+		std::vector<VkPhysicalDevice> devices(n_devices);
+		vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, devices.data());
 
-	if (validation_layer_enabled) {
-		auto CreateDebugUtilsMessengerEXT = [](VkInstance instance, const VkDebugUtilsMessengerCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugUtilsMessengerEXT* pDebugMessenger) {
-			auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkCreateDebugUtilsMessengerEXT");
-			if (func != nullptr) {
-				return func(instance, pCreateInfo, pAllocator, pDebugMessenger);
-			} else {
-				return VK_ERROR_EXTENSION_NOT_PRESENT;
-			}
+		struct QueueFamilyIndices {
+			int graphics_family = -1;
+			int compute_family = -1;
+			int transfer_family = -1;
+			int all_family = -1;
 		};
 
-		if (CreateDebugUtilsMessengerEXT(vk_instance, &debug_messenger_create_info, nullptr, &vk_debug_messenger) != VK_SUCCESS) {
-			tlog::warning() << "Vulkan: could not initialize debug messenger.";
-		}
-	}
-
-	// -------------------------------
-	// Vulkan Physical Device
-	// -------------------------------
-	uint32_t n_devices = 0;
-	vkEnumeratePhysicalDevices(vk_instance, &n_devices, nullptr);
+		auto find_queue_families = [](VkPhysicalDevice device) {
+			QueueFamilyIndices indices;
 
-	if (n_devices == 0) {
-		throw std::runtime_error{"Failed to find GPUs with Vulkan support."};
-	}
+			uint32_t queue_family_count = 0;
+			vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
 
-	std::vector<VkPhysicalDevice> devices(n_devices);
-	vkEnumeratePhysicalDevices(vk_instance, &n_devices, devices.data());
+			std::vector<VkQueueFamilyProperties> queue_families(queue_family_count);
+			vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
 
-	struct QueueFamilyIndices {
-		int graphics_family = -1;
-		int compute_family = -1;
-		int transfer_family = -1;
-		int all_family = -1;
-	};
+			int i = 0;
+			for (const auto& queue_family : queue_families) {
+				if (queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
+					indices.graphics_family = i;
+				}
 
-	auto find_queue_families = [](VkPhysicalDevice device) {
-		QueueFamilyIndices indices;
+				if (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) {
+					indices.compute_family = i;
+				}
 
-		uint32_t queue_family_count = 0;
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
+				if (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT) {
+					indices.transfer_family = i;
+				}
 
-		std::vector<VkQueueFamilyProperties> queue_families(queue_family_count);
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
+				if ((queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) && (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) && (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT)) {
+					indices.all_family = i;
+				}
 
-		int i = 0;
-		for (const auto& queue_family : queue_families) {
-			if (queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
-				indices.graphics_family = i;
+				i++;
 			}
 
-			if (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) {
-				indices.compute_family = i;
-			}
+			return indices;
+		};
 
-			if (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT) {
-				indices.transfer_family = i;
-			}
+		cudaDeviceProp cuda_device_prop;
+		CUDA_CHECK_THROW(cudaGetDeviceProperties(&cuda_device_prop, tcnn::cuda_device()));
 
-			if ((queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) && (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) && (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT)) {
-				indices.all_family = i;
-			}
+		auto is_same_as_cuda_device = [&](VkPhysicalDevice device) {
+			VkPhysicalDeviceIDProperties physical_device_id_properties = {};
+			physical_device_id_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES;
+			physical_device_id_properties.pNext = NULL;
 
-			i++;
-		}
+			VkPhysicalDeviceProperties2 physical_device_properties = {};
+			physical_device_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
+			physical_device_properties.pNext = &physical_device_id_properties;
 
-		return indices;
-	};
+			vkGetPhysicalDeviceProperties2(device, &physical_device_properties);
 
-	cudaDeviceProp cuda_device_prop;
-	CUDA_CHECK_THROW(cudaGetDeviceProperties(&cuda_device_prop, tcnn::cuda_device()));
+			return !memcmp(&cuda_device_prop.uuid, physical_device_id_properties.deviceUUID, VK_UUID_SIZE) && find_queue_families(device).all_family >= 0;
+		};
 
-	auto is_same_as_cuda_device = [&](VkPhysicalDevice device) {
-		VkPhysicalDeviceIDProperties physical_device_id_properties = {};
-		physical_device_id_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES;
-		physical_device_id_properties.pNext = NULL;
+		uint32_t device_id = 0;
+		for (uint32_t i = 0; i < n_devices; ++i) {
+			if (is_same_as_cuda_device(devices[i])) {
+				m_vk_physical_device = devices[i];
+				device_id = i;
+				break;
+			}
+		}
 
-		VkPhysicalDeviceProperties2 physical_device_properties = {};
-		physical_device_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
-		physical_device_properties.pNext = &physical_device_id_properties;
+		if (m_vk_physical_device == VK_NULL_HANDLE) {
+			throw std::runtime_error{"Failed to find Vulkan device corresponding to CUDA device."};
+		}
 
-		vkGetPhysicalDeviceProperties2(device, &physical_device_properties);
+		// -------------------------------
+		// Vulkan Logical Device
+		// -------------------------------
+		VkPhysicalDeviceProperties physical_device_properties;
+		vkGetPhysicalDeviceProperties(m_vk_physical_device, &physical_device_properties);
+
+		QueueFamilyIndices indices = find_queue_families(m_vk_physical_device);
+
+		VkDeviceQueueCreateInfo queue_create_info{};
+		queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+		queue_create_info.queueFamilyIndex = indices.all_family;
+		queue_create_info.queueCount = 1;
+
+		float queue_priority = 1.0f;
+		queue_create_info.pQueuePriorities = &queue_priority;
+
+		VkPhysicalDeviceFeatures device_features = {};
+		device_features.shaderStorageImageWriteWithoutFormat = true;
+
+		VkDeviceCreateInfo device_create_info = {};
+		device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+		device_create_info.pQueueCreateInfos = &queue_create_info;
+		device_create_info.queueCreateInfoCount = 1;
+		device_create_info.pEnabledFeatures = &device_features;
+		device_create_info.enabledExtensionCount = (uint32_t)device_extensions.size();
+		device_create_info.ppEnabledExtensionNames = device_extensions.data();
+		device_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
+		device_create_info.ppEnabledLayerNames = layers.data();
+
+	#ifdef VK_EXT_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME
+		VkPhysicalDeviceBufferDeviceAddressFeaturesEXT buffer_device_address_feature = {};
+		buffer_device_address_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_EXT;
+		buffer_device_address_feature.bufferDeviceAddress = VK_TRUE;
+		device_create_info.pNext = &buffer_device_address_feature;
+	#else
+		throw std::runtime_error{"Buffer device address extension not available."};
+	#endif
+
+		VK_CHECK_THROW(vkCreateDevice(m_vk_physical_device, &device_create_info, nullptr, &m_vk_device));
+
+		// -----------------------------------------------
+		// Vulkan queue / command pool / command buffer
+		// -----------------------------------------------
+		vkGetDeviceQueue(m_vk_device, indices.all_family, 0, &m_vk_queue);
+
+		VkCommandPoolCreateInfo command_pool_info = {};
+		command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+		command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
+		command_pool_info.queueFamilyIndex = indices.all_family;
+
+		VK_CHECK_THROW(vkCreateCommandPool(m_vk_device, &command_pool_info, nullptr, &m_vk_command_pool));
+
+		VkCommandBufferAllocateInfo command_buffer_alloc_info = {};
+		command_buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+		command_buffer_alloc_info.commandPool = m_vk_command_pool;
+		command_buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+		command_buffer_alloc_info.commandBufferCount = 1;
+
+		VK_CHECK_THROW(vkAllocateCommandBuffers(m_vk_device, &command_buffer_alloc_info, &m_vk_command_buffer));
+
+		// -------------------------------
+		// NGX init
+		// -------------------------------
+		std::wstring path;
+#ifdef _WIN32
+		path = fs::path::getcwd().wstr();
+#else
+		std::string tmp = fs::path::getcwd().str();
+		std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
+		path = converter.from_bytes(tmp);
+#endif
 
-		return !memcmp(&cuda_device_prop.uuid, physical_device_id_properties.deviceUUID, VK_UUID_SIZE) && find_queue_families(device).all_family >= 0;
-	};
+		NGX_CHECK_THROW(NVSDK_NGX_VULKAN_Init_with_ProjectID("ea75345e-5a42-4037-a5c9-59bf94dee157", NVSDK_NGX_ENGINE_TYPE_CUSTOM, "1.0.0", path.c_str(), m_vk_instance, m_vk_physical_device, m_vk_device));
+		m_ngx_initialized = true;
+
+		// -------------------------------
+		// Ensure DLSS capability
+		// -------------------------------
+		NGX_CHECK_THROW(NVSDK_NGX_VULKAN_GetCapabilityParameters(&m_ngx_parameters));
+
+		int needs_updated_driver = 0;
+		unsigned int min_driver_version_major = 0;
+		unsigned int min_driver_version_minor = 0;
+		NVSDK_NGX_Result result_updated_driver = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_NeedsUpdatedDriver, &needs_updated_driver);
+		NVSDK_NGX_Result result_min_driver_version_major = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMajor, &min_driver_version_major);
+		NVSDK_NGX_Result result_min_driver_version_minor = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMinor, &min_driver_version_minor);
+		if (result_updated_driver == NVSDK_NGX_Result_Success && result_min_driver_version_major == NVSDK_NGX_Result_Success && result_min_driver_version_minor == NVSDK_NGX_Result_Success) {
+			if (needs_updated_driver) {
+				throw std::runtime_error{fmt::format("Driver too old. Minimum version required is {}.{}", min_driver_version_major, min_driver_version_minor)};
+			}
+		}
 
-	uint32_t device_id = 0;
-	for (uint32_t i = 0; i < n_devices; ++i) {
-		if (is_same_as_cuda_device(devices[i])) {
-			vk_physical_device = devices[i];
-			device_id = i;
-			break;
+		int dlss_available  = 0;
+		NVSDK_NGX_Result ngx_result = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_Available, &dlss_available);
+		if (ngx_result != NVSDK_NGX_Result_Success || !dlss_available) {
+			ngx_result = NVSDK_NGX_Result_Fail;
+			NVSDK_NGX_Parameter_GetI(m_ngx_parameters, NVSDK_NGX_Parameter_SuperSampling_FeatureInitResult, (int*)&ngx_result);
+			throw std::runtime_error{fmt::format("DLSS not available: {}", ngx_error_string(ngx_result))};
 		}
+
+		cleanup_guard.disarm();
+
+		tlog::success() << "Initialized Vulkan and NGX on GPU #" << device_id << ": " << physical_device_properties.deviceName;
 	}
 
-	if (vk_physical_device == VK_NULL_HANDLE) {
-		throw std::runtime_error{"Failed to find Vulkan device corresponding to CUDA device."};
+	virtual ~VulkanAndNgx() {
+		clear();
 	}
 
-	// -------------------------------
-	// Vulkan Logical Device
-	// -------------------------------
-	VkPhysicalDeviceProperties physical_device_properties;
-	vkGetPhysicalDeviceProperties(vk_physical_device, &physical_device_properties);
-
-	QueueFamilyIndices indices = find_queue_families(vk_physical_device);
-
-	VkDeviceQueueCreateInfo queue_create_info{};
-	queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
-	queue_create_info.queueFamilyIndex = indices.all_family;
-	queue_create_info.queueCount = 1;
-
-	float queue_priority = 1.0f;
-	queue_create_info.pQueuePriorities = &queue_priority;
-
-	VkPhysicalDeviceFeatures device_features = {};
-	device_features.shaderStorageImageWriteWithoutFormat = true;
-
-	VkDeviceCreateInfo device_create_info = {};
-	device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
-	device_create_info.pQueueCreateInfos = &queue_create_info;
-	device_create_info.queueCreateInfoCount = 1;
-	device_create_info.pEnabledFeatures = &device_features;
-	device_create_info.enabledExtensionCount = (uint32_t)device_extensions.size();
-	device_create_info.ppEnabledExtensionNames = device_extensions.data();
-	device_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
-	device_create_info.ppEnabledLayerNames = layers.data();
-
-#ifdef VK_EXT_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME
-	VkPhysicalDeviceBufferDeviceAddressFeaturesEXT buffer_device_address_feature = {};
-	buffer_device_address_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_EXT;
-	buffer_device_address_feature.bufferDeviceAddress = VK_TRUE;
-	device_create_info.pNext = &buffer_device_address_feature;
-#else
-	throw std::runtime_error{"Buffer device address extension not available."};
-#endif
+	void clear() {
+		if (m_ngx_parameters) {
+			NVSDK_NGX_VULKAN_DestroyParameters(m_ngx_parameters);
+			m_ngx_parameters = nullptr;
+		}
 
-	VK_CHECK_THROW(vkCreateDevice(vk_physical_device, &device_create_info, nullptr, &vk_device));
+		if (m_ngx_initialized) {
+			NVSDK_NGX_VULKAN_Shutdown();
+			m_ngx_initialized = false;
+		}
 
-	// -----------------------------------------------
-	// Vulkan queue / command pool / command buffer
-	// -----------------------------------------------
-	vkGetDeviceQueue(vk_device, indices.all_family, 0, &vk_queue);
+		if (m_vk_command_pool) {
+			vkDestroyCommandPool(m_vk_device, m_vk_command_pool, nullptr);
+			m_vk_command_pool = VK_NULL_HANDLE;
+		}
 
-	VkCommandPoolCreateInfo command_pool_info = {};
-	command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
-	command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
-	command_pool_info.queueFamilyIndex = indices.all_family;
+		if (m_vk_device) {
+			vkDestroyDevice(m_vk_device, nullptr);
+			m_vk_device = VK_NULL_HANDLE;
+		}
 
-	VK_CHECK_THROW(vkCreateCommandPool(vk_device, &command_pool_info, nullptr, &vk_command_pool));
+		if (m_vk_debug_messenger) {
+			auto DestroyDebugUtilsMessengerEXT = [](VkInstance instance, VkDebugUtilsMessengerEXT debugMessenger, const VkAllocationCallbacks* pAllocator) {
+				auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkDestroyDebugUtilsMessengerEXT");
+				if (func != nullptr) {
+					func(instance, debugMessenger, pAllocator);
+				}
+			};
 
-	VkCommandBufferAllocateInfo command_buffer_alloc_info = {};
-	command_buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
-	command_buffer_alloc_info.commandPool = vk_command_pool;
-	command_buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
-	command_buffer_alloc_info.commandBufferCount = 1;
+			DestroyDebugUtilsMessengerEXT(m_vk_instance, m_vk_debug_messenger, nullptr);
+			m_vk_debug_messenger = VK_NULL_HANDLE;
+		}
 
-	VK_CHECK_THROW(vkAllocateCommandBuffers(vk_device, &command_buffer_alloc_info, &vk_command_buffer));
+		if (m_vk_instance) {
+			vkDestroyInstance(m_vk_instance, nullptr);
+			m_vk_instance = VK_NULL_HANDLE;
+		}
+	}
 
-	// -------------------------------
-	// NGX init
-	// -------------------------------
-	std::wstring path;
-#ifdef _WIN32
-	path = fs::path::getcwd().wstr();
-#else
-	std::string tmp = fs::path::getcwd().str();
-	std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
-	path = converter.from_bytes(tmp);
-#endif
+	uint32_t vk_find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags properties) {
+		VkPhysicalDeviceMemoryProperties mem_properties;
+		vkGetPhysicalDeviceMemoryProperties(m_vk_physical_device, &mem_properties);
 
-	NGX_CHECK_THROW(NVSDK_NGX_VULKAN_Init_with_ProjectID("ea75345e-5a42-4037-a5c9-59bf94dee157", NVSDK_NGX_ENGINE_TYPE_CUSTOM, "1.0.0", path.c_str(), vk_instance, vk_physical_device, vk_device));
-	ngx_initialized = true;
-
-	// -------------------------------
-	// Ensure DLSS capability
-	// -------------------------------
-	NGX_CHECK_THROW(NVSDK_NGX_VULKAN_GetCapabilityParameters(&ngx_parameters));
-
-	int needs_updated_driver = 0;
-	unsigned int min_driver_version_major = 0;
-	unsigned int min_driver_version_minor = 0;
-	NVSDK_NGX_Result result_updated_driver = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_NeedsUpdatedDriver, &needs_updated_driver);
-	NVSDK_NGX_Result result_min_driver_version_major = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMajor, &min_driver_version_major);
-	NVSDK_NGX_Result result_min_driver_version_minor = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMinor, &min_driver_version_minor);
-	if (result_updated_driver == NVSDK_NGX_Result_Success && result_min_driver_version_major == NVSDK_NGX_Result_Success && result_min_driver_version_minor == NVSDK_NGX_Result_Success) {
-		if (needs_updated_driver) {
-			throw std::runtime_error{fmt::format("Driver too old. Minimum version required is {}.{}", min_driver_version_major, min_driver_version_minor)};
+		for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++) {
+			if (type_filter & (1 << i) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties) {
+				return i;
+			}
 		}
+
+		throw std::runtime_error{"Failed to find suitable memory type."};
 	}
 
-	int dlss_available  = 0;
-	NVSDK_NGX_Result ngx_result = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_Available, &dlss_available);
-	if (ngx_result != NVSDK_NGX_Result_Success || !dlss_available) {
-		ngx_result = NVSDK_NGX_Result_Fail;
-		NVSDK_NGX_Parameter_GetI(ngx_parameters, NVSDK_NGX_Parameter_SuperSampling_FeatureInitResult, (int*)&ngx_result);
-		throw std::runtime_error{fmt::format("DLSS not available: {}", ngx_error_string(ngx_result))};
+	void vk_command_buffer_begin() {
+		VkCommandBufferBeginInfo begin_info = {};
+		begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+		begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+		begin_info.pInheritanceInfo = nullptr;
+
+		VK_CHECK_THROW(vkBeginCommandBuffer(m_vk_command_buffer, &begin_info));
 	}
 
-	tlog::success() << "Initialized Vulkan and NGX on GPU #" << device_id << ": " << physical_device_properties.deviceName;
-}
+	void vk_command_buffer_end() {
+		VK_CHECK_THROW(vkEndCommandBuffer(m_vk_command_buffer));
+	}
 
-size_t dlss_allocated_bytes() {
-	unsigned long long allocated_bytes = 0;
-	if (!ngx_parameters) {
-		return 0;
+	void vk_command_buffer_submit() {
+		VkSubmitInfo submit_info = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
+		submit_info.commandBufferCount = 1;
+		submit_info.pCommandBuffers = &m_vk_command_buffer;
+
+		VK_CHECK_THROW(vkQueueSubmit(m_vk_queue, 1, &submit_info, VK_NULL_HANDLE));
 	}
 
-	try {
-		NGX_CHECK_THROW(NGX_DLSS_GET_STATS(ngx_parameters, &allocated_bytes));
-	} catch (...) {
-		return 0;
+	void vk_synchronize() {
+		VK_CHECK_THROW(vkDeviceWaitIdle(m_vk_device));
 	}
 
-	return allocated_bytes;
-}
+	void vk_command_buffer_submit_sync() {
+		vk_command_buffer_submit();
+		vk_synchronize();
+	}
 
-void vk_command_buffer_begin() {
-	VkCommandBufferBeginInfo begin_info = {};
-	begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
-	begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
-	begin_info.pInheritanceInfo = nullptr;
+	void vk_command_buffer_end_and_submit_sync() {
+		vk_command_buffer_end();
+		vk_command_buffer_submit_sync();
+	}
 
-	VK_CHECK_THROW(vkBeginCommandBuffer(vk_command_buffer, &begin_info));
-}
+	const VkCommandBuffer& vk_command_buffer() const {
+		return m_vk_command_buffer;
+	}
 
-void vk_command_buffer_end() {
-	VK_CHECK_THROW(vkEndCommandBuffer(vk_command_buffer));
-}
+	const VkDevice& vk_device() const {
+		return m_vk_device;
+	}
 
-void vk_command_buffer_submit() {
-	VkSubmitInfo submit_info = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
-	submit_info.commandBufferCount = 1;
-	submit_info.pCommandBuffers = &vk_command_buffer;
+	NVSDK_NGX_Parameter* ngx_parameters() const {
+		return m_ngx_parameters;
+	}
 
-	VK_CHECK_THROW(vkQueueSubmit(vk_queue, 1, &submit_info, VK_NULL_HANDLE));
-}
+	size_t allocated_bytes() const override {
+		unsigned long long allocated_bytes = 0;
+		if (!m_ngx_parameters) {
+			return 0;
+		}
 
-void vk_synchronize() {
-	VK_CHECK_THROW(vkDeviceWaitIdle(vk_device));
-}
+		try {
+			NGX_CHECK_THROW(NGX_DLSS_GET_STATS(m_ngx_parameters, &allocated_bytes));
+		} catch (...) {
+			return 0;
+		}
 
-void vk_command_buffer_submit_sync() {
-	vk_command_buffer_submit();
-	vk_synchronize();
-}
+		return allocated_bytes;
+	}
+
+	std::unique_ptr<IDlss> init_dlss(const Eigen::Vector2i& out_resolution) override;
+
+private:
+	VkInstance m_vk_instance = VK_NULL_HANDLE;
+	VkDebugUtilsMessengerEXT m_vk_debug_messenger = VK_NULL_HANDLE;
+	VkPhysicalDevice m_vk_physical_device = VK_NULL_HANDLE;
+	VkDevice m_vk_device = VK_NULL_HANDLE;
+	VkQueue m_vk_queue = VK_NULL_HANDLE;
+	VkCommandPool m_vk_command_pool = VK_NULL_HANDLE;
+	VkCommandBuffer m_vk_command_buffer = VK_NULL_HANDLE;
+	NVSDK_NGX_Parameter* m_ngx_parameters = nullptr;
+	bool m_ngx_initialized = false;
+};
 
-void vk_command_buffer_end_and_submit_sync() {
-	vk_command_buffer_end();
-	vk_command_buffer_submit_sync();
+std::shared_ptr<IDlssProvider> init_vulkan_and_ngx() {
+	return std::make_shared<VulkanAndNgx>();
 }
 
 class VulkanTexture {
 public:
-	VulkanTexture(const Vector2i& size, uint32_t n_channels) : m_size{size}, m_n_channels{n_channels} {
+	VulkanTexture(std::shared_ptr<VulkanAndNgx> vk, const Vector2i& size, uint32_t n_channels) : m_vk{vk}, m_size{size}, m_n_channels{n_channels} {
 		VkImageCreateInfo image_info{};
 		image_info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
 		image_info.imageType = VK_IMAGE_TYPE_2D;
@@ -515,17 +581,17 @@ public:
 
 		image_info.pNext = &ext_image_info;
 
-		VK_CHECK_THROW(vkCreateImage(vk_device, &image_info, nullptr, &m_vk_image));
+		VK_CHECK_THROW(vkCreateImage(m_vk->vk_device(), &image_info, nullptr, &m_vk_image));
 
 		// Create device memory to back up the image
 		VkMemoryRequirements mem_requirements = {};
 
-		vkGetImageMemoryRequirements(vk_device, m_vk_image, &mem_requirements);
+		vkGetImageMemoryRequirements(m_vk->vk_device(), m_vk_image, &mem_requirements);
 
 		VkMemoryAllocateInfo mem_alloc_info = {};
 		mem_alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
 		mem_alloc_info.allocationSize = mem_requirements.size;
-		mem_alloc_info.memoryTypeIndex = vk_find_memory_type(mem_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
+		mem_alloc_info.memoryTypeIndex = m_vk->vk_find_memory_type(mem_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
 
 		VkExportMemoryAllocateInfoKHR export_info = {};
 		export_info.sType = VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR;
@@ -533,10 +599,10 @@ public:
 
 		mem_alloc_info.pNext = &export_info;
 
-		VK_CHECK_THROW(vkAllocateMemory(vk_device, &mem_alloc_info, nullptr, &m_vk_device_memory));
-		VK_CHECK_THROW(vkBindImageMemory(vk_device, m_vk_image, m_vk_device_memory, 0));
+		VK_CHECK_THROW(vkAllocateMemory(m_vk->vk_device(), &mem_alloc_info, nullptr, &m_vk_device_memory));
+		VK_CHECK_THROW(vkBindImageMemory(m_vk->vk_device(), m_vk_image, m_vk_device_memory, 0));
 
-		vk_command_buffer_begin();
+		m_vk->vk_command_buffer_begin();
 
 		VkImageMemoryBarrier barrier = {};
 		barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
@@ -554,7 +620,7 @@ public:
 		barrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT | VK_ACCESS_MEMORY_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_COLOR_ATTACHMENT_READ_BIT | VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT;
 
 		vkCmdPipelineBarrier(
-			vk_command_buffer,
+			m_vk->vk_command_buffer(),
 			VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
 			0,
 			0, nullptr,
@@ -562,7 +628,7 @@ public:
 			1, &barrier
 		);
 
-		vk_command_buffer_end_and_submit_sync();
+		m_vk->vk_command_buffer_end_and_submit_sync();
 
 		// Image view
 		VkImageViewCreateInfo view_info = {};
@@ -572,7 +638,7 @@ public:
 		view_info.format = image_info.format;
 		view_info.subresourceRange = barrier.subresourceRange;
 
-		VK_CHECK_THROW(vkCreateImageView(vk_device, &view_info, nullptr, &m_vk_image_view));
+		VK_CHECK_THROW(vkCreateImageView(m_vk->vk_device(), &view_info, nullptr, &m_vk_image_view));
 
 		// Map to NGX
 		m_ngx_resource = NVSDK_NGX_Create_ImageView_Resource_VK(m_vk_image_view, m_vk_image, view_info.subresourceRange, image_info.format, m_size.x(), m_size.y(), true);
@@ -584,21 +650,21 @@ public:
 		handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_WIN32_HANDLE_INFO_KHR;
 		handle_info.memory = m_vk_device_memory;
 		handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT;
-		auto pfn_vkGetMemory = (PFN_vkGetMemoryWin32HandleKHR)vkGetDeviceProcAddr(vk_device, "vkGetMemoryWin32HandleKHR");
+		auto pfn_vkGetMemory = (PFN_vkGetMemoryWin32HandleKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryWin32HandleKHR");
 #else
 		int handle = -1;
 		VkMemoryGetFdInfoKHR handle_info = {};
 		handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_FD_INFO_KHR;
 		handle_info.memory = m_vk_device_memory;
 		handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR;
-		auto pfn_vkGetMemory = (PFN_vkGetMemoryFdKHR)vkGetDeviceProcAddr(vk_device, "vkGetMemoryFdKHR");
+		auto pfn_vkGetMemory = (PFN_vkGetMemoryFdKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryFdKHR");
 #endif
 
 		if (!pfn_vkGetMemory) {
 			throw std::runtime_error{"Failed to locate pfn_vkGetMemory."};
 		}
 
-		VK_CHECK_THROW(pfn_vkGetMemory(vk_device, &handle_info, &handle));
+		VK_CHECK_THROW(pfn_vkGetMemory(m_vk->vk_device(), &handle_info, &handle));
 
 		// Map handle to CUDA memory
 		cudaExternalMemoryHandleDesc external_memory_handle_desc = {};
@@ -687,15 +753,15 @@ public:
 		}
 
 		if (m_vk_image_view) {
-			vkDestroyImageView(vk_device, m_vk_image_view, nullptr);
+			vkDestroyImageView(m_vk->vk_device(), m_vk_image_view, nullptr);
 		}
 
 		if (m_vk_image) {
-			vkDestroyImage(vk_device, m_vk_image, nullptr);
+			vkDestroyImage(m_vk->vk_device(), m_vk_image, nullptr);
 		}
 
 		if (m_vk_device_memory) {
-			vkFreeMemory(vk_device, m_vk_device_memory, nullptr);
+			vkFreeMemory(m_vk->vk_device(), m_vk_device_memory, nullptr);
 		}
 	}
 
@@ -720,6 +786,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk;
+
 	Vector2i m_size;
 	uint32_t m_n_channels;
 
@@ -765,7 +833,7 @@ struct DlssFeatureSpecs {
 	}
 };
 
-DlssFeatureSpecs dlss_feature_specs(const Eigen::Vector2i& out_resolution, EDlssQuality quality) {
+DlssFeatureSpecs dlss_feature_specs(NVSDK_NGX_Parameter* ngx_parameters, const Eigen::Vector2i& out_resolution, EDlssQuality quality) {
 	DlssFeatureSpecs specs;
 	specs.quality = quality;
 	specs.out_resolution = out_resolution;
@@ -790,7 +858,7 @@ DlssFeatureSpecs dlss_feature_specs(const Eigen::Vector2i& out_resolution, EDlss
 
 class DlssFeature {
 public:
-	DlssFeature(const DlssFeatureSpecs& specs, bool is_hdr, bool sharpen) : m_specs{specs}, m_is_hdr{is_hdr}, m_sharpen{sharpen} {
+	DlssFeature(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const DlssFeatureSpecs& specs, bool is_hdr, bool sharpen) : m_vk_and_ngx{vk_and_ngx}, m_specs{specs}, m_is_hdr{is_hdr}, m_sharpen{sharpen} {
 		// Initialize DLSS
 		unsigned int creation_node_mask = 1;
 		unsigned int visibility_node_mask = 1;
@@ -799,7 +867,7 @@ public:
 		dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_MVLowRes : 0;
 		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_MVJittered : 0;
 		dlss_create_feature_flags |= is_hdr ? NVSDK_NGX_DLSS_Feature_Flags_IsHDR : 0;
-		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_DepthInverted : 0;
+		dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_DepthInverted : 0;
 		dlss_create_feature_flags |= sharpen ? NVSDK_NGX_DLSS_Feature_Flags_DoSharpening : 0;
 		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_AutoExposure : 0;
 
@@ -815,15 +883,15 @@ public:
 		dlss_create_params.InFeatureCreateFlags = dlss_create_feature_flags;
 
 		{
-			vk_command_buffer_begin();
-			ScopeGuard command_buffer_guard{[&]() { vk_command_buffer_end_and_submit_sync(); }};
+			m_vk_and_ngx->vk_command_buffer_begin();
+			ScopeGuard command_buffer_guard{[&]() { m_vk_and_ngx->vk_command_buffer_end_and_submit_sync(); }};
 
-			NGX_CHECK_THROW(NGX_VULKAN_CREATE_DLSS_EXT(vk_command_buffer, creation_node_mask, visibility_node_mask, &m_ngx_dlss, ngx_parameters, &dlss_create_params));
+			NGX_CHECK_THROW(NGX_VULKAN_CREATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), creation_node_mask, visibility_node_mask, &m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_create_params));
 		}
 	}
 
-	DlssFeature(const Eigen::Vector2i& out_resolution, bool is_hdr, bool sharpen, EDlssQuality quality)
-	: DlssFeature{dlss_feature_specs(out_resolution, quality), is_hdr, sharpen} {}
+	DlssFeature(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const Eigen::Vector2i& out_resolution, bool is_hdr, bool sharpen, EDlssQuality quality)
+	: DlssFeature{vk_and_ngx, dlss_feature_specs(vk_and_ngx->ngx_parameters(), out_resolution, quality), is_hdr, sharpen} {}
 
 	~DlssFeature() {
 		cudaDeviceSynchronize();
@@ -832,7 +900,7 @@ public:
 			NVSDK_NGX_VULKAN_ReleaseFeature(m_ngx_dlss);
 		}
 
-		vk_synchronize();
+		m_vk_and_ngx->vk_synchronize();
 	}
 
 	void run(
@@ -850,7 +918,7 @@ public:
 			throw std::runtime_error{"May only specify non-zero sharpening, when DlssFeature has been created with sharpen option."};
 		}
 
-		vk_command_buffer_begin();
+		m_vk_and_ngx->vk_command_buffer_begin();
 
 		NVSDK_NGX_VK_DLSS_Eval_Params dlss_params;
 		memset(&dlss_params, 0, sizeof(dlss_params));
@@ -868,9 +936,9 @@ public:
 		dlss_params.InMVScaleY = 1.0f;
 		dlss_params.InRenderSubrectDimensions = {(uint32_t)in_resolution.x(), (uint32_t)in_resolution.y()};
 
-		NGX_CHECK_THROW(NGX_VULKAN_EVALUATE_DLSS_EXT(vk_command_buffer, m_ngx_dlss, ngx_parameters, &dlss_params));
+		NGX_CHECK_THROW(NGX_VULKAN_EVALUATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_params));
 
-		vk_command_buffer_end_and_submit_sync();
+		m_vk_and_ngx->vk_command_buffer_end_and_submit_sync();
 	}
 
 	bool is_hdr() const {
@@ -898,6 +966,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk_and_ngx;
+
 	NVSDK_NGX_Handle* m_ngx_dlss = {};
 	DlssFeatureSpecs m_specs;
 	bool m_is_hdr;
@@ -906,28 +976,29 @@ private:
 
 class Dlss : public IDlss {
 public:
-	Dlss(const Eigen::Vector2i& max_out_resolution)
+	Dlss(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const Eigen::Vector2i& max_out_resolution)
 	:
+	m_vk_and_ngx{vk_and_ngx},
 	m_max_out_resolution{max_out_resolution},
 	// Allocate all buffers at output resolution and use dynamic sub-rects
 	// to use subsets of them. This avoids re-allocations when using DLSS
 	// with dynamically changing input resolution.
-	m_frame_buffer{max_out_resolution, 4},
-	m_depth_buffer{max_out_resolution, 1},
-	m_mvec_buffer{max_out_resolution, 2},
-	m_exposure_buffer{{1, 1}, 1},
-	m_output_buffer{max_out_resolution, 4}
+	m_frame_buffer{m_vk_and_ngx, max_out_resolution, 4},
+	m_depth_buffer{m_vk_and_ngx, max_out_resolution, 1},
+	m_mvec_buffer{m_vk_and_ngx, max_out_resolution, 2},
+	m_exposure_buffer{m_vk_and_ngx, {1, 1}, 1},
+	m_output_buffer{m_vk_and_ngx, max_out_resolution, 4}
 	{
 		// Various quality modes of DLSS
 		for (int i = 0; i < (int)EDlssQuality::NumDlssQualitySettings; ++i) {
 			try {
-				auto specs = dlss_feature_specs(max_out_resolution, (EDlssQuality)i);
+				auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), max_out_resolution, (EDlssQuality)i);
 
 				// Only emplace the specs if the feature can be created in practice!
-				DlssFeature{specs, true, true};
-				DlssFeature{specs, true, false};
-				DlssFeature{specs, false, true};
-				DlssFeature{specs, false, false};
+				DlssFeature{m_vk_and_ngx, specs, true, true};
+				DlssFeature{m_vk_and_ngx, specs, true, false};
+				DlssFeature{m_vk_and_ngx, specs, false, true};
+				DlssFeature{m_vk_and_ngx, specs, false, false};
 				m_dlss_specs.emplace_back(specs);
 			} catch (...) {}
 		}
@@ -943,13 +1014,13 @@ public:
 
 		for (const auto& out_resolution : reduced_out_resolutions) {
 			try {
-				auto specs = dlss_feature_specs(out_resolution, EDlssQuality::UltraPerformance);
+				auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), out_resolution, EDlssQuality::UltraPerformance);
 
 				// Only emplace the specs if the feature can be created in practice!
-				DlssFeature{specs, true, true};
-				DlssFeature{specs, true, false};
-				DlssFeature{specs, false, true};
-				DlssFeature{specs, false, false};
+				DlssFeature{m_vk_and_ngx, specs, true, true};
+				DlssFeature{m_vk_and_ngx, specs, true, false};
+				DlssFeature{m_vk_and_ngx, specs, false, true};
+				DlssFeature{m_vk_and_ngx, specs, false, false};
 				m_dlss_specs.emplace_back(specs);
 			} catch (...) {}
 		}
@@ -977,7 +1048,7 @@ public:
 		}
 
 		if (!m_dlss_feature || m_dlss_feature->is_hdr() != is_hdr || m_dlss_feature->sharpen() != sharpen || m_dlss_feature->quality() != specs.quality || m_dlss_feature->out_resolution() != specs.out_resolution) {
-			m_dlss_feature.reset(new DlssFeature{specs.out_resolution, is_hdr, sharpen, specs.quality});
+			m_dlss_feature.reset(new DlssFeature{m_vk_and_ngx, specs.out_resolution, is_hdr, sharpen, specs.quality});
 		}
 	}
 
@@ -1060,6 +1131,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk_and_ngx;
+
 	std::unique_ptr<DlssFeature> m_dlss_feature;
 	std::vector<DlssFeatureSpecs> m_dlss_specs;
 
@@ -1072,47 +1145,8 @@ private:
 	Vector2i m_max_out_resolution;
 };
 
-std::shared_ptr<IDlss> dlss_init(const Eigen::Vector2i& out_resolution) {
-	return std::make_shared<Dlss>(out_resolution);
-}
-
-void vulkan_and_ngx_destroy() {
-	if (ngx_parameters) {
-		NVSDK_NGX_VULKAN_DestroyParameters(ngx_parameters);
-		ngx_parameters = nullptr;
-	}
-
-	if (ngx_initialized) {
-		NVSDK_NGX_VULKAN_Shutdown();
-		ngx_initialized = false;
-	}
-
-	if (vk_command_pool) {
-		vkDestroyCommandPool(vk_device, vk_command_pool, nullptr);
-		vk_command_pool = VK_NULL_HANDLE;
-	}
-
-	if (vk_device) {
-		vkDestroyDevice(vk_device, nullptr);
-		vk_device = VK_NULL_HANDLE;
-	}
-
-	if (vk_debug_messenger) {
-		auto DestroyDebugUtilsMessengerEXT = [](VkInstance instance, VkDebugUtilsMessengerEXT debugMessenger, const VkAllocationCallbacks* pAllocator) {
-			auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkDestroyDebugUtilsMessengerEXT");
-			if (func != nullptr) {
-				func(instance, debugMessenger, pAllocator);
-			}
-		};
-
-		DestroyDebugUtilsMessengerEXT(vk_instance, vk_debug_messenger, nullptr);
-		vk_debug_messenger = VK_NULL_HANDLE;
-	}
-
-	if (vk_instance) {
-		vkDestroyInstance(vk_instance, nullptr);
-		vk_instance = VK_NULL_HANDLE;
-	}
+std::unique_ptr<IDlss> VulkanAndNgx::init_dlss(const Eigen::Vector2i& out_resolution) {
+	return std::make_unique<Dlss>(shared_from_this(), out_resolution);
 }
 
 NGP_NAMESPACE_END
diff --git a/src/main.cu b/src/main.cu
index f05b489598b3cb0208251230e359310954a8b2d3..ac79bd362f70f3e354fd2029bb4c907809dfa8aa 100644
--- a/src/main.cu
+++ b/src/main.cu
@@ -62,6 +62,13 @@ int main_func(const std::vector<std::string>& arguments) {
 		{"no-gui"},
 	};
 
+	Flag vr_flag{
+		parser,
+		"VR",
+		"Enables VR",
+		{"vr"}
+	};
+
 	Flag no_train_flag{
 		parser,
 		"NO_TRAIN",
@@ -170,6 +177,10 @@ int main_func(const std::vector<std::string>& arguments) {
 		testbed.init_window(width_flag ? get(width_flag) : 1920, height_flag ? get(height_flag) : 1080);
 	}
 
+	if (vr_flag) {
+		testbed.init_vr();
+	}
+
 	// Render/training loop
 	while (testbed.frame()) {
 		if (!gui) {
diff --git a/src/marching_cubes.cu b/src/marching_cubes.cu
index 28c28585ab86bd8e8104319ecde973a325fae723..2fc595405c5835268f7f83893e820bfe65cd714c 100644
--- a/src/marching_cubes.cu
+++ b/src/marching_cubes.cu
@@ -98,11 +98,11 @@ bool check_shader(uint32_t handle, const char* desc, bool program) {
 
 uint32_t compile_shader(bool pixel, const char* code) {
 	GLuint g_VertHandle = glCreateShader(pixel ? GL_FRAGMENT_SHADER : GL_VERTEX_SHADER );
-	const char* glsl_version = "#version 330\n";
+	const char* glsl_version = "#version 140\n";
 	const GLchar* strings[2] = { glsl_version, code};
 	glShaderSource(g_VertHandle, 2, strings, NULL);
 	glCompileShader(g_VertHandle);
-	if (!check_shader(g_VertHandle, pixel?"pixel":"vertex", false)) {
+	if (!check_shader(g_VertHandle, pixel? "pixel" : "vertex", false)) {
 		glDeleteShader(g_VertHandle);
 		return 0;
 	}
@@ -173,9 +173,9 @@ void draw_mesh_gl(
 
 	if (!program) {
 		vs = compile_shader(false, R"foo(
-layout (location = 0) in vec3 pos;
-layout (location = 1) in vec3 nor;
-layout (location = 2) in vec3 col;
+in vec3 pos;
+in vec3 nor;
+in vec3 col;
 out vec3 vtxcol;
 uniform mat4 camera;
 uniform vec2 f;
@@ -198,16 +198,11 @@ void main()
 }
 )foo");
 		ps = compile_shader(true, R"foo(
-layout (location = 0) out vec4 o;
+out vec4 o;
 in vec3 vtxcol;
 uniform int mode;
 void main() {
-	if (mode == 3) {
-		vec3 tricol = vec3((ivec3(923, 3572, 5423) * gl_PrimitiveID) & 255) * (1.0 / 255.0);
-		o = vec4(tricol, 1.0);
-	} else {
-		o = vec4(vtxcol, 1.0);
-	}
+	o = vec4(vtxcol, 1.0);
 }
 )foo");
 		program = glCreateProgram();
diff --git a/src/nerf_loader.cu b/src/nerf_loader.cu
index 3fd76ca003e845e183bdd3b914e3af12bc21c9c1..69af304f8e3aed090d0ffe81fd46d38dcfd4e1af 100644
--- a/src/nerf_loader.cu
+++ b/src/nerf_loader.cu
@@ -231,6 +231,10 @@ void read_lens(const nlohmann::json& json, Lens& lens, Vector2f& principal_point
 		mode = ELensMode::LatLong;
 	}
 
+	if (json.contains("equirectangular")) {
+		mode = ELensMode::Equirectangular;
+	}
+
 	// If there was an outer distortion mode, don't override it with nothing.
 	if (mode != ELensMode::Perspective) {
 		lens.mode = mode;
diff --git a/src/openxr_hmd.cu b/src/openxr_hmd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..b7d2cd4e8e4c8bd5d4b418b7797dbd805c064991
--- /dev/null
+++ b/src/openxr_hmd.cu
@@ -0,0 +1,1249 @@
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
+ * and proprietary rights in and to this software, related documentation
+ * and any modifications thereto.  Any use, reproduction, disclosure or
+ * distribution of this software and related documentation without an express
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
+ */
+
+/** @file   openxr_hmd.cu
+ *  @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA
+ *  @brief  Wrapper around the OpenXR API, providing access to
+ *          per-eye framebuffers, lens parameters, visible area,
+ *          view, hand, and eye poses, as well as controller inputs.
+ */
+
+#define NOMINMAX
+
+#include <neural-graphics-primitives/common_device.cuh>
+#include <neural-graphics-primitives/marching_cubes.h>
+#include <neural-graphics-primitives/openxr_hmd.h>
+#include <neural-graphics-primitives/render_buffer.h>
+
+#include <openxr/openxr_reflection.h>
+
+#include <fmt/format.h>
+
+#include <imgui/imgui.h>
+
+#include <tinylogger/tinylogger.h>
+
+#include <tiny-cuda-nn/common.h>
+
+#include <string>
+#include <vector>
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type
+#endif
+
+using namespace Eigen;
+using namespace tcnn;
+
+NGP_NAMESPACE_BEGIN
+
+// function XrEnumStr turns enum into string for printing
+// uses expansion macro and data provided in openxr_reflection.h
+#define XR_ENUM_CASE_STR(name, val) \
+	case name:                      \
+		return #name;
+#define XR_ENUM_STR(enum_type)                                                     \
+	constexpr const char* XrEnumStr(enum_type e) {                                 \
+		switch (e) {                                                               \
+			XR_LIST_ENUM_##enum_type(XR_ENUM_CASE_STR) default : return "Unknown"; \
+		}                                                                          \
+	}
+
+XR_ENUM_STR(XrViewConfigurationType)
+XR_ENUM_STR(XrEnvironmentBlendMode)
+XR_ENUM_STR(XrReferenceSpaceType)
+XR_ENUM_STR(XrStructureType)
+XR_ENUM_STR(XrSessionState)
+
+/// Checks the result of a xrXXXXXX call and throws an error on failure
+#define XR_CHECK_THROW(x)                                                                                   \
+	do {                                                                                                              \
+		XrResult result = x;                                                                                          \
+		if (XR_FAILED(result)) {                                                                                      \
+			char buffer[XR_MAX_RESULT_STRING_SIZE];                                                                   \
+			XrResult result_to_string_result = xrResultToString(m_instance, result, buffer);                            \
+			if (XR_FAILED(result_to_string_result)) {                                                                 \
+				throw std::runtime_error{std::string(FILE_LINE " " #x " failed, but could not obtain error string")}; \
+			} else {                                                                                                  \
+				throw std::runtime_error{std::string(FILE_LINE " " #x " failed with error ") + buffer};               \
+			}                                                                                                         \
+		}                                                                                                             \
+	} while(0)
+
+OpenXRHMD::Swapchain::Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& m_instance) {
+	ScopeGuard cleanup_guard{[&]() { clear(); }};
+
+	XR_CHECK_THROW(xrCreateSwapchain(session, &rgba_create_info, &handle));
+
+	width = rgba_create_info.width;
+	height = rgba_create_info.height;
+
+	{
+		uint32_t size;
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, 0, &size, nullptr));
+
+		images_gl.resize(size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR});
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, size, &size, (XrSwapchainImageBaseHeader*)images_gl.data()));
+
+		// One framebuffer per swapchain image
+		framebuffers_gl.resize(size);
+	}
+
+	if (depth_create_info.format != 0) {
+		XR_CHECK_THROW(xrCreateSwapchain(session, &depth_create_info, &depth_handle));
+
+		uint32_t depth_size;
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, 0, &depth_size, nullptr));
+
+		depth_images_gl.resize(depth_size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR});
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, depth_size, &depth_size, (XrSwapchainImageBaseHeader*)depth_images_gl.data()));
+
+		// We might have a different number of depth swapchain images as we have framebuffers,
+		// so we will need to bind an acquired depth image to the current framebuffer on the
+		// fly later on.
+	}
+
+	glGenFramebuffers(framebuffers_gl.size(), framebuffers_gl.data());
+
+	cleanup_guard.disarm();
+}
+
+OpenXRHMD::Swapchain::~Swapchain() {
+	clear();
+}
+
+void OpenXRHMD::Swapchain::clear() {
+	if (!framebuffers_gl.empty()) {
+		glDeleteFramebuffers(framebuffers_gl.size(), framebuffers_gl.data());
+	}
+
+	if (depth_handle != XR_NULL_HANDLE) {
+		xrDestroySwapchain(depth_handle);
+		depth_handle = XR_NULL_HANDLE;
+	}
+
+	if (handle != XR_NULL_HANDLE) {
+		xrDestroySwapchain(handle);
+		handle = XR_NULL_HANDLE;
+	}
+}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+OpenXRHMD::OpenXRHMD(HDC hdc, HGLRC hglrc) {
+#elif defined(XR_USE_PLATFORM_XLIB)
+OpenXRHMD::OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) {
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+OpenXRHMD::OpenXRHMD(wl_display* display) {
+#endif
+	ScopeGuard cleanup_guard{[&]() { clear(); }};
+
+	init_create_xr_instance();
+	init_get_xr_system();
+	init_configure_xr_views();
+	init_check_for_xr_blend_mode();
+#if defined(XR_USE_PLATFORM_WIN32)
+	init_open_gl(hdc, hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	init_open_gl(xDisplay, visualid, glxFBConfig, glxDrawable, glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	init_open_gl(display);
+#endif
+	init_xr_session();
+	init_xr_actions();
+	init_xr_spaces();
+	init_xr_swapchain_open_gl();
+	init_open_gl_shaders();
+
+	cleanup_guard.disarm();
+	tlog::success() << "Initialized OpenXR for " << m_system_properties.systemName;
+	// tlog::success() << " "
+	// 	<< " depth=" << (m_supports_composition_layer_depth ? "true" : "false")
+	// 	<< " mask=" << (m_supports_hidden_area_mask ? "true" : "false")
+	// 	<< " eye=" << (m_supports_hidden_area_mask ? "true" : "false")
+	// 	;
+}
+
+OpenXRHMD::~OpenXRHMD() {
+	clear();
+}
+
+void OpenXRHMD::clear() {
+	auto xr_destroy = [&](auto& handle, auto destroy_fun) {
+		if (handle != XR_NULL_HANDLE) {
+			destroy_fun(handle);
+			handle = XR_NULL_HANDLE;
+		}
+	};
+
+	xr_destroy(m_pose_action, xrDestroyAction);
+	xr_destroy(m_thumbstick_actions[0], xrDestroyAction);
+	xr_destroy(m_thumbstick_actions[1], xrDestroyAction);
+	xr_destroy(m_press_action, xrDestroyAction);
+	xr_destroy(m_grab_action, xrDestroyAction);
+
+	xr_destroy(m_action_set, xrDestroyActionSet);
+
+	m_swapchains.clear();
+	xr_destroy(m_space, xrDestroySpace);
+	xr_destroy(m_session, xrDestroySession);
+	xr_destroy(m_instance, xrDestroyInstance);
+}
+
+void OpenXRHMD::init_create_xr_instance() {
+	std::vector<const char*> layers = {};
+	std::vector<const char*> extensions = {
+		XR_KHR_OPENGL_ENABLE_EXTENSION_NAME,
+	};
+
+	auto print_extension_properties = [](const char* layer_name) {
+		uint32_t size;
+		xrEnumerateInstanceExtensionProperties(layer_name, 0, &size, nullptr);
+		std::vector<XrExtensionProperties> props(size, {XR_TYPE_EXTENSION_PROPERTIES});
+		xrEnumerateInstanceExtensionProperties(layer_name, size, &size, props.data());
+		tlog::info() << fmt::format("Extensions ({}):", props.size());
+		for (XrExtensionProperties extension : props) {
+			tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion);
+		}
+	};
+
+	uint32_t size;
+	xrEnumerateApiLayerProperties(0, &size, nullptr);
+	m_api_layer_properties.clear();
+	m_api_layer_properties.resize(size, {XR_TYPE_API_LAYER_PROPERTIES});
+	xrEnumerateApiLayerProperties(size, &size, m_api_layer_properties.data());
+
+	if (m_print_api_layers) {
+		tlog::info() << fmt::format("API Layers ({}):", m_api_layer_properties.size());
+		for (auto p : m_api_layer_properties) {
+			tlog::info() << fmt::format(
+				"{} (v {}.{}.{}, {}) {}",
+				p.layerName,
+				XR_VERSION_MAJOR(p.specVersion),
+				XR_VERSION_MINOR(p.specVersion),
+				XR_VERSION_PATCH(p.specVersion),
+				p.layerVersion,
+				p.description
+			);
+			print_extension_properties(p.layerName);
+		}
+	}
+
+	if (layers.size() != 0) {
+		for (const auto& e : layers) {
+			bool found = false;
+			for (XrApiLayerProperties layer : m_api_layer_properties) {
+				if (strcmp(e, layer.layerName) == 0) {
+					found = true;
+					break;
+				}
+			}
+
+			if (!found) {
+				throw std::runtime_error{fmt::format("OpenXR API layer {} not found", e)};
+			}
+		}
+	}
+
+	xrEnumerateInstanceExtensionProperties(nullptr, 0, &size, nullptr);
+	m_instance_extension_properties.clear();
+	m_instance_extension_properties.resize(size, {XR_TYPE_EXTENSION_PROPERTIES});
+	xrEnumerateInstanceExtensionProperties(nullptr, size, &size, m_instance_extension_properties.data());
+
+	if (m_print_extensions) {
+		tlog::info() << fmt::format("Instance extensions ({}):", m_instance_extension_properties.size());
+		for (XrExtensionProperties extension : m_instance_extension_properties) {
+			tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion);
+		}
+	}
+
+	auto has_extension = [&](const char* e) {
+		for (XrExtensionProperties extension : m_instance_extension_properties) {
+			if (strcmp(e, extension.extensionName) == 0) {
+				return true;
+			}
+		}
+
+		return false;
+	};
+
+	for (const auto& e : extensions) {
+		if (!has_extension(e)) {
+			throw std::runtime_error{fmt::format("Required OpenXR extension {} not found", e)};
+		}
+	}
+
+	auto add_extension_if_supported = [&](const char* extension) {
+		if (has_extension(extension)) {
+			extensions.emplace_back(extension);
+			return true;
+		}
+
+		return false;
+	};
+
+	if (add_extension_if_supported(XR_KHR_COMPOSITION_LAYER_DEPTH_EXTENSION_NAME)) {
+		m_supports_composition_layer_depth = true;
+	}
+
+	if (add_extension_if_supported(XR_KHR_VISIBILITY_MASK_EXTENSION_NAME)) {
+		m_supports_hidden_area_mask = true;
+	}
+
+	if (add_extension_if_supported(XR_EXT_EYE_GAZE_INTERACTION_EXTENSION_NAME)) {
+		m_supports_eye_tracking = true;
+	}
+
+	XrInstanceCreateInfo instance_create_info = {XR_TYPE_INSTANCE_CREATE_INFO};
+	instance_create_info.applicationInfo = {};
+	strncpy(instance_create_info.applicationInfo.applicationName, "Instant Neural Graphics Primitives v" NGP_VERSION, XR_MAX_APPLICATION_NAME_SIZE);
+	instance_create_info.applicationInfo.applicationVersion = 1;
+	strncpy(instance_create_info.applicationInfo.engineName, "Instant Neural Graphics Primitives v" NGP_VERSION, XR_MAX_ENGINE_NAME_SIZE);
+	instance_create_info.applicationInfo.engineVersion = 1;
+	instance_create_info.applicationInfo.apiVersion = XR_CURRENT_API_VERSION;
+	instance_create_info.enabledExtensionCount = (uint32_t)extensions.size();
+	instance_create_info.enabledExtensionNames = extensions.data();
+	instance_create_info.enabledApiLayerCount = (uint32_t)layers.size();
+	instance_create_info.enabledApiLayerNames = layers.data();
+
+	if (XR_FAILED(xrCreateInstance(&instance_create_info, &m_instance))) {
+		throw std::runtime_error{"Failed to create OpenXR instance"};
+	}
+
+	XR_CHECK_THROW(xrGetInstanceProperties(m_instance, &m_instance_properties));
+	if (m_print_instance_properties) {
+		tlog::info() << "Instance Properties";
+		tlog::info() << fmt::format("\t        runtime name: '{}'", m_instance_properties.runtimeName);
+		const auto& v = m_instance_properties.runtimeVersion;
+		tlog::info() << fmt::format(
+			"\t     runtime version: {}.{}.{}",
+			XR_VERSION_MAJOR(v),
+			XR_VERSION_MINOR(v),
+			XR_VERSION_PATCH(v)
+		);
+	}
+}
+
+void OpenXRHMD::init_get_xr_system() {
+	XrSystemGetInfo system_get_info = {XR_TYPE_SYSTEM_GET_INFO, nullptr, XR_FORM_FACTOR_HEAD_MOUNTED_DISPLAY};
+	XR_CHECK_THROW(xrGetSystem(m_instance, &system_get_info, &m_system_id));
+
+	XR_CHECK_THROW(xrGetSystemProperties(m_instance, m_system_id, &m_system_properties));
+	if (m_print_system_properties) {
+		tlog::info() << "System Properties";
+		tlog::info() << fmt::format("\t                name: '{}'", m_system_properties.systemName);
+		tlog::info() << fmt::format("\t            vendorId: {:#x}", m_system_properties.vendorId);
+		tlog::info() << fmt::format("\t            systemId: {:#x}", m_system_properties.systemId);
+		tlog::info() << fmt::format("\t     max layer count: {}", m_system_properties.graphicsProperties.maxLayerCount);
+		tlog::info() << fmt::format("\t       max img width: {}", m_system_properties.graphicsProperties.maxSwapchainImageWidth);
+		tlog::info() << fmt::format("\t      max img height: {}", m_system_properties.graphicsProperties.maxSwapchainImageHeight);
+		tlog::info() << fmt::format("\torientation tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO");
+		tlog::info() << fmt::format("\t   position tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO");
+	}
+}
+
+void OpenXRHMD::init_configure_xr_views() {
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, 0, &size, nullptr));
+	std::vector<XrViewConfigurationType> view_config_types(size);
+	XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, size, &size, view_config_types.data()));
+
+	if (m_print_view_configuration_types) {
+		tlog::info() << fmt::format("View Configuration Types ({}):", view_config_types.size());
+		for (const auto& t : view_config_types) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(t));
+		}
+	}
+
+	// view configurations we support, in descending preference
+	const std::vector<XrViewConfigurationType> preferred_view_config_types = {
+		//XR_VIEW_CONFIGURATION_TYPE_PRIMARY_QUAD_VARJO,
+		XR_VIEW_CONFIGURATION_TYPE_PRIMARY_STEREO
+	};
+
+	bool found = false;
+	for (const auto& p : preferred_view_config_types) {
+		for (const auto& t : view_config_types) {
+			if (p == t) {
+				found = true;
+				m_view_configuration_type = t;
+			}
+		}
+	}
+
+	if (!found) {
+		throw std::runtime_error{"Could not find a suitable OpenXR view configuration type"};
+	}
+
+	// get view configuration properties
+	XR_CHECK_THROW(xrGetViewConfigurationProperties(m_instance, m_system_id, m_view_configuration_type, &m_view_configuration_properties));
+	if (m_print_view_configuration_properties) {
+		tlog::info() << "View Configuration Properties:";
+		tlog::info() << fmt::format("\t         Type: {}", XrEnumStr(m_view_configuration_type));
+		tlog::info() << fmt::format("\t         FOV Mutable: {}", m_view_configuration_properties.fovMutable ? "YES" : "NO");
+	}
+
+	// enumerate view configuration views
+	XR_CHECK_THROW(xrEnumerateViewConfigurationViews(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr));
+	m_view_configuration_views.clear();
+	m_view_configuration_views.resize(size, {XR_TYPE_VIEW_CONFIGURATION_VIEW});
+	XR_CHECK_THROW(xrEnumerateViewConfigurationViews(
+		m_instance,
+		m_system_id,
+		m_view_configuration_type,
+		size,
+		&size,
+		m_view_configuration_views.data()
+	));
+
+	if (m_print_view_configuration_view) {
+		tlog::info() << "View Configuration Views, Width x Height x Samples";
+		for (size_t i = 0; i < m_view_configuration_views.size(); ++i) {
+			const auto& view = m_view_configuration_views[i];
+			tlog::info() << fmt::format(
+				"\tView {}\tRecommended: {}x{}x{}  Max: {}x{}x{}",
+				i,
+				view.recommendedImageRectWidth,
+				view.recommendedImageRectHeight,
+				view.recommendedSwapchainSampleCount,
+				view.maxImageRectWidth,
+				view.maxImageRectHeight,
+				view.maxSwapchainSampleCount
+			);
+		}
+	}
+}
+
+void OpenXRHMD::init_check_for_xr_blend_mode() {
+	// enumerate environment blend modes
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr));
+	m_environment_blend_modes.resize(size);
+	XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes(
+		m_instance,
+		m_system_id,
+		m_view_configuration_type,
+		size,
+		&size,
+		m_environment_blend_modes.data()
+	));
+
+	if (m_print_environment_blend_modes) {
+		tlog::info() << fmt::format("Environment Blend Modes ({}):", m_environment_blend_modes.size());
+	}
+
+	bool found = false;
+	for (const auto& m : m_environment_blend_modes) {
+		if (m_print_environment_blend_modes) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(m));
+		}
+
+		if (m == m_environment_blend_mode) {
+			found = true;
+		}
+	}
+
+	if (!found) {
+		throw std::runtime_error{fmt::format("OpenXR environment blend mode {} not found", XrEnumStr(m_environment_blend_mode))};
+	}
+}
+
+void OpenXRHMD::init_xr_actions() {
+	// paths for left (0) and right (1) hands
+	XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/left", &m_hand_paths[0]));
+	XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/right", &m_hand_paths[1]));
+
+	// create action set
+	XrActionSetCreateInfo action_set_create_info{XR_TYPE_ACTION_SET_CREATE_INFO, nullptr, "actionset", "actionset", 0};
+	XR_CHECK_THROW(xrCreateActionSet(m_instance, &action_set_create_info, &m_action_set));
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"hand_pose",
+			XR_ACTION_TYPE_POSE_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Hand pose"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_pose_action));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"thumbstick_left",
+			XR_ACTION_TYPE_VECTOR2F_INPUT,
+			0,
+			nullptr,
+			"Left thumbstick"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[0]));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"thumbstick_right",
+			XR_ACTION_TYPE_VECTOR2F_INPUT,
+			0,
+			nullptr,
+			"Right thumbstick"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[1]));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"press",
+			XR_ACTION_TYPE_BOOLEAN_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Press"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_press_action));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"grab",
+			XR_ACTION_TYPE_FLOAT_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Grab"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_grab_action));
+	}
+
+	auto create_binding = [&](XrAction action, const std::string& binding_path_str) {
+		XrPath binding;
+		XR_CHECK_THROW(xrStringToPath(m_instance, binding_path_str.c_str(), &binding));
+		return XrActionSuggestedBinding{action, binding};
+	};
+
+	auto suggest_bindings = [&](const std::string& interaction_profile_path_str, const std::vector<XrActionSuggestedBinding>& bindings) {
+		XrPath interaction_profile;
+		XR_CHECK_THROW(xrStringToPath(m_instance, interaction_profile_path_str.c_str(), &interaction_profile));
+		XrInteractionProfileSuggestedBinding suggested_binding{
+			XR_TYPE_INTERACTION_PROFILE_SUGGESTED_BINDING,
+			nullptr,
+			interaction_profile,
+			(uint32_t)bindings.size(),
+			bindings.data()
+		};
+		XR_CHECK_THROW(xrSuggestInteractionProfileBindings(m_instance, &suggested_binding));
+	};
+
+	suggest_bindings("/interaction_profiles/khr/simple_controller", {
+		create_binding(m_pose_action, "/user/hand/left/input/grip/pose"),
+		create_binding(m_pose_action, "/user/hand/right/input/grip/pose"),
+	});
+
+	auto suggest_controller_bindings = [&](const std::string& xy, const std::string& press, const std::string& grab, const std::string& squeeze, const std::string& interaction_profile_path_str) {
+		std::vector<XrActionSuggestedBinding> bindings = {
+			create_binding(m_pose_action, "/user/hand/left/input/grip/pose"),
+			create_binding(m_pose_action, "/user/hand/right/input/grip/pose"),
+			create_binding(m_thumbstick_actions[0], std::string{"/user/hand/left/input/"} + xy),
+			create_binding(m_thumbstick_actions[1], std::string{"/user/hand/right/input/"} + xy),
+			create_binding(m_press_action, std::string{"/user/hand/left/input/"} + press),
+			create_binding(m_press_action, std::string{"/user/hand/right/input/"} + press),
+			create_binding(m_grab_action, std::string{"/user/hand/left/input/"} + grab),
+			create_binding(m_grab_action, std::string{"/user/hand/right/input/"} + grab),
+		};
+
+		if (!squeeze.empty()) {
+			bindings.emplace_back(create_binding(m_grab_action, std::string{"/user/hand/left/input/"} + squeeze));
+			bindings.emplace_back(create_binding(m_grab_action, std::string{"/user/hand/right/input/"} + squeeze));
+		}
+
+		suggest_bindings(interaction_profile_path_str, bindings);
+	};
+
+	suggest_controller_bindings("trackpad",   "select/click",     "trackpad/click", "",                  "/interaction_profiles/google/daydream_controller");
+	suggest_controller_bindings("trackpad",   "trackpad/click",   "trigger/click",  "squeeze/click",     "/interaction_profiles/htc/vive_controller");
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  "squeeze/click",     "/interaction_profiles/microsoft/motion_controller");
+	suggest_controller_bindings("trackpad",   "trackpad/click",   "trigger/click",  "",                  "/interaction_profiles/oculus/go_controller");
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  "squeeze/value",     "/interaction_profiles/oculus/touch_controller");
+
+	// Valve Index force squeeze is very sensitive and can cause unwanted grabbing. Only permit trigger-grabbing for now.
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  ""/*squeeze/force*/, "/interaction_profiles/valve/index_controller");
+
+	// Xbox controller (currently not functional)
+	suggest_bindings("/interaction_profiles/microsoft/xbox_controller", {
+		create_binding(m_thumbstick_actions[0], std::string{"/user/gamepad/input/thumbstick_left"}),
+		create_binding(m_thumbstick_actions[1], std::string{"/user/gamepad/input/thumbstick_right"}),
+	});
+}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+void OpenXRHMD::init_open_gl(HDC hdc, HGLRC hglrc) {
+#elif defined(XR_USE_PLATFORM_XLIB)
+void OpenXRHMD::init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) {
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+void OpenXRHMD::init_open_gl(wl_display* display) {
+#endif
+	// GL graphics requirements
+	PFN_xrGetOpenGLGraphicsRequirementsKHR xrGetOpenGLGraphicsRequirementsKHR = nullptr;
+	XR_CHECK_THROW(xrGetInstanceProcAddr(
+		m_instance,
+		"xrGetOpenGLGraphicsRequirementsKHR",
+		reinterpret_cast<PFN_xrVoidFunction*>(&xrGetOpenGLGraphicsRequirementsKHR)
+	));
+
+	XrGraphicsRequirementsOpenGLKHR graphics_requirements{XR_TYPE_GRAPHICS_REQUIREMENTS_OPENGL_KHR};
+	xrGetOpenGLGraphicsRequirementsKHR(m_instance, m_system_id, &graphics_requirements);
+	XrVersion min_version = graphics_requirements.minApiVersionSupported;
+	GLint major = 0;
+	GLint minor = 0;
+	glGetIntegerv(GL_MAJOR_VERSION, &major);
+	glGetIntegerv(GL_MINOR_VERSION, &minor);
+	const XrVersion have_version = XR_MAKE_VERSION(major, minor, 0);
+
+	if (have_version < min_version) {
+		tlog::info() << fmt::format(
+			"Required OpenGL version: {}.{}, found OpenGL version: {}.{}",
+			XR_VERSION_MAJOR(min_version),
+			XR_VERSION_MINOR(min_version),
+			major,
+			minor
+		);
+
+		throw std::runtime_error{"Insufficient graphics API support"};
+	}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	m_graphics_binding.hDC = hdc;
+	m_graphics_binding.hGLRC = hglrc;
+#elif defined(XR_USE_PLATFORM_XLIB)
+	m_graphics_binding.xDisplay = xDisplay;
+	m_graphics_binding.visualid = visualid;
+	m_graphics_binding.glxFBConfig = glxFBConfig;
+	m_graphics_binding.glxDrawable = glxDrawable;
+	m_graphics_binding.glxContext = glxContext;
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	m_graphics_binding.display = display;
+#endif
+}
+
+void OpenXRHMD::init_xr_session() {
+	// create session
+	XrSessionCreateInfo create_info{
+		XR_TYPE_SESSION_CREATE_INFO,
+		reinterpret_cast<const XrBaseInStructure*>(&m_graphics_binding),
+		0,
+		m_system_id
+	};
+
+	XR_CHECK_THROW(xrCreateSession(m_instance, &create_info, &m_session));
+
+	// tlog::info() << fmt::format("Created session {}", fmt::ptr(m_session));
+}
+
+void OpenXRHMD::init_xr_spaces() {
+	// reference space
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, 0, &size, nullptr));
+	m_reference_spaces.clear();
+	m_reference_spaces.resize(size);
+	XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, size, &size, m_reference_spaces.data()));
+
+	if (m_print_reference_spaces) {
+		tlog::info() << fmt::format("Reference spaces ({}):", m_reference_spaces.size());
+		for (const auto& r : m_reference_spaces) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(r));
+		}
+	}
+
+	XrReferenceSpaceCreateInfo reference_space_create_info{XR_TYPE_REFERENCE_SPACE_CREATE_INFO};
+	reference_space_create_info.referenceSpaceType = XR_REFERENCE_SPACE_TYPE_LOCAL;
+	reference_space_create_info.poseInReferenceSpace = XrPosef{};
+	reference_space_create_info.poseInReferenceSpace.orientation.w = 1.0f;
+	XR_CHECK_THROW(xrCreateReferenceSpace(m_session, &reference_space_create_info, &m_space));
+	XR_CHECK_THROW(xrGetReferenceSpaceBoundsRect(m_session, reference_space_create_info.referenceSpaceType, &m_bounds));
+
+	if (m_print_reference_spaces) {
+		tlog::info() << fmt::format("Using reference space {}", XrEnumStr(reference_space_create_info.referenceSpaceType));
+		tlog::info() << fmt::format("Reference space boundaries: {} x {}", m_bounds.width, m_bounds.height);
+	}
+
+	// action space
+	XrActionSpaceCreateInfo action_space_create_info{XR_TYPE_ACTION_SPACE_CREATE_INFO};
+	action_space_create_info.action = m_pose_action;
+	action_space_create_info.poseInActionSpace.orientation.w = 1.0f;
+	action_space_create_info.subactionPath = m_hand_paths[0];
+	XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[0]));
+	action_space_create_info.subactionPath = m_hand_paths[1];
+	XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[1]));
+
+	// attach action set
+	XrSessionActionSetsAttachInfo attach_info{XR_TYPE_SESSION_ACTION_SETS_ATTACH_INFO};
+	attach_info.countActionSets = 1;
+	attach_info.actionSets = &m_action_set;
+	XR_CHECK_THROW(xrAttachSessionActionSets(m_session, &attach_info));
+}
+
+void OpenXRHMD::init_xr_swapchain_open_gl() {
+	// swap chains
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, 0, &size, nullptr));
+	std::vector<int64_t> swapchain_formats(size);
+	XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, size, &size, swapchain_formats.data()));
+
+	if (m_print_available_swapchain_formats) {
+		tlog::info() << fmt::format("Swapchain formats ({}):", swapchain_formats.size());
+		for (const auto& f : swapchain_formats) {
+			tlog::info() << fmt::format("\t{:#x}", f);
+		}
+	}
+
+	auto find_compatible_swapchain_format = [&](const std::vector<int64_t>& candidates) {
+		for (auto format : candidates) {
+			if (std::find(std::begin(swapchain_formats), std::end(swapchain_formats), format) != std::end(swapchain_formats)) {
+				return format;
+			}
+		}
+
+		throw std::runtime_error{"No compatible OpenXR swapchain format found"};
+	};
+
+	m_swapchain_rgba_format = find_compatible_swapchain_format({
+		GL_SRGB8_ALPHA8,
+		GL_SRGB8,
+		GL_RGBA8,
+	});
+
+	if (m_supports_composition_layer_depth) {
+		m_swapchain_depth_format = find_compatible_swapchain_format({
+			GL_DEPTH_COMPONENT32F,
+			GL_DEPTH_COMPONENT24,
+			GL_DEPTH_COMPONENT16,
+		});
+	}
+
+	// tlog::info() << fmt::format("Chosen swapchain format: {:#x}", m_swapchain_rgba_format);
+	for (const auto& vcv : m_view_configuration_views) {
+		XrSwapchainCreateInfo rgba_swapchain_create_info{XR_TYPE_SWAPCHAIN_CREATE_INFO};
+		rgba_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_COLOR_ATTACHMENT_BIT;
+		rgba_swapchain_create_info.format = m_swapchain_rgba_format;
+		rgba_swapchain_create_info.sampleCount = 1;
+		rgba_swapchain_create_info.width = vcv.recommendedImageRectWidth;
+		rgba_swapchain_create_info.height = vcv.recommendedImageRectHeight;
+		rgba_swapchain_create_info.faceCount = 1;
+		rgba_swapchain_create_info.arraySize = 1;
+		rgba_swapchain_create_info.mipCount = 1;
+
+		XrSwapchainCreateInfo depth_swapchain_create_info = rgba_swapchain_create_info;
+		depth_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT;
+		depth_swapchain_create_info.format = m_swapchain_depth_format;
+
+		m_swapchains.emplace_back(rgba_swapchain_create_info, depth_swapchain_create_info, m_session, m_instance);
+	}
+}
+
+void OpenXRHMD::init_open_gl_shaders() {
+	// Hidden area mask program
+	{
+		static const char* shader_vert = R"(#version 140
+			in vec2 pos;
+			uniform mat4 project;
+			void main() {
+				vec4 pos = project * vec4(pos, -1.0, 1.0);
+				pos.xyz /= pos.w;
+				pos.y *= -1.0;
+				gl_Position = pos;
+			})";
+
+		static const char* shader_frag = R"(#version 140
+			out vec4 frag_color;
+			void main() {
+				frag_color = vec4(0.0, 0.0, 0.0, 1.0);
+			})";
+
+		GLuint vert = glCreateShader(GL_VERTEX_SHADER);
+		glShaderSource(vert, 1, &shader_vert, NULL);
+		glCompileShader(vert);
+		check_shader(vert, "OpenXR hidden area mask vertex shader", false);
+
+		GLuint frag = glCreateShader(GL_FRAGMENT_SHADER);
+		glShaderSource(frag, 1, &shader_frag, NULL);
+		glCompileShader(frag);
+		check_shader(frag, "OpenXR hidden area mask fragment shader", false);
+
+		m_hidden_area_mask_program = glCreateProgram();
+		glAttachShader(m_hidden_area_mask_program, vert);
+		glAttachShader(m_hidden_area_mask_program, frag);
+		glLinkProgram(m_hidden_area_mask_program);
+		check_shader(m_hidden_area_mask_program, "OpenXR hidden area mask shader program", true);
+
+		glDeleteShader(vert);
+		glDeleteShader(frag);
+	}
+}
+
+void OpenXRHMD::session_state_change(XrSessionState state, ControlFlow& flow) {
+	//tlog::info() << fmt::format("New session state {}", XrEnumStr(state));
+	switch (state) {
+		case XR_SESSION_STATE_READY: {
+			XrSessionBeginInfo sessionBeginInfo {XR_TYPE_SESSION_BEGIN_INFO};
+			sessionBeginInfo.primaryViewConfigurationType = m_view_configuration_type;
+			XR_CHECK_THROW(xrBeginSession(m_session, &sessionBeginInfo));
+			break;
+		}
+		case XR_SESSION_STATE_STOPPING: {
+			XR_CHECK_THROW(xrEndSession(m_session));
+			break;
+		}
+		case XR_SESSION_STATE_EXITING: {
+			flow = ControlFlow::QUIT;
+			break;
+		}
+		case XR_SESSION_STATE_LOSS_PENDING: {
+			flow = ControlFlow::RESTART;
+			break;
+		}
+		default: {
+			break;
+		}
+	}
+}
+
+OpenXRHMD::ControlFlow OpenXRHMD::poll_events() {
+	bool more = true;
+	ControlFlow flow = ControlFlow::CONTINUE;
+	while (more) {
+		// poll events
+		XrEventDataBuffer event {XR_TYPE_EVENT_DATA_BUFFER, nullptr};
+		XrResult result = xrPollEvent(m_instance, &event);
+
+		if (XR_FAILED(result)) {
+			tlog::error() << "xrPollEvent failed";
+		} else if (XR_SUCCESS == result) {
+			switch (event.type) {
+				case XR_TYPE_EVENT_DATA_SESSION_STATE_CHANGED: {
+					const XrEventDataSessionStateChanged& e = *reinterpret_cast<XrEventDataSessionStateChanged*>(&event);
+					//tlog::info() << "Session state change";
+					//tlog::info() << fmt::format("\t from {}\t   to {}", XrEnumStr(m_session_state), XrEnumStr(e.state));
+					//tlog::info() << fmt::format("\t session {}, time {}", fmt::ptr(e.session), e.time);
+					m_session_state = e.state;
+					session_state_change(e.state, flow);
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_INSTANCE_LOSS_PENDING: {
+					flow = ControlFlow::RESTART;
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_VISIBILITY_MASK_CHANGED_KHR: {
+					m_hidden_area_masks.clear();
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_INTERACTION_PROFILE_CHANGED: {
+					break; // Can ignore
+				}
+
+				default: {
+					tlog::info() << fmt::format("Unhandled event type {}", XrEnumStr(event.type));
+					break;
+				}
+			}
+		} else if (XR_EVENT_UNAVAILABLE == result) {
+			more = false;
+		}
+	}
+	return flow;
+}
+
+__global__ void read_hidden_area_mask_kernel(const Vector2i resolution, cudaSurfaceObject_t surface, uint8_t* __restrict__ mask) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	uint32_t idx = x + resolution.x() * y;
+	surf2Dread(&mask[idx], surface, x, y);
+}
+
+std::shared_ptr<Buffer2D<uint8_t>> OpenXRHMD::rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view) {
+	if (!m_supports_hidden_area_mask) {
+		return {};
+	}
+
+	PFN_xrGetVisibilityMaskKHR xrGetVisibilityMaskKHR = nullptr;
+	XR_CHECK_THROW(xrGetInstanceProcAddr(
+		m_instance,
+		"xrGetVisibilityMaskKHR",
+		reinterpret_cast<PFN_xrVoidFunction*>(&xrGetVisibilityMaskKHR)
+	));
+
+	XrVisibilityMaskKHR visibility_mask{XR_TYPE_VISIBILITY_MASK_KHR};
+	XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask));
+
+	if (visibility_mask.vertexCountOutput == 0 || visibility_mask.indexCountOutput == 0) {
+		return nullptr;
+	}
+
+	std::vector<XrVector2f> vertices(visibility_mask.vertexCountOutput);
+	std::vector<uint32_t> indices(visibility_mask.indexCountOutput);
+
+	visibility_mask.vertices = vertices.data();
+	visibility_mask.indices = indices.data();
+
+	visibility_mask.vertexCapacityInput = visibility_mask.vertexCountOutput;
+	visibility_mask.indexCapacityInput = visibility_mask.indexCountOutput;
+
+	XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask));
+
+	CUDA_CHECK_THROW(cudaDeviceSynchronize());
+
+	Vector2i size = {view.subImage.imageRect.extent.width, view.subImage.imageRect.extent.height};
+
+	bool tex = glIsEnabled(GL_TEXTURE_2D);
+	bool depth = glIsEnabled(GL_DEPTH_TEST);
+	bool cull = glIsEnabled(GL_CULL_FACE);
+	GLint previous_texture_id;
+	glGetIntegerv(GL_TEXTURE_BINDING_2D, &previous_texture_id);
+
+	if (!tex) glEnable(GL_TEXTURE_2D);
+	if (depth) glDisable(GL_DEPTH_TEST);
+	if (cull) glDisable(GL_CULL_FACE);
+
+	// Generate texture to hold hidden area mask. Single channel, value of 1 means visible and 0 means masked away
+	ngp::GLTexture mask_texture;
+	mask_texture.resize(size, 1, true);
+	glBindTexture(GL_TEXTURE_2D, mask_texture.texture());
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+
+	GLuint framebuffer = 0;
+	glGenFramebuffers(1, &framebuffer);
+	glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
+	glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, mask_texture.texture(), 0);
+
+	GLenum draw_buffers[1] = {GL_COLOR_ATTACHMENT0};
+	glDrawBuffers(1, draw_buffers);
+
+	glViewport(0, 0, size.x(), size.y());
+
+	// Draw hidden area mask
+	GLuint vao;
+	glGenVertexArrays(1, &vao);
+	glBindVertexArray(vao);
+
+	GLuint vertex_buffer;
+	glGenBuffers(1, &vertex_buffer);
+	glEnableVertexAttribArray(0);
+	glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer);
+	glBufferData(GL_ARRAY_BUFFER, sizeof(XrVector2f) * vertices.size(), vertices.data(), GL_STATIC_DRAW);
+	glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 0, (void*)0);
+
+	GLuint index_buffer;
+	glGenBuffers(1, &index_buffer);
+	glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, index_buffer);
+	glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(uint32_t) * indices.size(), indices.data(), GL_STATIC_DRAW);
+
+	glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
+	glClear(GL_COLOR_BUFFER_BIT);
+	glUseProgram(m_hidden_area_mask_program);
+
+	XrMatrix4x4f proj;
+	XrMatrix4x4f_CreateProjectionFov(&proj, GRAPHICS_OPENGL, view.fov, 1.0f / 128.0f, 128.0f);
+
+	GLuint project_id = glGetUniformLocation(m_hidden_area_mask_program, "project");
+	glUniformMatrix4fv(project_id, 1, GL_FALSE, &proj.m[0]);
+
+	glDrawElements(GL_TRIANGLES, indices.size(), GL_UNSIGNED_INT, (void*)0);
+	glFinish();
+
+	glDisableVertexAttribArray(0);
+	glDeleteBuffers(1, &vertex_buffer);
+	glDeleteBuffers(1, &index_buffer);
+	glDeleteVertexArrays(1, &vao);
+	glDeleteFramebuffers(1, &framebuffer);
+
+	glBindVertexArray(0);
+	glUseProgram(0);
+
+	// restore old state
+	if (!tex) glDisable(GL_TEXTURE_2D);
+	if (depth) glEnable(GL_DEPTH_TEST);
+	if (cull) glEnable(GL_CULL_FACE);
+	glBindTexture(GL_TEXTURE_2D, previous_texture_id);
+	glBindFramebuffer(GL_FRAMEBUFFER, 0);
+
+	std::shared_ptr<Buffer2D<uint8_t>> mask = std::make_shared<Buffer2D<uint8_t>>(size);
+
+	const dim3 threads = { 16, 8, 1 };
+	const dim3 blocks = { div_round_up((uint32_t)size.x(), threads.x), div_round_up((uint32_t)size.y(), threads.y), 1 };
+
+	read_hidden_area_mask_kernel<<<blocks, threads>>>(size, mask_texture.surface(), mask->data());
+	CUDA_CHECK_THROW(cudaDeviceSynchronize());
+
+	return mask;
+}
+
+Matrix<float, 3, 4> convert_xr_matrix_to_eigen(const XrMatrix4x4f& m) {
+	Matrix<float, 3, 4> out;
+
+	for (size_t i = 0; i < 3; ++i) {
+		for (size_t j = 0; j < 4; ++j) {
+			out(i, j) = m.m[i + j * 4];
+		}
+	}
+
+	// Flip Y and Z axes to match NGP conventions
+	out(0, 1) *= -1.f;
+	out(1, 0) *= -1.f;
+
+	out(0, 2) *= -1.f;
+	out(2, 0) *= -1.f;
+
+	out(1, 3) *= -1.f;
+	out(2, 3) *= -1.f;
+
+	return out;
+}
+
+Matrix<float, 3, 4> convert_xr_pose_to_eigen(const XrPosef& pose) {
+	XrMatrix4x4f matrix;
+	XrVector3f unit_scale{1.0f, 1.0f, 1.0f};
+	XrMatrix4x4f_CreateTranslationRotationScale(&matrix, &pose.position, &pose.orientation, &unit_scale);
+	return convert_xr_matrix_to_eigen(matrix);
+}
+
+OpenXRHMD::FrameInfoPtr OpenXRHMD::begin_frame() {
+	XrFrameWaitInfo frame_wait_info{XR_TYPE_FRAME_WAIT_INFO};
+	XR_CHECK_THROW(xrWaitFrame(m_session, &frame_wait_info, &m_frame_state));
+
+	XrFrameBeginInfo frame_begin_info{XR_TYPE_FRAME_BEGIN_INFO};
+	XR_CHECK_THROW(xrBeginFrame(m_session, &frame_begin_info));
+
+	if (!m_frame_state.shouldRender) {
+		return std::make_shared<FrameInfo>();
+	}
+
+	uint32_t num_views = (uint32_t)m_swapchains.size();
+	// TODO assert m_view_configuration_views.size() == m_swapchains.size()
+
+	// locate views
+	std::vector<XrView> views(num_views, {XR_TYPE_VIEW});
+
+	XrViewState viewState{XR_TYPE_VIEW_STATE};
+
+	XrViewLocateInfo view_locate_info{XR_TYPE_VIEW_LOCATE_INFO};
+	view_locate_info.viewConfigurationType = m_view_configuration_type;
+	view_locate_info.displayTime = m_frame_state.predictedDisplayTime;
+	view_locate_info.space = m_space;
+
+	XR_CHECK_THROW(xrLocateViews(m_session, &view_locate_info, &viewState, uint32_t(views.size()), &num_views, views.data()));
+
+	if (!(viewState.viewStateFlags & XR_VIEW_STATE_POSITION_VALID_BIT) || !(viewState.viewStateFlags & XR_VIEW_STATE_ORIENTATION_VALID_BIT)) {
+		return std::make_shared<FrameInfo>();
+	}
+
+	m_hidden_area_masks.resize(num_views);
+
+	// Fill frame information
+	if (!m_previous_frame_info) {
+		m_previous_frame_info = std::make_shared<FrameInfo>();
+	}
+
+	FrameInfoPtr frame_info = std::make_shared<FrameInfo>(*m_previous_frame_info);
+	frame_info->views.resize(m_swapchains.size());
+
+	for (size_t i = 0; i < m_swapchains.size(); ++i) {
+		const auto& sc = m_swapchains[i];
+
+		XrSwapchainImageAcquireInfo image_acquire_info{XR_TYPE_SWAPCHAIN_IMAGE_ACQUIRE_INFO};
+		XrSwapchainImageWaitInfo image_wait_info{XR_TYPE_SWAPCHAIN_IMAGE_WAIT_INFO, nullptr, XR_INFINITE_DURATION};
+
+		uint32_t image_index;
+		XR_CHECK_THROW(xrAcquireSwapchainImage(sc.handle, &image_acquire_info, &image_index));
+		XR_CHECK_THROW(xrWaitSwapchainImage(sc.handle, &image_wait_info));
+
+		FrameInfo::View& v = frame_info->views[i];
+		v.framebuffer = sc.framebuffers_gl[image_index];
+		v.view.pose = views[i].pose;
+		v.view.fov = views[i].fov;
+		v.view.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}};
+		v.view.subImage.imageArrayIndex = 0;
+		v.view.subImage.swapchain = sc.handle;
+
+		glBindFramebuffer(GL_FRAMEBUFFER, sc.framebuffers_gl[image_index]);
+		glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
+		glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
+		glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, sc.images_gl.at(image_index).image, 0);
+
+		if (sc.depth_handle != XR_NULL_HANDLE) {
+			uint32_t depth_image_index;
+			XR_CHECK_THROW(xrAcquireSwapchainImage(sc.depth_handle, &image_acquire_info, &depth_image_index));
+			XR_CHECK_THROW(xrWaitSwapchainImage(sc.depth_handle, &image_wait_info));
+
+			glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, sc.depth_images_gl.at(depth_image_index).image, 0);
+
+			v.depth_info.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}};
+			v.depth_info.subImage.imageArrayIndex = 0;
+			v.depth_info.subImage.swapchain = sc.depth_handle;
+			v.depth_info.minDepth = 0.0f;
+			v.depth_info.maxDepth = 1.0f;
+
+			// To be overwritten with actual near and far planes by end_frame
+			v.depth_info.nearZ = 1.0f / 128.0f;
+			v.depth_info.farZ = 128.0f;
+		}
+
+		glBindFramebuffer(GL_FRAMEBUFFER, 0);
+
+		if (!m_hidden_area_masks.at(i)) {
+			m_hidden_area_masks.at(i) = rasterize_hidden_area_mask(i, v.view);
+		}
+
+		v.hidden_area_mask = m_hidden_area_masks.at(i);
+		v.pose = convert_xr_pose_to_eigen(v.view.pose);
+	}
+
+	XrActiveActionSet active_action_set{m_action_set, XR_NULL_PATH};
+	XrActionsSyncInfo sync_info{XR_TYPE_ACTIONS_SYNC_INFO};
+	sync_info.countActiveActionSets = 1;
+	sync_info.activeActionSets = &active_action_set;
+	XR_CHECK_THROW(xrSyncActions(m_session, &sync_info));
+
+	for (size_t i = 0; i < 2; ++i) {
+		// Hand pose
+		{
+			XrActionStatePose pose_state{XR_TYPE_ACTION_STATE_POSE};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_pose_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStatePose(m_session, &get_info, &pose_state));
+
+			frame_info->hands[i].pose_active = pose_state.isActive;
+			if (frame_info->hands[i].pose_active) {
+				XrSpaceLocation space_location{XR_TYPE_SPACE_LOCATION};
+				XR_CHECK_THROW(xrLocateSpace(m_hand_spaces[i], m_space, m_frame_state.predictedDisplayTime, &space_location));
+				frame_info->hands[i].pose = convert_xr_pose_to_eigen(space_location.pose);
+			}
+		}
+
+		// Stick
+		{
+			XrActionStateVector2f thumbstick_state{XR_TYPE_ACTION_STATE_VECTOR2F};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_thumbstick_actions[i];
+			XR_CHECK_THROW(xrGetActionStateVector2f(m_session, &get_info, &thumbstick_state));
+
+			if (thumbstick_state.isActive) {
+				frame_info->hands[i].thumbstick.x() = thumbstick_state.currentState.x;
+				frame_info->hands[i].thumbstick.y() = thumbstick_state.currentState.y;
+			} else {
+				frame_info->hands[i].thumbstick = Vector2f::Zero();
+			}
+		}
+
+		// Press
+		{
+			XrActionStateBoolean press_state{XR_TYPE_ACTION_STATE_BOOLEAN};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_press_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStateBoolean(m_session, &get_info, &press_state));
+
+			if (press_state.isActive) {
+				frame_info->hands[i].pressing = press_state.currentState;
+			} else {
+				frame_info->hands[i].pressing = 0.0f;
+			}
+		}
+
+		// Grab
+		{
+			XrActionStateFloat grab_state{XR_TYPE_ACTION_STATE_FLOAT};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_grab_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStateFloat(m_session, &get_info, &grab_state));
+
+			if (grab_state.isActive) {
+				frame_info->hands[i].grab_strength = grab_state.currentState;
+			} else {
+				frame_info->hands[i].grab_strength = 0.0f;
+			}
+
+			bool was_grabbing = frame_info->hands[i].grabbing;
+			frame_info->hands[i].grabbing = frame_info->hands[i].grab_strength >= 0.5f;
+
+			if (frame_info->hands[i].grabbing) {
+				frame_info->hands[i].prev_grab_pos = was_grabbing ? frame_info->hands[i].grab_pos : frame_info->hands[i].pose.col(3);
+				frame_info->hands[i].grab_pos = frame_info->hands[i].pose.col(3);
+			}
+		}
+	}
+
+	m_previous_frame_info = frame_info;
+	return frame_info;
+}
+
+void OpenXRHMD::end_frame(FrameInfoPtr frame_info, float znear, float zfar) {
+	std::vector<XrCompositionLayerProjectionView> layer_projection_views(frame_info->views.size());
+	for (size_t i = 0; i < layer_projection_views.size(); ++i) {
+		auto& v = frame_info->views[i];
+		auto& view = layer_projection_views[i];
+
+		view = v.view;
+
+		// release swapchain image
+		XrSwapchainImageReleaseInfo release_info{XR_TYPE_SWAPCHAIN_IMAGE_RELEASE_INFO};
+		XR_CHECK_THROW(xrReleaseSwapchainImage(v.view.subImage.swapchain, &release_info));
+
+		if (v.depth_info.subImage.swapchain != XR_NULL_HANDLE) {
+			XR_CHECK_THROW(xrReleaseSwapchainImage(v.depth_info.subImage.swapchain, &release_info));
+			v.depth_info.nearZ = znear;
+			v.depth_info.farZ = zfar;
+			// The following line being commented means that our provided depth buffer
+			// _isn't_ actually passed to the runtime for reprojection. So far,
+			// experimentation has shown that runtimes do a better job at reprojection
+			// without getting a depth buffer from us, so we leave it disabled for now.
+			// view.next = &v.depth_info;
+		}
+	}
+
+	XrCompositionLayerProjection layer{XR_TYPE_COMPOSITION_LAYER_PROJECTION};
+	layer.space = m_space;
+	layer.viewCount = uint32_t(layer_projection_views.size());
+	layer.views = layer_projection_views.data();
+
+	std::vector<XrCompositionLayerBaseHeader*> layers;
+	if (layer.viewCount) {
+		layers.push_back(reinterpret_cast<XrCompositionLayerBaseHeader*>(&layer));
+	}
+
+	XrFrameEndInfo frame_end_info{XR_TYPE_FRAME_END_INFO};
+	frame_end_info.displayTime = m_frame_state.predictedDisplayTime;
+	frame_end_info.environmentBlendMode = m_environment_blend_mode;
+	frame_end_info.layerCount = (uint32_t)layers.size();
+	frame_end_info.layers = layers.data();
+	XR_CHECK_THROW(xrEndFrame(m_session, &frame_end_info));
+}
+
+NGP_NAMESPACE_END
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
diff --git a/src/python_api.cu b/src/python_api.cu
index f69056f3f46a65aff0b7f18e7cec7a5fa56f239b..8a8d436405d11488fd550c77f5ff6468300cde91 100644
--- a/src/python_api.cu
+++ b/src/python_api.cu
@@ -157,6 +157,7 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 	}
 
 	auto end_cam_matrix = m_smoothed_camera;
+	auto prev_camera_matrix = m_smoothed_camera;
 
 	for (int i = 0; i < spp; ++i) {
 		float start_alpha = ((float)i)/(float)spp * shutter_fraction;
@@ -164,6 +165,9 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 
 		auto sample_start_cam_matrix = start_cam_matrix;
 		auto sample_end_cam_matrix = log_space_lerp(start_cam_matrix, end_cam_matrix, shutter_fraction);
+		if (i == 0) {
+			prev_camera_matrix = sample_start_cam_matrix;
+		}
 
 		if (path_animation_enabled) {
 			set_camera_from_time(start_time + (end_time-start_time) * (start_alpha + end_alpha) / 2.0f);
@@ -174,7 +178,21 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 			autofocus();
 		}
 
-		render_frame(sample_start_cam_matrix, sample_end_cam_matrix, Eigen::Vector4f::Zero(), m_windowless_render_surface, !linear);
+		render_frame(
+			m_stream.get(),
+			sample_start_cam_matrix,
+			sample_end_cam_matrix,
+			prev_camera_matrix,
+			m_screen_center,
+			m_relative_focal_length,
+			{0.0f, 0.0f, 0.0f, 1.0f},
+			{},
+			{},
+			m_visualized_dimension,
+			m_windowless_render_surface,
+			!linear
+		);
+		prev_camera_matrix = sample_start_cam_matrix;
 	}
 
 	// For cam smoothing when rendering the next frame.
@@ -303,6 +321,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.value("FTheta", ELensMode::FTheta)
 		.value("LatLong", ELensMode::LatLong)
 		.value("OpenCVFisheye", ELensMode::OpenCVFisheye)
+		.value("Equirectangular", ELensMode::Equirectangular)
 		.export_values();
 
 	py::class_<BoundingBox>(m, "BoundingBox")
@@ -344,12 +363,13 @@ PYBIND11_MODULE(pyngp, m) {
 		.def("clear_training_data", &Testbed::clear_training_data, "Clears training data to free up GPU memory.")
 		// General control
 #ifdef NGP_GUI
-		.def("init_window", &Testbed::init_window, "Init a GLFW window that shows real-time progress and a GUI. 'second_window' creates a second copy of the output in its own window",
+		.def("init_window", &Testbed::init_window, "Init a GLFW window that shows real-time progress and a GUI. 'second_window' creates a second copy of the output in its own window.",
 			py::arg("width"),
 			py::arg("height"),
 			py::arg("hidden") = false,
 			py::arg("second_window") = false
 		)
+		.def("init_vr", &Testbed::init_vr, "Init rendering to a connected and active VR headset. Requires a GUI window to have been previously created via `init_window`.")
 		.def_readwrite("keyboard_event_callback", &Testbed::m_keyboard_event_callback)
 		.def("is_key_pressed", [](py::object& obj, int key) { return ImGui::IsKeyPressed(key); })
 		.def("is_key_down", [](py::object& obj, int key) { return ImGui::IsKeyDown(key); })
@@ -431,6 +451,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.def_readwrite("dynamic_res_target_fps", &Testbed::m_dynamic_res_target_fps)
 		.def_readwrite("fixed_res_factor", &Testbed::m_fixed_res_factor)
 		.def_readwrite("background_color", &Testbed::m_background_color)
+		.def_readwrite("render_transparency_as_checkerboard", &Testbed::m_render_transparency_as_checkerboard)
 		.def_readwrite("shall_train", &Testbed::m_train)
 		.def_readwrite("shall_train_encoding", &Testbed::m_train_encoding)
 		.def_readwrite("shall_train_network", &Testbed::m_train_network)
@@ -493,7 +514,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.def_property("dlss",
 			[](py::object& obj) { return obj.cast<Testbed&>().m_dlss; },
 			[](const py::object& obj, bool value) {
-				if (value && !obj.cast<Testbed&>().m_dlss_supported) {
+				if (value && !obj.cast<Testbed&>().m_dlss_provider) {
 					if (obj.cast<Testbed&>().m_render_window) {
 						throw std::runtime_error{"DLSS not supported."};
 					} else {
@@ -660,7 +681,6 @@ PYBIND11_MODULE(pyngp, m) {
 	image
 		.def_readonly("training", &Testbed::Image::training)
 		.def_readwrite("random_mode", &Testbed::Image::random_mode)
-		.def_readwrite("pos", &Testbed::Image::pos)
 		;
 
 	py::class_<Testbed::Image::Training>(image, "Training")
diff --git a/src/render_buffer.cu b/src/render_buffer.cu
index c6d06e250c5200c96c8455b8de2db5a375a3ff0b..aa0c10dd4950f743bb756c86bff5474d6e962f1c 100644
--- a/src/render_buffer.cu
+++ b/src/render_buffer.cu
@@ -47,30 +47,40 @@ void CudaSurface2D::free() {
 	m_surface = 0;
 	if (m_array) {
 		cudaFreeArray(m_array);
-		g_total_n_bytes_allocated -= m_size.prod() * sizeof(float4);
+		g_total_n_bytes_allocated -= m_size.prod() * sizeof(float) * m_n_channels;
 	}
 	m_array = nullptr;
+	m_size = Vector2i::Zero();
+	m_n_channels = 0;
 }
 
-void CudaSurface2D::resize(const Vector2i& size) {
-	if (size == m_size) {
+void CudaSurface2D::resize(const Vector2i& size, int n_channels) {
+	if (size == m_size && n_channels == m_n_channels) {
 		return;
 	}
 
 	free();
 
-	m_size = size;
-
-	cudaChannelFormatDesc desc = cudaCreateChannelDesc<float4>();
+	cudaChannelFormatDesc desc;
+	switch (n_channels) {
+		case 1: desc = cudaCreateChannelDesc<float>(); break;
+		case 2: desc = cudaCreateChannelDesc<float2>(); break;
+		case 3: desc = cudaCreateChannelDesc<float3>(); break;
+		case 4: desc = cudaCreateChannelDesc<float4>(); break;
+		default: throw std::runtime_error{fmt::format("CudaSurface2D: unsupported number of channels {}", n_channels)};
+	}
 	CUDA_CHECK_THROW(cudaMallocArray(&m_array, &desc, size.x(), size.y(), cudaArraySurfaceLoadStore));
 
-	g_total_n_bytes_allocated += m_size.prod() * sizeof(float4);
+	g_total_n_bytes_allocated += m_size.prod() * sizeof(float) * n_channels;
 
 	struct cudaResourceDesc resource_desc;
 	memset(&resource_desc, 0, sizeof(resource_desc));
 	resource_desc.resType = cudaResourceTypeArray;
 	resource_desc.res.array.array = m_array;
 	CUDA_CHECK_THROW(cudaCreateSurfaceObject(&m_surface, &resource_desc));
+
+	m_size = size;
+	m_n_channels = n_channels;
 }
 
 #ifdef NGP_GUI
@@ -91,14 +101,14 @@ GLuint GLTexture::texture() {
 
 cudaSurfaceObject_t GLTexture::surface() {
 	if (!m_cuda_mapping) {
-		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size);
+		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size, m_n_channels);
 	}
 	return m_cuda_mapping->surface();
 }
 
 cudaArray_t GLTexture::array() {
 	if (!m_cuda_mapping) {
-		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size);
+		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size, m_n_channels);
 	}
 	return m_cuda_mapping->array();
 }
@@ -108,12 +118,14 @@ void GLTexture::blit_from_cuda_mapping() {
 		return;
 	}
 
-	if (m_internal_format != GL_RGBA32F || m_format != GL_RGBA || m_is_8bit) {
-		throw std::runtime_error{"Can only blit from CUDA mapping if the texture is RGBA float."};
+	if (m_is_8bit) {
+		throw std::runtime_error{"Can only blit from CUDA mapping if the texture is float."};
 	}
 
 	const float* data_cpu = m_cuda_mapping->data_cpu();
-	glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, m_size.x(), m_size.y(), 0, GL_RGBA, GL_FLOAT, data_cpu);
+
+	glBindTexture(GL_TEXTURE_2D, m_texture_id);
+	glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, m_size.x(), m_size.y(), 0, m_format, GL_FLOAT, data_cpu);
 }
 
 void GLTexture::load(const fs::path& path) {
@@ -173,8 +185,7 @@ void GLTexture::resize(const Vector2i& new_size, int n_channels, bool is_8bit) {
 	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
 }
 
-
-GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size) : m_size{size} {
+GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size, int n_channels) : m_size{size}, m_n_channels{n_channels} {
 	static bool s_is_cuda_interop_supported = !is_wsl();
 	if (s_is_cuda_interop_supported) {
 		cudaError_t err = cudaGraphicsGLRegisterImage(&m_graphics_resource, texture_id, GL_TEXTURE_2D, cudaGraphicsRegisterFlagsSurfaceLoadStore);
@@ -187,8 +198,8 @@ GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size) : m
 	if (!s_is_cuda_interop_supported) {
 		// falling back to a regular cuda surface + CPU copy of data
 		m_cuda_surface = std::make_unique<CudaSurface2D>();
-		m_cuda_surface->resize(size);
-		m_data_cpu.resize(m_size.prod() * 4);
+		m_cuda_surface->resize(size, n_channels);
+		m_data_cpu.resize(m_size.prod() * n_channels);
 		return;
 	}
 
@@ -212,7 +223,7 @@ GLTexture::CUDAMapping::~CUDAMapping() {
 }
 
 const float* GLTexture::CUDAMapping::data_cpu() {
-	CUDA_CHECK_THROW(cudaMemcpy2DFromArray(m_data_cpu.data(), m_size.x() * sizeof(float) * 4, array(), 0, 0, m_size.x() * sizeof(float) * 4, m_size.y(), cudaMemcpyDeviceToHost));
+	CUDA_CHECK_THROW(cudaMemcpy2DFromArray(m_data_cpu.data(), m_size.x() * sizeof(float) * m_n_channels, array(), 0, 0, m_size.x() * sizeof(float) * m_n_channels, m_size.y(), cudaMemcpyDeviceToHost));
 	return m_data_cpu.data();
 }
 #endif //NGP_GUI
@@ -362,11 +373,11 @@ __global__ void overlay_image_kernel(
 	float fx = x+0.5f;
 	float fy = y+0.5f;
 
-	fx-=resolution.x()*0.5f; fx/=zoom; fx+=screen_center.x() * resolution.x();
-	fy-=resolution.y()*0.5f; fy/=zoom; fy+=screen_center.y() * resolution.y();
+	fx -= resolution.x() * 0.5f; fx /= zoom; fx += screen_center.x() * resolution.x();
+	fy -= resolution.y() * 0.5f; fy /= zoom; fy += screen_center.y() * resolution.y();
 
-	float u = (fx-resolution.x()*0.5f) * scale  + image_resolution.x()*0.5f;
-	float v = (fy-resolution.y()*0.5f) * scale  + image_resolution.y()*0.5f;
+	float u = (fx - resolution.x() * 0.5f) * scale  + image_resolution.x() * 0.5f;
+	float v = (fy - resolution.y() * 0.5f) * scale  + image_resolution.y() * 0.5f;
 
 	int srcx = floorf(u);
 	int srcy = floorf(v);
@@ -431,7 +442,8 @@ __global__ void overlay_depth_kernel(
 	float depth_scale,
 	Vector2i image_resolution,
 	int fov_axis,
-	float zoom, Eigen::Vector2f screen_center,
+	float zoom,
+	Eigen::Vector2f screen_center,
 	cudaSurfaceObject_t surface
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
@@ -443,14 +455,14 @@ __global__ void overlay_depth_kernel(
 
 	float scale = image_resolution[fov_axis] / float(resolution[fov_axis]);
 
-	float fx = x+0.5f;
-	float fy = y+0.5f;
+	float fx = x + 0.5f;
+	float fy = y + 0.5f;
 
-	fx-=resolution.x()*0.5f; fx/=zoom; fx+=screen_center.x() * resolution.x();
-	fy-=resolution.y()*0.5f; fy/=zoom; fy+=screen_center.y() * resolution.y();
+	fx -= resolution.x() * 0.5f; fx /= zoom; fx += screen_center.x() * resolution.x();
+	fy -= resolution.y() * 0.5f; fy /= zoom; fy += screen_center.y() * resolution.y();
 
-	float u = (fx-resolution.x()*0.5f) * scale  + image_resolution.x()*0.5f;
-	float v = (fy-resolution.y()*0.5f) * scale  + image_resolution.y()*0.5f;
+	float u = (fx - resolution.x() * 0.5f) * scale + image_resolution.x() * 0.5f;
+	float v = (fy - resolution.y() * 0.5f) * scale + image_resolution.y() * 0.5f;
 
 	int srcx = floorf(u);
 	int srcy = floorf(v);
@@ -568,15 +580,42 @@ __global__ void dlss_splat_kernel(
 	surf2Dwrite(color, surface, x * sizeof(float4), y);
 }
 
+__global__ void depth_splat_kernel(
+	Vector2i resolution,
+	float znear,
+	float zfar,
+	float* __restrict__ depth_buffer,
+	cudaSurfaceObject_t surface
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	uint32_t idx = x + resolution.x() * y;
+	surf2Dwrite(to_ndc_depth(depth_buffer[idx], znear, zfar), surface, x * sizeof(float), y);
+}
+
+void CudaRenderBufferView::clear(cudaStream_t stream) const {
+	size_t n_pixels = resolution.prod();
+	CUDA_CHECK_THROW(cudaMemsetAsync(frame_buffer, 0, n_pixels * sizeof(Array4f), stream));
+	CUDA_CHECK_THROW(cudaMemsetAsync(depth_buffer, 0, n_pixels * sizeof(float), stream));
+}
+
 void CudaRenderBuffer::resize(const Vector2i& res) {
 	m_in_resolution = res;
 	m_frame_buffer.enlarge(res.x() * res.y());
 	m_depth_buffer.enlarge(res.x() * res.y());
+	if (m_depth_target) {
+		m_depth_target->resize(res, 1);
+	}
 	m_accumulate_buffer.enlarge(res.x() * res.y());
 
 	Vector2i out_res = m_dlss ? m_dlss->out_resolution() : res;
 	auto prev_out_res = out_resolution();
-	m_surface_provider->resize(out_res);
+	m_rgba_target->resize(out_res, 4);
 
 	if (out_resolution() != prev_out_res) {
 		reset_accumulation();
@@ -584,8 +623,7 @@ void CudaRenderBuffer::resize(const Vector2i& res) {
 }
 
 void CudaRenderBuffer::clear_frame(cudaStream_t stream) {
-	CUDA_CHECK_THROW(cudaMemsetAsync(m_frame_buffer.data(), 0, m_frame_buffer.bytes(), stream));
-	CUDA_CHECK_THROW(cudaMemsetAsync(m_depth_buffer.data(), 0, m_depth_buffer.bytes(), stream));
+	view().clear(stream);
 }
 
 void CudaRenderBuffer::accumulate(float exposure, cudaStream_t stream) {
@@ -610,10 +648,10 @@ void CudaRenderBuffer::accumulate(float exposure, cudaStream_t stream) {
 	++m_spp;
 }
 
-void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color, EColorSpace output_color_space, cudaStream_t stream) {
+void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color, EColorSpace output_color_space, float znear, float zfar, cudaStream_t stream) {
 	assert(m_dlss || out_resolution() == in_resolution());
 
-	auto res = m_dlss ? in_resolution() : out_resolution();
+	auto res = in_resolution();
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	tonemap_kernel<<<blocks, threads, 0, stream>>>(
@@ -646,6 +684,10 @@ void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color,
 		const dim3 out_blocks = { div_round_up((uint32_t)out_res.x(), threads.x), div_round_up((uint32_t)out_res.y(), threads.y), 1 };
 		dlss_splat_kernel<<<out_blocks, threads, 0, stream>>>(out_res, m_dlss->output(), surface());
 	}
+
+	if (m_depth_target) {
+		depth_splat_kernel<<<blocks, threads, 0, stream>>>(res, znear, zfar, depth_buffer(), m_depth_target->surface());
+	}
 }
 
 void CudaRenderBuffer::overlay_image(
@@ -726,10 +768,10 @@ void CudaRenderBuffer::overlay_false_color(Vector2i training_resolution, bool to
 	);
 }
 
-void CudaRenderBuffer::enable_dlss(const Eigen::Vector2i& max_out_res) {
+void CudaRenderBuffer::enable_dlss(IDlssProvider& dlss_provider, const Eigen::Vector2i& max_out_res) {
 #ifdef NGP_VULKAN
 	if (!m_dlss || m_dlss->max_out_resolution() != max_out_res) {
-		m_dlss = dlss_init(max_out_res);
+		m_dlss = dlss_provider.init_dlss(max_out_res);
 	}
 
 	if (m_dlss) {
diff --git a/src/testbed.cu b/src/testbed.cu
index 23735633885daba0ad7550ee7edfe59a93c851de..6f088bec62c0d9f0afc77ee98588ad4ab257a1e5 100644
--- a/src/testbed.cu
+++ b/src/testbed.cu
@@ -186,11 +186,30 @@ void Testbed::set_mode(ETestbedMode mode) {
 	m_distortion = {};
 	m_training_data_available = false;
 
+	// Clear device-owned data that might be mode-specific
+	for (auto&& device : m_devices) {
+		device.clear();
+	}
+
 	// Reset paths that might be attached to the chosen mode
 	m_data_path = {};
 
 	m_testbed_mode = mode;
 
+	// Set various defaults depending on mode
+	if (m_testbed_mode == ETestbedMode::Nerf) {
+		if (m_devices.size() > 1) {
+			m_use_aux_devices = true;
+		}
+
+		if (m_dlss_provider) {
+			m_dlss = true;
+		}
+	} else {
+		m_use_aux_devices = false;
+		m_dlss = false;
+	}
+
 	reset_camera();
 }
 
@@ -348,8 +367,8 @@ void Testbed::reset_accumulation(bool due_to_camera_movement, bool immediate_red
 
 	if (!due_to_camera_movement || !reprojection_available()) {
 		m_windowless_render_surface.reset_accumulation();
-		for (auto& tex : m_render_surfaces) {
-			tex.reset_accumulation();
+		for (auto& view : m_views) {
+			view.render_buffer->reset_accumulation();
 		}
 	}
 }
@@ -359,8 +378,13 @@ void Testbed::set_visualized_dim(int dim) {
 	reset_accumulation();
 }
 
-void Testbed::translate_camera(const Vector3f& rel) {
-	m_camera.col(3) += m_camera.block<3, 3>(0, 0) * rel * m_bounding_radius;
+void Testbed::translate_camera(const Vector3f& rel, const Matrix3f& rot, bool allow_up_down) {
+	Vector3f movement = rot * rel;
+	if (!allow_up_down) {
+		movement -= movement.dot(m_up_dir) * m_up_dir;
+	}
+
+	m_camera.col(3) += movement;
 	reset_accumulation(true);
 }
 
@@ -425,15 +449,28 @@ void Testbed::set_camera_to_training_view(int trainview) {
 	m_scale = std::max((old_look_at - view_pos()).dot(view_dir()), 0.1f);
 	m_nerf.render_with_lens_distortion = true;
 	m_nerf.render_lens = m_nerf.training.dataset.metadata[trainview].lens;
-	m_screen_center = Vector2f::Constant(1.0f) - m_nerf.training.dataset.metadata[0].principal_point;
+	if (!supports_dlss(m_nerf.render_lens.mode)) {
+		m_dlss = false;
+	}
+
+	m_screen_center = Vector2f::Constant(1.0f) - m_nerf.training.dataset.metadata[trainview].principal_point;
+	m_nerf.training.view = trainview;
 }
 
 void Testbed::reset_camera() {
 	m_fov_axis = 1;
-	set_fov(50.625f);
-	m_zoom = 1.f;
+	m_zoom = 1.0f;
 	m_screen_center = Vector2f::Constant(0.5f);
-	m_scale = m_testbed_mode == ETestbedMode::Image ? 1.0f : 1.5f;
+
+	if (m_testbed_mode == ETestbedMode::Image) {
+		// Make image full-screen at the given view distance
+		m_relative_focal_length = Vector2f::Ones();
+		m_scale = 1.0f;
+	} else {
+		set_fov(50.625f);
+		m_scale = 1.5f;
+	}
+
 	m_camera <<
 		1.0f, 0.0f, 0.0f, 0.5f,
 		0.0f, -1.0f, 0.0f, 0.5f,
@@ -630,7 +667,7 @@ void Testbed::imgui() {
 							m_smoothed_camera = m_camera;
 						}
 					} else {
-						m_pip_render_surface->reset_accumulation();
+						m_pip_render_buffer->reset_accumulation();
 					}
 				}
 			}
@@ -639,7 +676,7 @@ void Testbed::imgui() {
 				float w = ImGui::GetContentRegionAvail().x;
 				if (m_camera_path.update_cam_from_path) {
 					m_picture_in_picture_res = 0;
-					ImGui::Image((ImTextureID)(size_t)m_render_textures.front()->texture(), ImVec2(w, w * 9.0f / 16.0f));
+					ImGui::Image((ImTextureID)(size_t)m_rgba_render_textures.front()->texture(), ImVec2(w, w * 9.0f / 16.0f));
 				} else {
 					m_picture_in_picture_res = (float)std::min((int(w)+31)&(~31), 1920/4);
 					ImGui::Image((ImTextureID)(size_t)m_pip_render_texture->texture(), ImVec2(w, w * 9.0f / 16.0f));
@@ -684,7 +721,7 @@ void Testbed::imgui() {
 
 				auto elapsed = std::chrono::steady_clock::now() - m_camera_path.render_start_time;
 
-				uint32_t progress = m_camera_path.render_frame_idx * m_camera_path.render_settings.spp + m_render_surfaces.front().spp();
+				uint32_t progress = m_camera_path.render_frame_idx * m_camera_path.render_settings.spp + m_views.front().render_buffer->spp();
 				uint32_t goal = m_camera_path.render_settings.n_frames() * m_camera_path.render_settings.spp;
 				auto est_remaining = elapsed * (float)(goal - progress) / std::max(progress, 1u);
 
@@ -718,7 +755,11 @@ void Testbed::imgui() {
 
 	ImGui::Begin("instant-ngp v" NGP_VERSION);
 
-	size_t n_bytes = tcnn::total_n_bytes_allocated() + g_total_n_bytes_allocated + dlss_allocated_bytes();
+	size_t n_bytes = tcnn::total_n_bytes_allocated() + g_total_n_bytes_allocated;
+	if (m_dlss_provider) {
+		n_bytes += m_dlss_provider->allocated_bytes();
+	}
+
 	ImGui::Text("Frame: %.2f ms (%.1f FPS); Mem: %s", m_frame_ms.ema_val(), 1000.0f / m_frame_ms.ema_val(), bytes_to_string(n_bytes).c_str());
 	bool accum_reset = false;
 
@@ -728,41 +769,58 @@ void Testbed::imgui() {
 		if (imgui_colored_button(m_train ? "Stop training" : "Start training", 0.4)) {
 			set_train(!m_train);
 		}
+
+
+		ImGui::SameLine();
+		if (imgui_colored_button("Reset training", 0.f)) {
+			reload_network_from_file();
+		}
+
 		ImGui::SameLine();
-		ImGui::Checkbox("Train encoding", &m_train_encoding);
+		ImGui::Checkbox("encoding", &m_train_encoding);
 		ImGui::SameLine();
-		ImGui::Checkbox("Train network", &m_train_network);
+		ImGui::Checkbox("network", &m_train_network);
 		ImGui::SameLine();
-		ImGui::Checkbox("Random levels", &m_max_level_rand_training);
+		ImGui::Checkbox("rand levels", &m_max_level_rand_training);
 		if (m_testbed_mode == ETestbedMode::Nerf) {
-			ImGui::Checkbox("Train envmap", &m_nerf.training.train_envmap);
+			ImGui::Checkbox("envmap", &m_nerf.training.train_envmap);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train extrinsics", &m_nerf.training.optimize_extrinsics);
+			ImGui::Checkbox("extrinsics", &m_nerf.training.optimize_extrinsics);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train exposure", &m_nerf.training.optimize_exposure);
+			ImGui::Checkbox("exposure", &m_nerf.training.optimize_exposure);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train distortion", &m_nerf.training.optimize_distortion);
+			ImGui::Checkbox("distortion", &m_nerf.training.optimize_distortion);
+
 			if (m_nerf.training.dataset.n_extra_learnable_dims) {
-				ImGui::Checkbox("Train latent codes", &m_nerf.training.optimize_extra_dims);
+				ImGui::SameLine();
+				ImGui::Checkbox("latents", &m_nerf.training.optimize_extra_dims);
 			}
+
+
 			static bool export_extrinsics_in_quat_format = true;
-			if (imgui_colored_button("Export extrinsics", 0.4f)) {
-				m_nerf.training.export_camera_extrinsics(m_imgui.extrinsics_path, export_extrinsics_in_quat_format);
+			static bool extrinsics_have_been_optimized = false;
+
+			if (m_nerf.training.optimize_extrinsics) {
+				extrinsics_have_been_optimized = true;
 			}
 
-			ImGui::SameLine();
-			ImGui::PushItemWidth(400.f);
-			ImGui::InputText("File##Extrinsics file path", m_imgui.extrinsics_path, sizeof(m_imgui.extrinsics_path));
-			ImGui::PopItemWidth();
-			ImGui::SameLine();
-			ImGui::Checkbox("Quaternion format", &export_extrinsics_in_quat_format);
-		}
-		if (imgui_colored_button("Reset training", 0.f)) {
-			reload_network_from_file();
+			if (extrinsics_have_been_optimized) {
+				if (imgui_colored_button("Export extrinsics", 0.4f)) {
+					m_nerf.training.export_camera_extrinsics(m_imgui.extrinsics_path, export_extrinsics_in_quat_format);
+				}
+
+				ImGui::SameLine();
+				ImGui::Checkbox("as quaternions", &export_extrinsics_in_quat_format);
+				ImGui::InputText("File##Extrinsics file path", m_imgui.extrinsics_path, sizeof(m_imgui.extrinsics_path));
+			}
 		}
+
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		ImGui::SliderInt("Batch size", (int*)&m_training_batch_size, 1 << 12, 1 << 22, "%d", ImGuiSliderFlags_Logarithmic);
 		ImGui::SameLine();
 		ImGui::DragInt("Seed", (int*)&m_seed, 1.0f, 0, std::numeric_limits<int>::max());
-		ImGui::SliderInt("Batch size", (int*)&m_training_batch_size, 1 << 12, 1 << 22, "%d", ImGuiSliderFlags_Logarithmic);
+		ImGui::PopItemWidth();
+
 		m_training_batch_size = next_multiple(m_training_batch_size, batch_size_granularity);
 
 		if (m_train) {
@@ -778,9 +836,11 @@ void Testbed::imgui() {
 		} else {
 			ImGui::Text("Training paused");
 		}
+
 		if (m_testbed_mode == ETestbedMode::Nerf) {
 			ImGui::Text("Rays/batch: %d, Samples/ray: %.2f, Batch size: %d/%d", m_nerf.training.counters_rgb.rays_per_batch, (float)m_nerf.training.counters_rgb.measured_batch_size / (float)m_nerf.training.counters_rgb.rays_per_batch, m_nerf.training.counters_rgb.measured_batch_size, m_nerf.training.counters_rgb.measured_batch_size_before_compaction);
 		}
+
 		float elapsed_training = std::chrono::duration<float>(std::chrono::steady_clock::now() - m_training_start_time_point).count();
 		ImGui::Text("Steps: %d, Loss: %0.6f (%0.2f dB), Elapsed: %.1fs", m_training_step, m_loss_scalar.ema_val(), linear_to_db(m_loss_scalar.ema_val()), elapsed_training);
 		ImGui::PlotLines("loss graph", m_loss_graph.data(), std::min(m_loss_graph_samples, m_loss_graph.size()), (m_loss_graph_samples < m_loss_graph.size()) ? 0 : (m_loss_graph_samples % m_loss_graph.size()), 0, FLT_MAX, FLT_MAX, ImVec2(0, 50.f));
@@ -848,85 +908,79 @@ void Testbed::imgui() {
 	if (!m_training_data_available) { ImGui::EndDisabled(); }
 
 	if (ImGui::CollapsingHeader("Rendering", ImGuiTreeNodeFlags_DefaultOpen)) {
-		ImGui::Checkbox("Render", &m_render);
-		ImGui::SameLine();
-
-		const auto& render_tex = m_render_surfaces.front();
-		std::string spp_string = m_dlss ? std::string{""} : fmt::format("({} spp)", std::max(render_tex.spp(), 1u));
-		ImGui::Text(": %.01fms for %dx%d %s", m_render_ms.ema_val(), render_tex.in_resolution().x(), render_tex.in_resolution().y(), spp_string.c_str());
-
-		if (m_dlss_supported) {
-			if (!m_single_view) {
-				ImGui::BeginDisabled();
-				m_dlss = false;
-			}
-
-			if (ImGui::Checkbox("DLSS", &m_dlss)) {
-				accum_reset = true;
+		if (!m_hmd) {
+			if (ImGui::Button("Connect to VR/AR headset")) {
+				try {
+					init_vr();
+				} catch (const std::runtime_error& e) {
+					imgui_error_string = e.what();
+					ImGui::OpenPopup("Error");
+				}
 			}
-
-			if (render_tex.dlss()) {
-				ImGui::SameLine();
-				ImGui::Text("(automatic quality setting: %s)", DlssQualityStrArray[(int)render_tex.dlss()->quality()]);
-				ImGui::SliderFloat("DLSS sharpening", &m_dlss_sharpening, 0.0f, 1.0f, "%.02f");
+		} else if (ImGui::TreeNodeEx("VR/AR settings", ImGuiTreeNodeFlags_DefaultOpen)) {
+			if (m_devices.size() > 1 && m_testbed_mode == ETestbedMode::Nerf) {
+				ImGui::Checkbox("Multi-GPU rendering (one per eye)", &m_use_aux_devices);
 			}
 
-			if (!m_single_view) {
-				ImGui::EndDisabled();
+			accum_reset |= ImGui::Checkbox("Foveated rendering", &m_foveated_rendering) && !m_dlss;
+			if (m_foveated_rendering) {
+				accum_reset |= ImGui::SliderFloat("Maximum foveation", &m_foveated_rendering_max_scaling, 1.0f, 16.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat) && !m_dlss;
 			}
+			ImGui::TreePop();
 		}
 
-		ImGui::Checkbox("Dynamic resolution", &m_dynamic_res);
+		ImGui::Checkbox("Render", &m_render);
+		ImGui::SameLine();
+
+		const auto& render_buffer = m_views.front().render_buffer;
+		std::string spp_string = m_dlss ? std::string{""} : fmt::format("({} spp)", std::max(render_buffer->spp(), 1u));
+		ImGui::Text(": %.01fms for %dx%d %s", m_render_ms.ema_val(), render_buffer->in_resolution().x(), render_buffer->in_resolution().y(), spp_string.c_str());
+
+		ImGui::SameLine();
 		if (ImGui::Checkbox("VSync", &m_vsync)) {
 			glfwSwapInterval(m_vsync ? 1 : 0);
 		}
-		ImGui::SliderFloat("Target FPS", &m_dynamic_res_target_fps, 2.0f, 144.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
-		ImGui::SliderInt("Max spp", &m_max_spp, 0, 1024, "%d", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
 
-		if (!m_dynamic_res) {
-			ImGui::SliderInt("Fixed resolution factor", &m_fixed_res_factor, 8, 64);
-		}
-
-		if (m_testbed_mode == ETestbedMode::Nerf && m_nerf.training.dataset.has_light_dirs) {
-			Vector3f light_dir = m_nerf.light_dir.normalized();
-			if (ImGui::TreeNodeEx("Light Dir (Polar)", ImGuiTreeNodeFlags_DefaultOpen)) {
-				float phi = atan2f(m_nerf.light_dir.x(), m_nerf.light_dir.z());
-				float theta = asinf(m_nerf.light_dir.y());
-				bool spin = ImGui::SliderFloat("Light Dir Theta", &theta, -PI() / 2.0f, PI() / 2.0f);
-				spin |= ImGui::SliderFloat("Light Dir Phi", &phi, -PI(), PI());
-				if (spin) {
-					float sin_phi, cos_phi;
-					sincosf(phi, &sin_phi, &cos_phi);
-					float cos_theta=cosf(theta);
-					m_nerf.light_dir = {sin_phi * cos_theta,sinf(theta),cos_phi * cos_theta};
-					accum_reset = true;
-				}
-				ImGui::TreePop();
-			}
-			if (ImGui::TreeNode("Light Dir (Cartesian)")) {
-				accum_reset |= ImGui::SliderFloat("Light Dir X", ((float*)(&m_nerf.light_dir)) + 0, -1.0f, 1.0f);
-				accum_reset |= ImGui::SliderFloat("Light Dir Y", ((float*)(&m_nerf.light_dir)) + 1, -1.0f, 1.0f);
-				accum_reset |= ImGui::SliderFloat("Light Dir Z", ((float*)(&m_nerf.light_dir)) + 2, -1.0f, 1.0f);
-				ImGui::TreePop();
-			}
+		if (!m_dlss_provider) { ImGui::BeginDisabled(); }
+		accum_reset |= ImGui::Checkbox("DLSS", &m_dlss);
+
+		if (render_buffer->dlss()) {
+			ImGui::SameLine();
+			ImGui::Text("(%s)", DlssQualityStrArray[(int)render_buffer->dlss()->quality()]);
+			ImGui::SameLine();
+			ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+			ImGui::SliderFloat("Sharpening", &m_dlss_sharpening, 0.0f, 1.0f, "%.02f");
+			ImGui::PopItemWidth();
+		}
+
+		if (!m_dlss_provider) {
+			ImGui::SameLine();
+#ifdef NGP_VULKAN
+			ImGui::Text("(unsupported on this system)");
+#else
+			ImGui::Text("(Vulkan was missing at compilation time)");
+#endif
+			ImGui::EndDisabled();
 		}
-		if (m_testbed_mode == ETestbedMode::Nerf && m_nerf.training.dataset.n_extra_learnable_dims) {
-			accum_reset |= ImGui::SliderInt("training image latent code for inference", (int*)&m_nerf.extra_dim_idx_for_inference, 0, m_nerf.training.dataset.n_images-1);
+
+		ImGui::Checkbox("Dynamic resolution", &m_dynamic_res);
+		ImGui::SameLine();
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		if (m_dynamic_res) {
+			ImGui::SliderFloat("Target FPS", &m_dynamic_res_target_fps, 2.0f, 144.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
+		} else {
+			ImGui::SliderInt("Resolution factor", &m_fixed_res_factor, 8, 64);
 		}
+		ImGui::PopItemWidth();
+
 		accum_reset |= ImGui::Combo("Render mode", (int*)&m_render_mode, RenderModeStr);
-		if (m_testbed_mode == ETestbedMode::Nerf)  {
-			accum_reset |= ImGui::Combo("Groundtruth Render mode", (int*)&m_ground_truth_render_mode, GroundTruthRenderModeStr);
-			accum_reset |= ImGui::SliderFloat("Groundtruth Alpha", &m_ground_truth_alpha, 0.0f, 1.0f, "%.02f", ImGuiSliderFlags_AlwaysClamp);
-		}
-		accum_reset |= ImGui::Combo("Color space", (int*)&m_color_space, ColorSpaceStr);
 		accum_reset |= ImGui::Combo("Tonemap curve", (int*)&m_tonemap_curve, TonemapCurveStr);
 		accum_reset |= ImGui::ColorEdit4("Background", &m_background_color[0]);
+
 		if (ImGui::SliderFloat("Exposure", &m_exposure, -5.f, 5.f)) {
 			set_exposure(m_exposure);
 		}
 
-		accum_reset |= ImGui::Checkbox("Snap to pixel centers", &m_snap_to_pixel_centers);
-
 		float max_diam = (m_aabb.max-m_aabb.min).maxCoeff();
 		float render_diam = (m_render_aabb.max-m_render_aabb.min).maxCoeff();
 		float old_render_diam = render_diam;
@@ -988,11 +1042,52 @@ void Testbed::imgui() {
 			m_edit_render_aabb = false;
 		}
 
+		if (ImGui::TreeNode("Advanced rendering options")) {
+			ImGui::SliderInt("Max spp", &m_max_spp, 0, 1024, "%d", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
+			accum_reset |= ImGui::Checkbox("Render transparency as checkerboard", &m_render_transparency_as_checkerboard);
+			accum_reset |= ImGui::Combo("Color space", (int*)&m_color_space, ColorSpaceStr);
+			accum_reset |= ImGui::Checkbox("Snap to pixel centers", &m_snap_to_pixel_centers);
+
+			ImGui::TreePop();
+		}
+
 		if (m_testbed_mode == ETestbedMode::Nerf && ImGui::TreeNode("NeRF rendering options")) {
-			accum_reset |= ImGui::Checkbox("Apply lens distortion", &m_nerf.render_with_lens_distortion);
+			if (m_nerf.training.dataset.has_light_dirs) {
+				Vector3f light_dir = m_nerf.light_dir.normalized();
+				if (ImGui::TreeNodeEx("Light Dir (Polar)", ImGuiTreeNodeFlags_DefaultOpen)) {
+					float phi = atan2f(m_nerf.light_dir.x(), m_nerf.light_dir.z());
+					float theta = asinf(m_nerf.light_dir.y());
+					bool spin = ImGui::SliderFloat("Light Dir Theta", &theta, -PI() / 2.0f, PI() / 2.0f);
+					spin |= ImGui::SliderFloat("Light Dir Phi", &phi, -PI(), PI());
+					if (spin) {
+						float sin_phi, cos_phi;
+						sincosf(phi, &sin_phi, &cos_phi);
+						float cos_theta=cosf(theta);
+						m_nerf.light_dir = {sin_phi * cos_theta,sinf(theta),cos_phi * cos_theta};
+						accum_reset = true;
+					}
+					ImGui::TreePop();
+				}
+
+				if (ImGui::TreeNode("Light Dir (Cartesian)")) {
+					accum_reset |= ImGui::SliderFloat("Light Dir X", ((float*)(&m_nerf.light_dir)) + 0, -1.0f, 1.0f);
+					accum_reset |= ImGui::SliderFloat("Light Dir Y", ((float*)(&m_nerf.light_dir)) + 1, -1.0f, 1.0f);
+					accum_reset |= ImGui::SliderFloat("Light Dir Z", ((float*)(&m_nerf.light_dir)) + 2, -1.0f, 1.0f);
+					ImGui::TreePop();
+				}
+			}
+
+			if (m_nerf.training.dataset.n_extra_learnable_dims) {
+				accum_reset |= ImGui::SliderInt("training image latent code for inference", (int*)&m_nerf.extra_dim_idx_for_inference, 0, m_nerf.training.dataset.n_images-1);
+			}
+
+			accum_reset |= ImGui::Combo("Groundtruth render mode", (int*)&m_ground_truth_render_mode, GroundTruthRenderModeStr);
+			accum_reset |= ImGui::SliderFloat("Groundtruth alpha", &m_ground_truth_alpha, 0.0f, 1.0f, "%.02f", ImGuiSliderFlags_AlwaysClamp);
+
+			bool lens_changed = ImGui::Checkbox("Apply lens distortion", &m_nerf.render_with_lens_distortion);
 
 			if (m_nerf.render_with_lens_distortion) {
-				accum_reset |= ImGui::Combo("Lens mode", (int*)&m_nerf.render_lens.mode, LensModeStr);
+				lens_changed |= ImGui::Combo("Lens mode", (int*)&m_nerf.render_lens.mode, LensModeStr);
 				if (m_nerf.render_lens.mode == ELensMode::OpenCV) {
 					accum_reset |= ImGui::InputFloat("k1", &m_nerf.render_lens.params[0], 0.f, 0.f, "%.5f");
 					accum_reset |= ImGui::InputFloat("k2", &m_nerf.render_lens.params[1], 0.f, 0.f, "%.5f");
@@ -1012,6 +1107,12 @@ void Testbed::imgui() {
 					accum_reset |= ImGui::InputFloat("f_theta p3", &m_nerf.render_lens.params[3], 0.f, 0.f, "%.5f");
 					accum_reset |= ImGui::InputFloat("f_theta p4", &m_nerf.render_lens.params[4], 0.f, 0.f, "%.5f");
 				}
+
+				if (lens_changed && !supports_dlss(m_nerf.render_lens.mode)) {
+					m_dlss = false;
+				}
+
+				accum_reset |= lens_changed;
 			}
 
 			accum_reset |= ImGui::SliderFloat("Min transmittance", &m_nerf.render_min_transmittance, 0.0f, 1.0f, "%.3f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
@@ -1032,6 +1133,7 @@ void Testbed::imgui() {
 			}
 
 			accum_reset |= ImGui::Checkbox("Analytic normals", &m_sdf.analytic_normals);
+			accum_reset |= ImGui::Checkbox("Floor", &m_floor_enable);
 
 			accum_reset |= ImGui::SliderFloat("Normals epsilon", &m_sdf.fd_normals_epsilon, 0.00001f, 0.1f, "%.6g", ImGuiSliderFlags_Logarithmic);
 			accum_reset |= ImGui::SliderFloat("Maximum distance", &m_sdf.maximum_distance, 0.00001f, 0.1f, "%.6g", ImGuiSliderFlags_Logarithmic);
@@ -1131,24 +1233,29 @@ void Testbed::imgui() {
 	}
 
 	if (ImGui::CollapsingHeader("Camera", ImGuiTreeNodeFlags_DefaultOpen)) {
-		if (ImGui::SliderFloat("Aperture size", &m_aperture_size, 0.0f, 0.1f)) {
+		ImGui::Checkbox("First person controls", &m_fps_camera);
+		ImGui::SameLine();
+		ImGui::Checkbox("Smooth motion", &m_camera_smoothing);
+		ImGui::SameLine();
+		ImGui::Checkbox("Autofocus", &m_autofocus);
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		if (ImGui::SliderFloat("Aperture size", &m_aperture_size, 0.0f, 1.0f, "%.3f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat)) {
 			m_dlss = false;
 			accum_reset = true;
 		}
+		ImGui::SameLine();
+		accum_reset |= ImGui::SliderFloat("Focus depth", &m_slice_plane_z, -m_bounding_radius, m_bounding_radius);
+
 		float local_fov = fov();
 		if (ImGui::SliderFloat("Field of view", &local_fov, 0.0f, 120.0f)) {
 			set_fov(local_fov);
 			accum_reset = true;
 		}
+		ImGui::SameLine();
 		accum_reset |= ImGui::SliderFloat("Zoom", &m_zoom, 1.f, 10.f);
-		if (m_testbed_mode == ETestbedMode::Sdf) {
-			accum_reset |= ImGui::Checkbox("Floor", &m_floor_enable);
-			ImGui::SameLine();
-		}
+		ImGui::PopItemWidth();
+
 
-		ImGui::Checkbox("First person controls", &m_fps_camera);
-		ImGui::Checkbox("Smooth camera motion", &m_camera_smoothing);
-		ImGui::Checkbox("Autofocus", &m_autofocus);
 
 		if (ImGui::TreeNode("Advanced camera settings")) {
 			accum_reset |= ImGui::SliderFloat2("Screen center", &m_screen_center.x(), 0.f, 1.f);
@@ -1218,7 +1325,7 @@ void Testbed::imgui() {
 		}
 	}
 
-	if (ImGui::CollapsingHeader("Snapshot")) {
+	if (ImGui::CollapsingHeader("Snapshot", ImGuiTreeNodeFlags_DefaultOpen)) {
 		ImGui::Text("Snapshot");
 		ImGui::SameLine();
 		if (ImGui::Button("Save")) {
@@ -1329,7 +1436,7 @@ void Testbed::imgui() {
 			ImGui::Text("%dx%dx%d", res3d.x(), res3d.y(), res3d.z());
 			float thresh_range = (m_testbed_mode == ETestbedMode::Sdf) ? 0.5f : 10.f;
 			ImGui::SliderFloat("MC density threshold",&m_mesh.thresh, -thresh_range, thresh_range);
-			ImGui::Combo("Mesh render mode", (int*)&m_mesh_render_mode, "Off\0Vertex Colors\0Vertex Normals\0Face IDs\0");
+			ImGui::Combo("Mesh render mode", (int*)&m_mesh_render_mode, "Off\0Vertex Colors\0Vertex Normals\0\0");
 			ImGui::Checkbox("Unwrap mesh", &m_mesh.unwrap);
 			if (uint32_t tricount = m_mesh.indices.size()/3) {
 				ImGui::InputText("##OBJFile", m_imgui.mesh_path, sizeof(m_imgui.mesh_path));
@@ -1446,11 +1553,11 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 		view2world.setIdentity();
 		view2world.block<3,4>(0,0) = camera_matrix;
 
-		auto focal = calc_focal_length(Vector2i::Ones(), m_fov_axis, m_zoom);
+		auto focal = calc_focal_length(Vector2i::Ones(), m_relative_focal_length, m_fov_axis, m_zoom);
 		float zscale = 1.0f / focal[m_fov_axis];
 
 		float xyscale = (float)m_window_res[m_fov_axis];
-		Vector2f screen_center = render_screen_center();
+		Vector2f screen_center = render_screen_center(m_screen_center);
 		view2proj <<
 			xyscale, 0,       (float)m_window_res.x()*screen_center.x()*zscale, 0,
 			0,       xyscale, (float)m_window_res.y()*screen_center.y()*zscale, 0,
@@ -1478,12 +1585,12 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 				float flx = focal.x();
 				float fly = focal.y();
 				Matrix<float, 4, 4> view2proj_guizmo;
-				float zfar = 100.f;
-				float znear = 0.1f;
+				float zfar = m_ndc_zfar;
+				float znear = m_ndc_znear;
 				view2proj_guizmo <<
-					fly*2.f/aspect, 0, 0, 0,
-					0, -fly*2.f, 0, 0,
-					0, 0, (zfar+znear)/(zfar-znear), -(2.f*zfar*znear) / (zfar-znear),
+					fly * 2.f / aspect, 0, 0, 0,
+					0, -fly * 2.f, 0, 0,
+					0, 0, (zfar + znear) / (zfar - znear), -(2.f * zfar * znear) / (zfar - znear),
 					0, 0, 1, 0;
 				ImGuizmo::SetRect(0, 0, io.DisplaySize.x, io.DisplaySize.y);
 
@@ -1502,8 +1609,8 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 			}
 		}
 
-		if (m_camera_path.imgui_viz(list, view2proj, world2proj, world2view, focal, aspect)) {
-			m_pip_render_surface->reset_accumulation();
+		if (m_camera_path.imgui_viz(list, view2proj, world2proj, world2view, focal, aspect, m_ndc_znear, m_ndc_zfar)) {
+			m_pip_render_buffer->reset_accumulation();
 		}
 	}
 }
@@ -1635,11 +1742,19 @@ bool Testbed::keyboard_event() {
 	}
 
 	if (ImGui::IsKeyPressed('=') || ImGui::IsKeyPressed('+')) {
-		m_camera_velocity *= 1.5f;
+		if (m_fps_camera) {
+			m_camera_velocity *= 1.5f;
+		} else {
+			set_scale(m_scale * 1.1f);
+		}
 	}
 
 	if (ImGui::IsKeyPressed('-') || ImGui::IsKeyPressed('_')) {
-		m_camera_velocity /= 1.5f;
+		if (m_fps_camera) {
+			m_camera_velocity /= 1.5f;
+		} else {
+			set_scale(m_scale / 1.1f);
+		}
 	}
 
 	// WASD camera movement
@@ -1675,46 +1790,66 @@ bool Testbed::keyboard_event() {
 
 	if (translate_vec != Vector3f::Zero()) {
 		m_fps_camera = true;
-		translate_camera(translate_vec);
+
+		// If VR is active, movement that isn't aligned with the current view
+		// direction is _very_ jarring to the user, so make keyboard-based
+		// movement aligned with the VR view, even though it is not an intended
+		// movement mechanism. (Users should use controllers.)
+		translate_camera(translate_vec, m_hmd && m_hmd->is_visible() ? m_views.front().camera0.block<3, 3>(0, 0) : m_camera.block<3, 3>(0, 0));
 	}
 
 	return false;
 }
 
-void Testbed::mouse_wheel(Vector2f m, float delta) {
+void Testbed::mouse_wheel() {
+	float delta = ImGui::GetIO().MouseWheel;
 	if (delta == 0) {
 		return;
 	}
 
-	if (!ImGui::GetIO().WantCaptureMouse) {
-		float scale_factor = pow(1.1f, -delta);
-		m_image.pos = (m_image.pos - m) / scale_factor + m;
-		set_scale(m_scale * scale_factor);
+	float scale_factor = pow(1.1f, -delta);
+	set_scale(m_scale * scale_factor);
+
+	// When in image mode, zoom around the hovered point.
+	if (m_testbed_mode == ETestbedMode::Image) {
+		Vector2i mouse = {ImGui::GetMousePos().x, ImGui::GetMousePos().y};
+		Vector3f offset = get_3d_pos_from_pixel(*m_views.front().render_buffer, mouse) - look_at();
+
+		// Don't center around infinitely distant points.
+		if (offset.norm() < 256.0f) {
+			m_camera.col(3) += offset * (1.0f - scale_factor);
+		}
 	}
 
 	reset_accumulation(true);
 }
 
-void Testbed::mouse_drag(const Vector2f& rel, int button) {
+Matrix3f Testbed::rotation_from_angles(const Vector2f& angles) const {
 	Vector3f up = m_up_dir;
 	Vector3f side = m_camera.col(0);
+	return (AngleAxisf(angles.x(), up) * AngleAxisf(angles.y(), side)).matrix();
+}
 
-	bool is_left_held = (button & 1) != 0;
-	bool is_right_held = (button & 2) != 0;
+void Testbed::mouse_drag() {
+	Vector2f rel = Vector2f{ImGui::GetIO().MouseDelta.x, ImGui::GetIO().MouseDelta.y} / (float)m_window_res[m_fov_axis];
+	Vector2i mouse = {ImGui::GetMousePos().x, ImGui::GetMousePos().y};
+
+	Vector3f up = m_up_dir;
+	Vector3f side = m_camera.col(0);
 
 	bool shift = ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Shift;
-	if (is_left_held) {
+
+	// Left held
+	if (ImGui::GetIO().MouseDown[0]) {
 		if (shift) {
-			auto mouse = ImGui::GetMousePos();
-			determine_autofocus_target_from_pixel({mouse.x, mouse.y});
+			m_autofocus_target = get_3d_pos_from_pixel(*m_views.front().render_buffer, mouse);
+			m_autofocus = true;
+
 			reset_accumulation();
 		} else {
 			float rot_sensitivity = m_fps_camera ? 0.35f : 1.0f;
-			Matrix3f rot =
-				(AngleAxisf(static_cast<float>(-rel.x() * 2 * PI() * rot_sensitivity), up) * // Scroll sideways around up vector
-				AngleAxisf(static_cast<float>(-rel.y() * 2 * PI() * rot_sensitivity), side)).matrix(); // Scroll around side vector
+			Matrix3f rot = rotation_from_angles(-rel * 2 * PI() * rot_sensitivity);
 
-			m_image.pos += rel;
 			if (m_fps_camera) {
 				m_camera.block<3, 3>(0, 0) = rot * m_camera.block<3, 3>(0, 0);
 			} else {
@@ -1729,11 +1864,9 @@ void Testbed::mouse_drag(const Vector2f& rel, int button) {
 		}
 	}
 
-	if (is_right_held) {
-		Matrix3f rot =
-			(AngleAxisf(static_cast<float>(-rel.x() * 2 * PI()), up) * // Scroll sideways around up vector
-			AngleAxisf(static_cast<float>(-rel.y() * 2 * PI()), side)).matrix(); // Scroll around side vector
-
+	// Right held
+	if (ImGui::GetIO().MouseDown[1]) {
+		Matrix3f rot = rotation_from_angles(-rel * 2 * PI());
 		if (m_render_mode == ERenderMode::Shade) {
 			m_sun_dir = rot.transpose() * m_sun_dir;
 		}
@@ -1742,14 +1875,27 @@ void Testbed::mouse_drag(const Vector2f& rel, int button) {
 		reset_accumulation();
 	}
 
-	bool is_middle_held = (button & 4) != 0;
-	if (is_middle_held) {
-		translate_camera({-rel.x(), -rel.y(), 0.0f});
+	// Middle pressed
+	if (ImGui::GetIO().MouseClicked[2]) {
+		m_drag_depth = get_depth_from_renderbuffer(*m_views.front().render_buffer, mouse.cast<float>().cwiseQuotient(m_window_res.cast<float>()));
+	}
+
+	// Middle held
+	if (ImGui::GetIO().MouseDown[2]) {
+		Vector3f translation = Vector3f{-rel.x(), -rel.y(), 0.0f} / m_zoom;
+
+		// If we have a valid depth value, scale the scene translation by it such that the
+		// hovered point in 3D space stays under the cursor.
+		if (m_drag_depth < 256.0f) {
+			translation *= m_drag_depth / m_relative_focal_length[m_fov_axis];
+		}
+
+		translate_camera(translation, m_camera.block<3, 3>(0, 0));
 	}
 }
 
-bool Testbed::begin_frame_and_handle_user_input() {
-	if (glfwWindowShouldClose(m_glfw_window) || ImGui::IsKeyDown(GLFW_KEY_ESCAPE) || ImGui::IsKeyDown(GLFW_KEY_Q)) {
+bool Testbed::begin_frame() {
+	if (glfwWindowShouldClose(m_glfw_window) || ImGui::IsKeyPressed(GLFW_KEY_ESCAPE) || ImGui::IsKeyPressed(GLFW_KEY_Q)) {
 		destroy_window();
 		return false;
 	}
@@ -1769,21 +1915,18 @@ bool Testbed::begin_frame_and_handle_user_input() {
 	ImGui::NewFrame();
 	ImGuizmo::BeginFrame();
 
+	return true;
+}
+
+void Testbed::handle_user_input() {
 	if (ImGui::IsKeyPressed(GLFW_KEY_TAB) || ImGui::IsKeyPressed(GLFW_KEY_GRAVE_ACCENT)) {
 		m_imgui.enabled = !m_imgui.enabled;
 	}
 
-	ImVec2 m = ImGui::GetMousePos();
-	int mb = 0;
-	float mw = 0.f;
-	ImVec2 relm = {};
-	if (!ImGui::IsAnyItemActive() && !ImGuizmo::IsUsing() && !ImGuizmo::IsOver()) {
-		relm = ImGui::GetIO().MouseDelta;
-		if (ImGui::GetIO().MouseDown[0]) mb |= 1;
-		if (ImGui::GetIO().MouseDown[1]) mb |= 2;
-		if (ImGui::GetIO().MouseDown[2]) mb |= 4;
-		mw = ImGui::GetIO().MouseWheel;
-		relm = {relm.x / (float)m_window_res.y(), relm.y / (float)m_window_res.y()};
+	// Only respond to mouse inputs when not interacting with ImGui
+	if (!ImGui::IsAnyItemActive() && !ImGuizmo::IsUsing() && !ImGuizmo::IsOver() && !ImGui::GetIO().WantCaptureMouse) {
+		mouse_wheel();
+		mouse_drag();
 	}
 
 	if (m_testbed_mode == ETestbedMode::Nerf && (m_render_ground_truth || m_nerf.training.render_error_overlay)) {
@@ -1791,21 +1934,150 @@ bool Testbed::begin_frame_and_handle_user_input() {
 		int bestimage = find_best_training_view(-1);
 		if (bestimage >= 0) {
 			m_nerf.training.view = bestimage;
-			if (mb == 0) {// snap camera to ground truth view on mouse up
+			if (ImGui::GetIO().MouseReleased[0]) {// snap camera to ground truth view on mouse up
 				set_camera_to_training_view(m_nerf.training.view);
 			}
 		}
 	}
 
 	keyboard_event();
-	mouse_wheel({m.x / (float)m_window_res.y(), m.y / (float)m_window_res.y()}, mw);
-	mouse_drag({relm.x, relm.y}, mb);
 
 	if (m_imgui.enabled) {
 		imgui();
 	}
+}
 
-	return true;
+Vector3f Testbed::vr_to_world(const Vector3f& pos) const {
+	return m_camera.block<3, 3>(0, 0) * pos * m_scale + m_camera.col(3);
+}
+
+void Testbed::begin_vr_frame_and_handle_vr_input() {
+	if (!m_hmd) {
+		m_vr_frame_info = nullptr;
+		return;
+	}
+
+	m_hmd->poll_events();
+	if (!m_hmd->must_run_frame_loop()) {
+		m_vr_frame_info = nullptr;
+		return;
+	}
+
+	m_vr_frame_info = m_hmd->begin_frame();
+
+	const auto& views = m_vr_frame_info->views;
+	size_t n_views = views.size();
+	size_t n_devices = m_devices.size();
+	if (n_views > 0) {
+		set_n_views(n_views);
+
+		Vector2i total_size = Vector2i::Zero();
+		for (size_t i = 0; i < n_views; ++i) {
+			Vector2i view_resolution = {views[i].view.subImage.imageRect.extent.width, views[i].view.subImage.imageRect.extent.height};
+			total_size += view_resolution;
+
+			m_views[i].full_resolution = view_resolution;
+
+			// Apply the VR pose relative to the world camera transform.
+			m_views[i].camera0.block<3, 3>(0, 0) = m_camera.block<3, 3>(0, 0) * views[i].pose.block<3, 3>(0, 0);
+			m_views[i].camera0.col(3) = vr_to_world(views[i].pose.col(3));
+			m_views[i].camera1 = m_views[i].camera0;
+
+			m_views[i].visualized_dimension = m_visualized_dimension;
+
+			const auto& xr_fov = views[i].view.fov;
+
+			// Compute the distance on the image plane (1 unit away from the camera) that an angle of the respective FOV spans
+			Vector2f rel_focal_length_left_down = 0.5f * fov_to_focal_length(Vector2i::Ones(), Vector2f{360.0f * xr_fov.angleLeft / PI(), 360.0f * xr_fov.angleDown / PI()});
+			Vector2f rel_focal_length_right_up = 0.5f * fov_to_focal_length(Vector2i::Ones(), Vector2f{360.0f * xr_fov.angleRight / PI(), 360.0f * xr_fov.angleUp / PI()});
+
+			// Compute total distance (for X and Y) that is spanned on the image plane.
+			m_views[i].relative_focal_length = rel_focal_length_right_up - rel_focal_length_left_down;
+
+			// Compute fraction of that distance that is spanned by the right-up part and set screen center accordingly.
+			Vector2f ratio = rel_focal_length_right_up.cwiseQuotient(m_views[i].relative_focal_length);
+			m_views[i].screen_center = { 1.0f - ratio.x(), ratio.y() };
+
+			// Fix up weirdness in the rendering pipeline
+			m_views[i].relative_focal_length[(m_fov_axis+1)%2] *= (float)view_resolution[(m_fov_axis+1)%2] / (float)view_resolution[m_fov_axis];
+			m_views[i].render_buffer->set_hidden_area_mask(views[i].hidden_area_mask);
+
+			// Render each view on a different GPU (if available)
+			m_views[i].device = m_use_aux_devices ? &m_devices.at(i % m_devices.size()) : &primary_device();
+		}
+
+		// Put all the views next to each other, but at half size
+		glfwSetWindowSize(m_glfw_window, total_size.x() / 2, (total_size.y() / 2) / n_views);
+
+		// VR controller input
+		const auto& hands = m_vr_frame_info->hands;
+		m_fps_camera = true;
+
+		// TRANSLATE BY STICK (if not pressing the stick)
+		if (!hands[0].pressing) {
+			Vector3f translate_vec = Vector3f{hands[0].thumbstick.x(), 0.0f, hands[0].thumbstick.y()} * m_camera_velocity * m_frame_ms.val() / 1000.0f;
+			if (translate_vec != Vector3f::Zero()) {
+				translate_camera(translate_vec, m_views.front().camera0.block<3, 3>(0, 0), false);
+			}
+		}
+
+		// TURN BY STICK (if not pressing the stick)
+		if (!hands[1].pressing) {
+			auto prev_camera = m_camera;
+
+			// Turn around the up vector (equivalent to x-axis mouse drag) with right joystick left/right
+			float sensitivity = 0.35f;
+			m_camera.block<3, 3>(0, 0) = rotation_from_angles({-2.0f * PI() * sensitivity * hands[1].thumbstick.x() * m_frame_ms.val() / 1000.0f, 0.0f}) * m_camera.block<3, 3>(0, 0);
+
+			// Translate camera such that center of rotation was about the current view
+			m_camera.col(3) += prev_camera.block<3, 3>(0, 0) * views[0].pose.col(3) * m_scale - m_camera.block<3, 3>(0, 0) * views[0].pose.col(3) * m_scale;
+		}
+
+		// TRANSLATE, SCALE, AND ROTATE BY GRAB
+		{
+			bool both_grabbing = hands[0].grabbing && hands[1].grabbing;
+			float drag_factor = both_grabbing ? 0.5f : 1.0f;
+
+			if (both_grabbing) {
+				drag_factor = 0.5f;
+
+				Vector3f prev_diff = hands[0].prev_grab_pos - hands[1].prev_grab_pos;
+				Vector3f diff = hands[0].grab_pos - hands[1].grab_pos;
+				Vector3f center = 0.5f * (hands[0].grab_pos + hands[1].grab_pos);
+
+				Vector3f center_world = vr_to_world(0.5f * (hands[0].grab_pos + hands[1].grab_pos));
+
+				// Scale around center position of the two dragging hands. Makes the scaling feel similar to phone pinch-to-zoom
+				float scale = m_scale * prev_diff.norm() / diff.norm();
+				m_camera.col(3) = (view_pos() - center_world) * (scale / m_scale) + center_world;
+				m_scale = scale;
+
+				// Take rotational component and project it to the nearest rotation about the up vector.
+				// We don't want to rotate the scene about any other axis.
+				Vector3f rot = prev_diff.normalized().cross(diff.normalized());
+				float rot_radians = std::asin(m_up_dir.dot(rot));
+
+				auto prev_camera = m_camera;
+				m_camera.block<3, 3>(0, 0) = AngleAxisf(rot_radians, m_up_dir) * m_camera.block<3, 3>(0, 0);
+				m_camera.col(3) += prev_camera.block<3, 3>(0, 0) * center * m_scale - m_camera.block<3, 3>(0, 0) * center * m_scale;
+			}
+
+			for (const auto& hand : hands) {
+				if (hand.grabbing) {
+					m_camera.col(3) -= drag_factor * m_camera.block<3, 3>(0, 0) * hand.drag() * m_scale;
+				}
+			}
+		}
+
+		// ERASE OCCUPANCY WHEN PRESSING STICK/TRACKPAD
+		if (m_testbed_mode == ETestbedMode::Nerf) {
+			for (const auto& hand : hands) {
+				if (hand.pressing) {
+					mark_density_grid_in_sphere_empty(vr_to_world(hand.pose.col(3)), m_scale * 0.05f, m_stream.get());
+				}
+			}
+		}
+	}
 }
 
 void Testbed::SecondWindow::draw(GLuint texture) {
@@ -1834,11 +2106,164 @@ void Testbed::SecondWindow::draw(GLuint texture) {
 	glfwMakeContextCurrent(old_context);
 }
 
+void Testbed::init_opengl_shaders() {
+	static const char* shader_vert = R"(#version 140
+		out vec2 UVs;
+		void main() {
+			UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2);
+			gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0);
+		})";
+
+	static const char* shader_frag = R"(#version 140
+		in vec2 UVs;
+		out vec4 frag_color;
+		uniform sampler2D rgba_texture;
+		uniform sampler2D depth_texture;
+
+		struct FoveationWarp {
+			float al, bl, cl;
+			float am, bm;
+			float ar, br, cr;
+			float switch_left, switch_right;
+			float inv_switch_left, inv_switch_right;
+		};
+
+		uniform FoveationWarp warp_x;
+		uniform FoveationWarp warp_y;
+
+		float unwarp(in FoveationWarp warp, float y) {
+			y = clamp(y, 0.0, 1.0);
+			if (y < warp.inv_switch_left) {
+				return (sqrt(-4.0 * warp.al * warp.cl + 4.0 * warp.al * y + warp.bl * warp.bl) - warp.bl) / (2.0 * warp.al);
+			} else if (y > warp.inv_switch_right) {
+				return (sqrt(-4.0 * warp.ar * warp.cr + 4.0 * warp.ar * y + warp.br * warp.br) - warp.br) / (2.0 * warp.ar);
+			} else {
+				return (y - warp.bm) / warp.am;
+			}
+		}
+
+		vec2 unwarp(in vec2 pos) {
+			return vec2(unwarp(warp_x, pos.x), unwarp(warp_y, pos.y));
+		}
+
+		void main() {
+			vec2 tex_coords = UVs;
+			tex_coords.y = 1.0 - tex_coords.y;
+			tex_coords = unwarp(tex_coords);
+			frag_color = texture(rgba_texture, tex_coords.xy);
+			//Uncomment the following line of code to visualize debug the depth buffer for debugging.
+			// frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0);
+			gl_FragDepth = texture(depth_texture, tex_coords.xy).r;
+		})";
+
+	GLuint vert = glCreateShader(GL_VERTEX_SHADER);
+	glShaderSource(vert, 1, &shader_vert, NULL);
+	glCompileShader(vert);
+	check_shader(vert, "Blit vertex shader", false);
+
+	GLuint frag = glCreateShader(GL_FRAGMENT_SHADER);
+	glShaderSource(frag, 1, &shader_frag, NULL);
+	glCompileShader(frag);
+	check_shader(frag, "Blit fragment shader", false);
+
+	m_blit_program = glCreateProgram();
+	glAttachShader(m_blit_program, vert);
+	glAttachShader(m_blit_program, frag);
+	glLinkProgram(m_blit_program);
+	check_shader(m_blit_program, "Blit shader program", true);
+
+	glDeleteShader(vert);
+	glDeleteShader(frag);
+
+	glGenVertexArrays(1, &m_blit_vao);
+}
+
+void Testbed::blit_texture(const Foveation& foveation, GLint rgba_texture, GLint rgba_filter_mode, GLint depth_texture, GLint framebuffer, const Vector2i& offset, const Vector2i& resolution) {
+	if (m_blit_program == 0) {
+		return;
+	}
+
+	// Blit image to OpenXR swapchain.
+	// Note that the OpenXR swapchain is 8bit while the rendering is in a float texture.
+	// As some XR runtimes do not support float swapchains, we can't render into it directly.
+
+	bool tex = glIsEnabled(GL_TEXTURE_2D);
+	bool depth = glIsEnabled(GL_DEPTH_TEST);
+	bool cull = glIsEnabled(GL_CULL_FACE);
+
+	if (!tex) glEnable(GL_TEXTURE_2D);
+	if (!depth) glEnable(GL_DEPTH_TEST);
+	if (cull) glDisable(GL_CULL_FACE);
+
+	glDepthFunc(GL_ALWAYS);
+	glDepthMask(GL_TRUE);
+
+	glBindVertexArray(m_blit_vao);
+	glUseProgram(m_blit_program);
+	glUniform1i(glGetUniformLocation(m_blit_program, "rgba_texture"), 0);
+	glUniform1i(glGetUniformLocation(m_blit_program, "depth_texture"), 1);
+
+	auto bind_warp = [&](const FoveationPiecewiseQuadratic& warp, const std::string& uniform_name) {
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".al").c_str()), warp.al);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bl").c_str()), warp.bl);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cl").c_str()), warp.cl);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".am").c_str()), warp.am);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bm").c_str()), warp.bm);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".ar").c_str()), warp.ar);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".br").c_str()), warp.br);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cr").c_str()), warp.cr);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_left").c_str()), warp.switch_left);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_right").c_str()), warp.switch_right);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_left").c_str()), warp.inv_switch_left);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_right").c_str()), warp.inv_switch_right);
+	};
+
+	bind_warp(foveation.warp_x, "warp_x");
+	bind_warp(foveation.warp_y, "warp_y");
+
+	glActiveTexture(GL_TEXTURE1);
+	glBindTexture(GL_TEXTURE_2D, depth_texture);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+
+	glActiveTexture(GL_TEXTURE0);
+	glBindTexture(GL_TEXTURE_2D, rgba_texture);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, rgba_filter_mode);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, rgba_filter_mode);
+
+	glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
+	glViewport(offset.x(), offset.y(), resolution.x(), resolution.y());
+
+	glDrawArrays(GL_TRIANGLES, 0, 3);
+
+	glBindVertexArray(0);
+	glUseProgram(0);
+
+	glDepthFunc(GL_LESS);
+
+	// restore old state
+	if (!tex) glDisable(GL_TEXTURE_2D);
+	if (!depth) glDisable(GL_DEPTH_TEST);
+	if (cull) glEnable(GL_CULL_FACE);
+	glBindFramebuffer(GL_FRAMEBUFFER, 0);
+}
+
 void Testbed::draw_gui() {
 	// Make sure all the cuda code finished its business here
 	CUDA_CHECK_THROW(cudaDeviceSynchronize());
-	if (!m_render_textures.empty())
-		m_second_window.draw((GLuint)m_render_textures.front()->texture());
+
+	if (!m_rgba_render_textures.empty()) {
+		m_second_window.draw((GLuint)m_rgba_render_textures.front()->texture());
+	}
+
 	glfwMakeContextCurrent(m_glfw_window);
 	int display_w, display_h;
 	glfwGetFramebufferSize(m_glfw_window, &display_w, &display_h);
@@ -1846,56 +2271,42 @@ void Testbed::draw_gui() {
 	glClearColor(0.f, 0.f, 0.f, 0.f);
 	glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
 
+	glEnable(GL_BLEND);
+	glBlendEquationSeparate(GL_FUNC_ADD, GL_FUNC_ADD);
+	glBlendFuncSeparate(GL_ONE, GL_ONE_MINUS_SRC_ALPHA, GL_ONE, GL_ONE_MINUS_SRC_ALPHA);
 
-	ImDrawList* list = ImGui::GetBackgroundDrawList();
-	list->AddCallback([](const ImDrawList*, const ImDrawCmd*) {
-		glBlendEquationSeparate(GL_FUNC_ADD, GL_FUNC_ADD);
-		glBlendFuncSeparate(GL_ONE, GL_ONE_MINUS_SRC_ALPHA, GL_ONE, GL_ONE_MINUS_SRC_ALPHA);
-	}, nullptr);
-
-	if (m_single_view) {
-		list->AddImageQuad((ImTextureID)(size_t)m_render_textures.front()->texture(), ImVec2{0.f, 0.f}, ImVec2{(float)display_w, 0.f}, ImVec2{(float)display_w, (float)display_h}, ImVec2{0.f, (float)display_h}, ImVec2(0, 0), ImVec2(1, 0), ImVec2(1, 1), ImVec2(0, 1));
-	} else {
-		m_dlss = false;
+	Vector2i extent = Vector2f{(float)display_w / m_n_views.x(), (float)display_h / m_n_views.y()}.cast<int>();
 
-		int i = 0;
-		for (int y = 0; y < m_n_views.y(); ++y) {
-			for (int x = 0; x < m_n_views.x(); ++x) {
-				if (i >= m_render_surfaces.size()) {
-					break;
-				}
+	int i = 0;
+	for (int y = 0; y < m_n_views.y(); ++y) {
+		for (int x = 0; x < m_n_views.x(); ++x) {
+			if (i >= m_views.size()) {
+				break;
+			}
 
-				Vector2f top_left{x * m_view_size.x(), y * m_view_size.y()};
-
-				list->AddImageQuad(
-					(ImTextureID)(size_t)m_render_textures[i]->texture(),
-					ImVec2{top_left.x(),                          top_left.y()                         },
-					ImVec2{top_left.x() + (float)m_view_size.x(), top_left.y()                         },
-					ImVec2{top_left.x() + (float)m_view_size.x(), top_left.y() + (float)m_view_size.y()},
-					ImVec2{top_left.x(),                          top_left.y() + (float)m_view_size.y()},
-					ImVec2(0, 0),
-					ImVec2(1, 0),
-					ImVec2(1, 1),
-					ImVec2(0, 1)
-				);
+			auto& view = m_views[i];
+			Vector2i top_left{x * extent.x(), display_h - (y + 1) * extent.y()};
+			blit_texture(view.foveation, m_rgba_render_textures.at(i)->texture(), m_foveated_rendering ? GL_LINEAR : GL_NEAREST, m_depth_render_textures.at(i)->texture(), 0, top_left, extent);
 
-				++i;
-			}
+			++i;
 		}
 	}
+	glFinish();
+	glViewport(0, 0, display_w, display_h);
 
+
+	ImDrawList* list = ImGui::GetBackgroundDrawList();
 	list->AddCallback(ImDrawCallback_ResetRenderState, nullptr);
 
 	auto draw_mesh = [&]() {
 		glClear(GL_DEPTH_BUFFER_BIT);
 		Vector2i res(display_w, display_h);
-		Vector2f focal_length = calc_focal_length(res, m_fov_axis, m_zoom);
-		Vector2f screen_center = render_screen_center();
-		draw_mesh_gl(m_mesh.verts, m_mesh.vert_normals, m_mesh.vert_colors, m_mesh.indices, res, focal_length, m_smoothed_camera, screen_center, (int)m_mesh_render_mode);
+		Vector2f focal_length = calc_focal_length(res, m_relative_focal_length, m_fov_axis, m_zoom);
+		draw_mesh_gl(m_mesh.verts, m_mesh.vert_normals, m_mesh.vert_colors, m_mesh.indices, res, focal_length, m_smoothed_camera, render_screen_center(m_screen_center), (int)m_mesh_render_mode);
 	};
 
 	// Visualizations are only meaningful when rendering a single view
-	if (m_single_view) {
+	if (m_views.size() == 1) {
 		if (m_mesh.verts.size() != 0 && m_mesh.indices.size() != 0 && m_mesh_render_mode != EMeshRenderMode::Off) {
 			list->AddCallback([](const ImDrawList*, const ImDrawCmd* cmd) {
 				(*(decltype(draw_mesh)*)cmd->UserCallbackData)();
@@ -1955,7 +2366,7 @@ void Testbed::prepare_next_camera_path_frame() {
 	// If we're rendering a video, we'd like to accumulate multiple spp
 	// for motion blur. Hence dump the frame once the target spp has been reached
 	// and only reset _then_.
-	if (m_render_surfaces.front().spp() == m_camera_path.render_settings.spp) {
+	if (m_views.front().render_buffer->spp() == m_camera_path.render_settings.spp) {
 		auto tmp_dir = fs::path{"tmp"};
 		if (!tmp_dir.exists()) {
 			if (!fs::create_directory(tmp_dir)) {
@@ -1965,7 +2376,7 @@ void Testbed::prepare_next_camera_path_frame() {
 			}
 		}
 
-		Vector2i res = m_render_surfaces.front().out_resolution();
+		Vector2i res = m_views.front().render_buffer->out_resolution();
 		const dim3 threads = { 16, 8, 1 };
 		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 
@@ -1973,7 +2384,7 @@ void Testbed::prepare_next_camera_path_frame() {
 		to_8bit_color_kernel<<<blocks, threads>>>(
 			res,
 			EColorSpace::SRGB, // the GUI always renders in SRGB
-			m_render_surfaces.front().surface(),
+			m_views.front().render_buffer->surface(),
 			image_data.data()
 		);
 
@@ -2047,7 +2458,7 @@ void Testbed::prepare_next_camera_path_frame() {
 	const auto& rs = m_camera_path.render_settings;
 	m_camera_path.play_time = (float)((double)m_camera_path.render_frame_idx / (double)rs.n_frames());
 
-	if (m_render_surfaces.front().spp() == 0) {
+	if (m_views.front().render_buffer->spp() == 0) {
 		set_camera_from_time(m_camera_path.play_time);
 		apply_camera_smoothing(rs.frame_milliseconds());
 
@@ -2109,134 +2520,204 @@ void Testbed::train_and_render(bool skip_rendering) {
 		autofocus();
 	}
 
-	if (m_single_view) {
-		// Should have been created when the window was created.
-		assert(!m_render_surfaces.empty());
+#ifdef NGP_GUI
+	if (m_hmd && m_hmd->is_visible()) {
+		for (auto& view : m_views) {
+			view.visualized_dimension = m_visualized_dimension;
+		}
 
-		auto& render_buffer = m_render_surfaces.front();
+		m_n_views = {m_views.size(), 1};
 
-		{
-			// Don't count the time being spent allocating buffers and resetting DLSS as part of the frame time.
-			// Otherwise the dynamic resolution calculations for following frames will be thrown out of whack
-			// and may even start oscillating.
-			auto skip_start = std::chrono::steady_clock::now();
-			ScopeGuard skip_timing_guard{[&]() {
-				start += std::chrono::steady_clock::now() - skip_start;
-			}};
-			if (m_dlss) {
-				render_buffer.enable_dlss(m_window_res);
-				m_aperture_size = 0.0f;
-			} else {
-				render_buffer.disable_dlss();
-			}
+		m_nerf.render_with_lens_distortion = false;
+		reset_accumulation(true);
+	} else if (m_single_view) {
+		set_n_views(1);
+		m_n_views = {1, 1};
 
-			auto render_res = render_buffer.in_resolution();
-			if (render_res.isZero() || (m_train && m_training_step == 0)) {
-				render_res = m_window_res/16;
-			} else {
-				render_res = render_res.cwiseMin(m_window_res);
-			}
+		auto& view = m_views.front();
+
+		view.full_resolution = m_window_res;
+
+		view.camera0 = m_smoothed_camera;
+
+		// Motion blur over the fraction of time that the shutter is open. Interpolate in log-space to preserve rotations.
+		view.camera1 = m_camera_path.rendering ? log_space_lerp(m_smoothed_camera, m_camera_path.render_frame_end_camera, m_camera_path.render_settings.shutter_fraction) : view.camera0;
+
+		view.visualized_dimension = m_visualized_dimension;
+		view.relative_focal_length = m_relative_focal_length;
+		view.screen_center = m_screen_center;
+		view.render_buffer->set_hidden_area_mask(nullptr);
+		view.foveation = {};
+		view.device = &primary_device();
+	} else {
+		int n_views = n_dimensions_to_visualize()+1;
+
+		float d = std::sqrt((float)m_window_res.x() * (float)m_window_res.y() / (float)n_views);
+
+		int nx = (int)std::ceil((float)m_window_res.x() / d);
+		int ny = (int)std::ceil((float)n_views / (float)nx);
+
+		m_n_views = {nx, ny};
+		Vector2i view_size = {m_window_res.x() / nx, m_window_res.y() / ny};
+
+		set_n_views(n_views);
+
+		int i = 0;
+		for (int y = 0; y < ny; ++y) {
+			for (int x = 0; x < nx; ++x) {
+				if (i >= n_views) {
+					break;
+				}
 
-			float render_time_per_fullres_frame = m_render_ms.val() / (float)render_res.x() / (float)render_res.y() * (float)m_window_res.x() * (float)m_window_res.y();
+				m_views[i].full_resolution = view_size;
 
-			// Make sure we don't starve training with slow rendering
-			float factor = std::sqrt(1000.0f / m_dynamic_res_target_fps / render_time_per_fullres_frame);
-			if (!m_dynamic_res) {
-				factor = 8.f/(float)m_fixed_res_factor;
+				m_views[i].camera0 = m_views[i].camera1 = m_smoothed_camera;
+				m_views[i].visualized_dimension = i-1;
+				m_views[i].relative_focal_length = m_relative_focal_length;
+				m_views[i].screen_center = m_screen_center;
+				m_views[i].render_buffer->set_hidden_area_mask(nullptr);
+				m_views[i].foveation = {};
+				m_views[i].device = &primary_device();
+				++i;
 			}
+		}
+	}
 
-			factor = tcnn::clamp(factor, 1.0f/16.0f, 1.0f);
+	if (m_dlss) {
+		m_aperture_size = 0.0f;
+		if (!supports_dlss(m_nerf.render_lens.mode)) {
+			m_nerf.render_with_lens_distortion = false;
+		}
+	}
 
-			if (factor > m_last_render_res_factor * 1.2f || factor < m_last_render_res_factor * 0.8f || factor == 1.0f || !m_dynamic_res) {
-				render_res = (m_window_res.cast<float>() * factor).cast<int>().cwiseMin(m_window_res).cwiseMax(m_window_res/16);
-				m_last_render_res_factor = factor;
+	// Update dynamic res and DLSS
+	{
+		// Don't count the time being spent allocating buffers and resetting DLSS as part of the frame time.
+		// Otherwise the dynamic resolution calculations for following frames will be thrown out of whack
+		// and may even start oscillating.
+		auto skip_start = std::chrono::steady_clock::now();
+		ScopeGuard skip_timing_guard{[&]() {
+			start += std::chrono::steady_clock::now() - skip_start;
+		}};
+
+		size_t n_pixels = 0, n_pixels_full_res = 0;
+		for (const auto& view : m_views) {
+			n_pixels += view.render_buffer->in_resolution().prod();
+			n_pixels_full_res += view.full_resolution.prod();
+		}
+
+		float pixel_ratio = (n_pixels == 0 || (m_train && m_training_step == 0)) ? (1.0f / 256.0f) : ((float)n_pixels / (float)n_pixels_full_res);
+
+		float last_factor = std::sqrt(pixel_ratio);
+		float factor = std::sqrt(pixel_ratio / m_render_ms.val() * 1000.0f / m_dynamic_res_target_fps);
+		if (!m_dynamic_res) {
+			factor = 8.f / (float)m_fixed_res_factor;
+		}
+
+		factor = tcnn::clamp(factor, 1.0f / 16.0f, 1.0f);
+
+		for (auto&& view : m_views) {
+			if (m_dlss) {
+				view.render_buffer->enable_dlss(*m_dlss_provider, view.full_resolution);
+			} else {
+				view.render_buffer->disable_dlss();
 			}
 
+			Vector2i render_res = view.render_buffer->in_resolution();
+			Vector2i new_render_res = (view.full_resolution.cast<float>() * factor).cast<int>().cwiseMin(view.full_resolution).cwiseMax(view.full_resolution / 16);
+
 			if (m_camera_path.rendering) {
-				render_res = m_camera_path.render_settings.resolution;
-				m_last_render_res_factor = 1.0f;
+				new_render_res = m_camera_path.render_settings.resolution;
 			}
 
-			if (render_buffer.dlss()) {
-				render_res = render_buffer.dlss()->clamp_resolution(render_res);
-				render_buffer.dlss()->update_feature(render_res, render_buffer.dlss()->is_hdr(), render_buffer.dlss()->sharpen());
+			float ratio = std::sqrt((float)render_res.prod() / (float)new_render_res.prod());
+			if (ratio > 1.2f || ratio < 0.8f || factor == 1.0f || !m_dynamic_res || m_camera_path.rendering) {
+				render_res = new_render_res;
 			}
 
-			render_buffer.resize(render_res);
-		}
+			if (view.render_buffer->dlss()) {
+				render_res = view.render_buffer->dlss()->clamp_resolution(render_res);
+				view.render_buffer->dlss()->update_feature(render_res, view.render_buffer->dlss()->is_hdr(), view.render_buffer->dlss()->sharpen());
+			}
 
-		render_frame(
-			m_smoothed_camera,
-			m_camera_path.rendering ? log_space_lerp(m_smoothed_camera, m_camera_path.render_frame_end_camera, m_camera_path.render_settings.shutter_fraction) : m_smoothed_camera,
-			{0.0f, 0.0f, 0.0f, 1.0f},
-			render_buffer
-		);
+			view.render_buffer->resize(render_res);
 
-#ifdef NGP_GUI
-		m_render_textures.front()->blit_from_cuda_mapping();
+			if (m_foveated_rendering) {
+				float foveation_warped_full_res_diameter = 0.55f;
+				Vector2f resolution_scale = render_res.cast<float>().cwiseQuotient(view.full_resolution.cast<float>());
 
-		if (m_picture_in_picture_res > 0) {
-			Vector2i res(m_picture_in_picture_res, m_picture_in_picture_res * 9/16);
-			m_pip_render_surface->resize(res);
-			if (m_pip_render_surface->spp() < 8) {
-				// a bit gross, but let's copy the keyframe's state into the global state in order to not have to plumb through the fov etc to render_frame.
-				CameraKeyframe backup = copy_camera_to_keyframe();
-				CameraKeyframe pip_kf = m_camera_path.eval_camera_path(m_camera_path.play_time);
-				set_camera_from_keyframe(pip_kf);
-				render_frame(pip_kf.m(), pip_kf.m(), Eigen::Vector4f::Zero(), *m_pip_render_surface);
-				set_camera_from_keyframe(backup);
+				// Only start foveation when DLSS if off or if DLSS is asked to do more than 1.5x upscaling.
+				// The reason for the 1.5x threshold is that DLSS can do up to 3x upscaling, at which point a foveation
+				// factor of 2x = 3.0x/1.5x corresponds exactly to bilinear super sampling, which is helpful in
+				// suppressing DLSS's artifacts.
+				float foveation_begin_factor = m_dlss ? 1.5f : 1.0f;
 
-				m_pip_render_texture->blit_from_cuda_mapping();
+				resolution_scale = (resolution_scale * foveation_begin_factor).cwiseMin(1.0f).cwiseMax(1.0f / m_foveated_rendering_max_scaling);
+				view.foveation = {resolution_scale, Vector2f::Ones() - view.screen_center, Vector2f::Constant(foveation_warped_full_res_diameter * 0.5f)};
+			} else {
+				view.foveation = {};
 			}
 		}
-#endif
-	} else {
-#ifdef NGP_GUI
-		// Don't do DLSS when multi-view rendering
-		m_dlss = false;
-		m_render_surfaces.front().disable_dlss();
+	}
 
-		int n_views = n_dimensions_to_visualize()+1;
+	// Make sure all in-use auxiliary GPUs have the latest model and bitfield
+	std::unordered_set<CudaDevice*> devices_in_use;
+	for (auto& view : m_views) {
+		if (!view.device || devices_in_use.count(view.device) != 0) {
+			continue;
+		}
 
-		float d = std::sqrt((float)m_window_res.x() * (float)m_window_res.y() / (float)n_views);
+		devices_in_use.insert(view.device);
+		sync_device(*view.render_buffer, *view.device);
+	}
 
-		int nx = (int)std::ceil((float)m_window_res.x() / d);
-		int ny = (int)std::ceil((float)n_views / (float)nx);
+	{
+		SyncedMultiStream synced_streams{m_stream.get(), m_views.size()};
+
+		std::vector<std::future<void>> futures(m_views.size());
+		for (size_t i = 0; i < m_views.size(); ++i) {
+			auto& view = m_views[i];
+			futures[i] = view.device->enqueue_task([this, &view, stream=synced_streams.get(i)]() {
+				auto device_guard = use_device(stream, *view.render_buffer, *view.device);
+				render_frame_main(*view.device, view.camera0, view.camera1, view.screen_center, view.relative_focal_length, {0.0f, 0.0f, 0.0f, 1.0f}, view.foveation, view.visualized_dimension);
+			});
+		}
 
-		m_n_views = {nx, ny};
-		m_view_size = {m_window_res.x() / nx, m_window_res.y() / ny};
+		for (size_t i = 0; i < m_views.size(); ++i) {
+			auto& view = m_views[i];
 
-		while (m_render_surfaces.size() > n_views) {
-			m_render_surfaces.pop_back();
-		}
+			if (futures[i].valid()) {
+				futures[i].get();
+			}
 
-		m_render_textures.resize(n_views);
-		while (m_render_surfaces.size() < n_views) {
-			size_t idx = m_render_surfaces.size();
-			m_render_textures[idx] = std::make_shared<GLTexture>();
-			m_render_surfaces.emplace_back(m_render_textures[idx]);
+			render_frame_epilogue(synced_streams.get(i), view.camera0, view.prev_camera, view.screen_center, view.relative_focal_length, view.foveation, view.prev_foveation, *view.render_buffer, true);
+			view.prev_camera = view.camera0;
+			view.prev_foveation = view.foveation;
 		}
+	}
 
-		int i = 0;
-		for (int y = 0; y < ny; ++y) {
-			for (int x = 0; x < nx; ++x) {
-				if (i >= n_views) {
-					return;
-				}
-
-				m_visualized_dimension = i-1;
-				m_render_surfaces[i].resize(m_view_size);
+	for (size_t i = 0; i < m_views.size(); ++i) {
+		m_rgba_render_textures.at(i)->blit_from_cuda_mapping();
+		m_depth_render_textures.at(i)->blit_from_cuda_mapping();
+	}
 
-				render_frame(m_smoothed_camera, m_smoothed_camera, Eigen::Vector4f::Zero(), m_render_surfaces[i]);
+	if (m_picture_in_picture_res > 0) {
+		Vector2i res(m_picture_in_picture_res, m_picture_in_picture_res * 9/16);
+		m_pip_render_buffer->resize(res);
+		if (m_pip_render_buffer->spp() < 8) {
+			// a bit gross, but let's copy the keyframe's state into the global state in order to not have to plumb through the fov etc to render_frame.
+			CameraKeyframe backup = copy_camera_to_keyframe();
+			CameraKeyframe pip_kf = m_camera_path.eval_camera_path(m_camera_path.play_time);
+			set_camera_from_keyframe(pip_kf);
+			render_frame(m_stream.get(), pip_kf.m(), pip_kf.m(), pip_kf.m(), m_screen_center, m_relative_focal_length, Eigen::Vector4f::Zero(), {}, {}, m_visualized_dimension, *m_pip_render_buffer);
+			set_camera_from_keyframe(backup);
 
-				m_render_textures[i]->blit_from_cuda_mapping();
-				++i;
-			}
+			m_pip_render_texture->blit_from_cuda_mapping();
 		}
-#else
-		throw std::runtime_error{"Multi-view rendering is only supported when compiling with NGP_GUI."};
-#endif
 	}
+#endif
+
+	CUDA_CHECK_THROW(cudaStreamSynchronize(m_stream.get()));
 }
 
 
@@ -2262,7 +2743,6 @@ void Testbed::create_second_window() {
 		win_x = 0x40000000;
 		win_y = 0x40000000;
 		static const char* copy_shader_vert = "\
-			layout (location = 0)\n\
 			in vec2 vertPos_data;\n\
 			out vec2 texCoords;\n\
 			void main(){\n\
@@ -2300,7 +2780,8 @@ void Testbed::create_second_window() {
 		1.0f, 1.0f,
 		1.0f, 1.0f,
 		1.0f, -1.0f,
-		-1.0f, -1.0f};
+		-1.0f, -1.0f
+	};
 	glBindBuffer(GL_ARRAY_BUFFER, m_second_window.vbo);
 	glBufferData(GL_ARRAY_BUFFER, sizeof(fsquadVerts), fsquadVerts, GL_STATIC_DRAW);
 	glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 2 * sizeof(float), (void *)0);
@@ -2308,6 +2789,84 @@ void Testbed::create_second_window() {
 	glBindBuffer(GL_ARRAY_BUFFER, 0);
 	glBindVertexArray(0);
 }
+
+void Testbed::init_vr() {
+	try {
+		if (!m_glfw_window) {
+			throw std::runtime_error{"`init_window` must be called before `init_vr`"};
+		}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+		m_hmd = std::make_unique<OpenXRHMD>(wglGetCurrentDC(), glfwGetWGLContext(m_glfw_window));
+#elif defined(XR_USE_PLATFORM_XLIB)
+		Display* xDisplay = glfwGetX11Display();
+		GLXContext glxContext = glfwGetGLXContext(m_glfw_window);
+
+		int glxFBConfigXID = 0;
+		glXQueryContext(xDisplay, glxContext, GLX_FBCONFIG_ID, &glxFBConfigXID);
+		int attributes[3] = { GLX_FBCONFIG_ID, glxFBConfigXID, 0 };
+		int nelements = 1;
+		GLXFBConfig* pglxFBConfig = glXChooseFBConfig(xDisplay, 0, attributes, &nelements);
+		if (nelements != 1 || !pglxFBConfig) {
+			throw std::runtime_error{"init_vr(): Couldn't obtain GLXFBConfig"};
+		}
+
+		GLXFBConfig glxFBConfig = *pglxFBConfig;
+
+		XVisualInfo* visualInfo = glXGetVisualFromFBConfig(xDisplay, glxFBConfig);
+		if (!visualInfo) {
+			throw std::runtime_error{"init_vr(): Couldn't obtain XVisualInfo"};
+		}
+
+		m_hmd = std::make_unique<OpenXRHMD>(xDisplay, visualInfo->visualid, glxFBConfig, glXGetCurrentDrawable(), glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+		m_hmd = std::make_unique<OpenXRHMD>(glfwGetWaylandDisplay());
+#endif
+
+		// DLSS + sharpening is instrumental in getting VR to look good.
+		if (m_dlss_provider) {
+			m_dlss = true;
+			m_foveated_rendering = true;
+
+			// VERY aggressive performance settings (detriment to quality)
+			// to allow maintaining VR-adequate frame rates.
+			m_nerf.render_min_transmittance = 0.2f;
+		}
+
+		// If multiple GPUs are available, shoot for 60 fps in VR.
+		// Otherwise, it wouldn't be realistic to expect more than 30.
+		m_dynamic_res_target_fps = m_devices.size() > 1 ? 60 : 30;
+
+		// Many VR runtimes perform optical flow for automatic reprojection / motion smoothing.
+		// This breaks down for solid-color background, sometimes leading to artifacts. Hence:
+		// set background color to transparent and, in spherical_checkerboard_kernel(...),
+		// blend a checkerboard. If the user desires a solid background nonetheless, they can
+		// set the background color to have an alpha value of 1.0 manually via the GUI or via Python.
+		m_background_color = {0.0f, 0.0f, 0.0f, 0.0f};
+		m_render_transparency_as_checkerboard = true;
+	} catch (const std::runtime_error& e) {
+		if (std::string{e.what()}.find("XR_ERROR_FORM_FACTOR_UNAVAILABLE") != std::string::npos) {
+			throw std::runtime_error{"Could not initialize VR. Ensure that SteamVR, OculusVR, or any other OpenXR-compatible runtime is running. Also set it as the active OpenXR runtime."};
+		} else {
+			throw std::runtime_error{fmt::format("Could not initialize VR: {}", e.what())};
+		}
+	}
+}
+
+void Testbed::set_n_views(size_t n_views) {
+	while (m_views.size() > n_views) {
+		m_views.pop_back();
+	}
+
+	m_rgba_render_textures.resize(n_views);
+	m_depth_render_textures.resize(n_views);
+	while (m_views.size() < n_views) {
+		size_t idx = m_views.size();
+		m_rgba_render_textures[idx] = std::make_shared<GLTexture>();
+		m_depth_render_textures[idx] = std::make_shared<GLTexture>();
+		m_views.emplace_back(View{std::make_shared<CudaRenderBuffer>(m_rgba_render_textures[idx], m_depth_render_textures[idx])});
+	}
+};
 #endif //NGP_GUI
 
 void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
@@ -2322,23 +2881,22 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 	}
 
 #ifdef NGP_VULKAN
-	try {
-		vulkan_and_ngx_init();
-		m_dlss_supported = true;
-		if (m_testbed_mode == ETestbedMode::Nerf) {
-			m_dlss = true;
+	// Only try to initialize DLSS (Vulkan+NGX) if the
+	// GPU is sufficiently new. Older GPUs don't support
+	// DLSS, so it is preferable to not make a futile
+	// attempt and emit a warning that confuses users.
+	if (primary_device().compute_capability() >= 70) {
+		try {
+			m_dlss_provider = init_vulkan_and_ngx();
+			if (m_testbed_mode == ETestbedMode::Nerf) {
+				m_dlss = true;
+			}
+		} catch (const std::runtime_error& e) {
+			tlog::warning() << "Could not initialize Vulkan and NGX. DLSS not supported. (" << e.what() << ")";
 		}
-	} catch (const std::runtime_error& e) {
-		tlog::warning() << "Could not initialize Vulkan and NGX. DLSS not supported. (" << e.what() << ")";
 	}
-#else
-	m_dlss_supported = false;
 #endif
 
-	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
-	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
-	glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
-	glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GLFW_TRUE);
 	glfwWindowHint(GLFW_VISIBLE, hidden ? GLFW_FALSE : GLFW_TRUE);
 	std::string title = "Instant Neural Graphics Primitives";
 	m_glfw_window = glfwCreateWindow(m_window_res.x(), m_window_res.y(), title.c_str(), NULL, NULL);
@@ -2358,6 +2916,16 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 #endif
 	glfwSwapInterval(0); // Disable vsync
 
+	GLint gl_version_minor, gl_version_major;
+	glGetIntegerv(GL_MINOR_VERSION, &gl_version_minor);
+	glGetIntegerv(GL_MAJOR_VERSION, &gl_version_major);
+
+	if (gl_version_major < 3 || (gl_version_major == 3 && gl_version_minor < 1)) {
+		throw std::runtime_error{fmt::format("Unsupported OpenGL version {}.{}. instant-ngp requires at least OpenGL 3.1", gl_version_major, gl_version_minor)};
+	}
+
+	tlog::success() << "Initialized OpenGL version " << glGetString(GL_VERSION);
+
 	glfwSetWindowUserPointer(m_glfw_window, this);
 	glfwSetDropCallback(m_glfw_window, [](GLFWwindow* window, int count, const char** paths) {
 		Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window);
@@ -2420,26 +2988,40 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 	IMGUI_CHECKVERSION();
 	ImGui::CreateContext();
 	ImGuiIO& io = ImGui::GetIO(); (void)io;
-	//io.ConfigFlags |= ImGuiConfigFlags_NavEnableKeyboard;     // Enable Keyboard Controls
-	io.ConfigInputTrickleEventQueue = false; // new ImGui event handling seems to make camera controls laggy if this is true.
+
+	// By default, imgui places its configuration (state of the GUI -- size of windows,
+	// which regions are expanded, etc.) in ./imgui.ini relative to the working directory.
+	// Instead, we would like to place imgui.ini in the directory that instant-ngp project
+	// resides in.
+	static std::string ini_filename;
+	ini_filename = (get_root_dir()/"imgui.ini").str();
+	io.IniFilename = ini_filename.c_str();
+
+	// New ImGui event handling seems to make camera controls laggy if input trickling is true.
+	// So disable input trickling.
+	io.ConfigInputTrickleEventQueue = false;
 	ImGui::StyleColorsDark();
 	ImGui_ImplGlfw_InitForOpenGL(m_glfw_window, true);
-	ImGui_ImplOpenGL3_Init("#version 330 core");
+	ImGui_ImplOpenGL3_Init("#version 140");
 
 	ImGui::GetStyle().ScaleAllSizes(xscale);
 	ImFontConfig font_cfg;
 	font_cfg.SizePixels = 13.0f * xscale;
 	io.Fonts->AddFontDefault(&font_cfg);
 
+	init_opengl_shaders();
+
 	// Make sure there's at least one usable render texture
-	m_render_textures = { std::make_shared<GLTexture>() };
+	m_rgba_render_textures = { std::make_shared<GLTexture>() };
+	m_depth_render_textures = { std::make_shared<GLTexture>() };
 
-	m_render_surfaces.clear();
-	m_render_surfaces.emplace_back(m_render_textures.front());
-	m_render_surfaces.front().resize(m_window_res);
+	m_views.clear();
+	m_views.emplace_back(View{std::make_shared<CudaRenderBuffer>(m_rgba_render_textures.front(), m_depth_render_textures.front())});
+	m_views.front().full_resolution = m_window_res;
+	m_views.front().render_buffer->resize(m_views.front().full_resolution);
 
 	m_pip_render_texture = std::make_shared<GLTexture>();
-	m_pip_render_surface = std::make_unique<CudaRenderBuffer>(m_pip_render_texture);
+	m_pip_render_buffer = std::make_unique<CudaRenderBuffer>(m_pip_render_texture);
 
 	m_render_window = true;
 
@@ -2457,16 +3039,17 @@ void Testbed::destroy_window() {
 		throw std::runtime_error{"Window must be initialized to be destroyed."};
 	}
 
-	m_render_surfaces.clear();
-	m_render_textures.clear();
+	m_hmd.reset();
 
-	m_pip_render_surface.reset();
+	m_views.clear();
+	m_rgba_render_textures.clear();
+	m_depth_render_textures.clear();
+
+	m_pip_render_buffer.reset();
 	m_pip_render_texture.reset();
 
-#ifdef NGP_VULKAN
-	m_dlss_supported = m_dlss = false;
-	vulkan_and_ngx_destroy();
-#endif
+	m_dlss = false;
+	m_dlss_provider.reset();
 
 	ImGui_ImplOpenGL3_Shutdown();
 	ImGui_ImplGlfw_Shutdown();
@@ -2474,6 +3057,9 @@ void Testbed::destroy_window() {
 	glfwDestroyWindow(m_glfw_window);
 	glfwTerminate();
 
+	m_blit_program = 0;
+	m_blit_vao = 0;
+
 	m_glfw_window = nullptr;
 	m_render_window = false;
 #endif //NGP_GUI
@@ -2482,9 +3068,12 @@ void Testbed::destroy_window() {
 bool Testbed::frame() {
 #ifdef NGP_GUI
 	if (m_render_window) {
-		if (!begin_frame_and_handle_user_input()) {
+		if (!begin_frame()) {
 			return false;
 		}
+
+		begin_vr_frame_and_handle_vr_input();
+		handle_user_input();
 	}
 #endif
 
@@ -2496,7 +3085,7 @@ bool Testbed::frame() {
 	}
 	bool skip_rendering = m_render_skip_due_to_lack_of_camera_movement_counter++ != 0;
 
-	if (!m_dlss && m_max_spp > 0 && !m_render_surfaces.empty() && m_render_surfaces.front().spp() >= m_max_spp) {
+	if (!m_dlss && m_max_spp > 0 && !m_views.empty() && m_views.front().render_buffer->spp() >= m_max_spp) {
 		skip_rendering = true;
 		if (!m_train) {
 			std::this_thread::sleep_for(1ms);
@@ -2508,6 +3097,12 @@ bool Testbed::frame() {
 		skip_rendering = false;
 	}
 
+#ifdef NGP_GUI
+	if (m_hmd && m_hmd->is_visible()) {
+		skip_rendering = false;
+	}
+#endif
+
 	if (!skip_rendering || (std::chrono::steady_clock::now() - m_last_gui_draw_time_point) > 25ms) {
 		redraw_gui_next_frame();
 	}
@@ -2540,6 +3135,32 @@ bool Testbed::frame() {
 
 		ImGui::EndFrame();
 	}
+
+	if (m_vr_frame_info) {
+		// If HMD is visible to the user, splat rendered images to the HMD
+		if (m_hmd->is_visible()) {
+			size_t n_views = std::min(m_views.size(), m_vr_frame_info->views.size());
+
+			// Blit textures to the OpenXR-owned framebuffers (each corresponding to one eye)
+			for (size_t i = 0; i < n_views; ++i) {
+				const auto& vr_view = m_vr_frame_info->views.at(i);
+
+				Vector2i resolution = {
+					vr_view.view.subImage.imageRect.extent.width,
+					vr_view.view.subImage.imageRect.extent.height,
+				};
+
+				blit_texture(m_views.at(i).foveation, m_rgba_render_textures.at(i)->texture(), GL_LINEAR, m_depth_render_textures.at(i)->texture(), vr_view.framebuffer, Vector2i::Zero(), resolution);
+			}
+
+			glFinish();
+		}
+
+		// Far and near planes are intentionally reversed, because we map depth inversely
+		// to z. I.e. a window-space depth of 1 refers to the near plane and a depth of 0
+		// to the far plane. This results in much better numeric precision.
+		m_hmd->end_frame(m_vr_frame_info, m_ndc_zfar / m_scale, m_ndc_znear / m_scale);
+	}
 #endif
 
 	return true;
@@ -2579,8 +3200,10 @@ void Testbed::set_camera_from_keyframe(const CameraKeyframe& k) {
 }
 
 void Testbed::set_camera_from_time(float t) {
-	if (m_camera_path.keyframes.empty())
+	if (m_camera_path.keyframes.empty()) {
 		return;
+	}
+
 	set_camera_from_keyframe(m_camera_path.eval_camera_path(t));
 }
 
@@ -2711,6 +3334,8 @@ void Testbed::reset_network(bool clear_density_grid) {
 	if (clear_density_grid) {
 		m_nerf.density_grid.memset(0);
 		m_nerf.density_grid_bitfield.memset(0);
+
+		set_all_devices_dirty();
 	}
 
 	m_loss_graph_samples = 0;
@@ -2723,6 +3348,13 @@ void Testbed::reset_network(bool clear_density_grid) {
 	json& optimizer_config = config["optimizer"];
 	json& network_config = config["network"];
 
+	// If the network config is incomplete, avoid doing further work.
+	/*
+	if (config.is_null() || encoding_config.is_null() || loss_config.is_null() || optimizer_config.is_null() || network_config.is_null()) {
+		return;
+	}
+	*/
+
 	auto dims = network_dims();
 
 	if (m_testbed_mode == ETestbedMode::Nerf) {
@@ -2798,16 +3430,22 @@ void Testbed::reset_network(bool clear_density_grid) {
 
 		uint32_t n_dir_dims = 3;
 		uint32_t n_extra_dims = m_nerf.training.dataset.n_extra_dims();
-		m_network = m_nerf_network = std::make_shared<NerfNetwork<precision_t>>(
-			dims.n_pos,
-			n_dir_dims,
-			n_extra_dims,
-			dims.n_pos + 1, // The offset of 1 comes from the dt member variable of NerfCoordinate. HACKY
-			encoding_config,
-			dir_encoding_config,
-			network_config,
-			rgb_network_config
-		);
+
+		// Instantiate an additional model for each auxiliary GPU
+		for (auto& device : m_devices) {
+			device.set_nerf_network(std::make_shared<NerfNetwork<precision_t>>(
+				dims.n_pos,
+				n_dir_dims,
+				n_extra_dims,
+				dims.n_pos + 1, // The offset of 1 comes from the dt member variable of NerfCoordinate. HACKY
+				encoding_config,
+				dir_encoding_config,
+				network_config,
+				rgb_network_config
+			));
+		}
+
+		m_network = m_nerf_network = primary_device().nerf_network();
 
 		m_encoding = m_nerf_network->encoding();
 		n_encoding_params = m_encoding->n_params() + m_nerf_network->dir_encoding()->n_params();
@@ -2873,7 +3511,12 @@ void Testbed::reset_network(bool clear_density_grid) {
 			}
 		}
 
-		m_network = std::make_shared<NetworkWithInputEncoding<precision_t>>(m_encoding, dims.n_output, network_config);
+		for (auto& device : m_devices) {
+			device.set_network(std::make_shared<NetworkWithInputEncoding<precision_t>>(m_encoding, dims.n_output, network_config));
+		}
+
+		m_network = primary_device().network();
+
 		n_encoding_params = m_encoding->n_params();
 
 		tlog::info()
@@ -2910,6 +3553,7 @@ void Testbed::reset_network(bool clear_density_grid) {
 		}
 	}
 
+	set_all_devices_dirty();
 }
 
 Testbed::Testbed(ETestbedMode mode) {
@@ -2955,6 +3599,28 @@ Testbed::Testbed(ETestbedMode mode) {
 		tlog::warning() << "This program was compiled for >=" << MIN_GPU_ARCH << " and may thus behave unexpectedly.";
 	}
 
+	m_devices.emplace_back(active_device, true);
+
+	// Multi-GPU is only supported in NeRF mode for now
+	int n_devices = cuda_device_count();
+	for (int i = 0; i < n_devices; ++i) {
+		if (i == active_device) {
+			continue;
+		}
+
+		if (cuda_compute_capability(i) >= MIN_GPU_ARCH) {
+			m_devices.emplace_back(i, false);
+		}
+	}
+
+	if (m_devices.size() > 1) {
+		tlog::success() << "Detected auxiliary GPUs:";
+		for (size_t i = 1; i < m_devices.size(); ++i) {
+			const auto& device = m_devices[i];
+			tlog::success() << "  #" << device.id() << ": " << device.name() << " [" << device.compute_capability() << "]";
+		}
+	}
+
 	m_network_config = {
 		{"loss", {
 			{"otype", "L2"}
@@ -3032,6 +3698,8 @@ void Testbed::train(uint32_t batch_size) {
 		throw std::runtime_error{"Cannot train without a mode."};
 	}
 
+	set_all_devices_dirty();
+
 	// If we don't have a trainer, as can happen when having loaded training data or changed modes without having
 	// explicitly loaded a new neural network.
 	if (!m_trainer) {
@@ -3097,18 +3765,16 @@ void Testbed::train(uint32_t batch_size) {
 	}
 }
 
-Vector2f Testbed::calc_focal_length(const Vector2i& resolution, int fov_axis, float zoom) const {
-	return m_relative_focal_length * resolution[fov_axis] * zoom;
+Vector2f Testbed::calc_focal_length(const Vector2i& resolution, const Vector2f& relative_focal_length, int fov_axis, float zoom) const {
+	return relative_focal_length * resolution[fov_axis] * zoom;
 }
 
-Vector2f Testbed::render_screen_center() const {
-	// see pixel_to_ray for how screen center is used; 0.5,0.5 is 'normal'. we flip so that it becomes the point in the original image we want to center on.
-	auto screen_center = m_screen_center;
-	return {(0.5f-screen_center.x())*m_zoom + 0.5f, (0.5-screen_center.y())*m_zoom + 0.5f};
+Vector2f Testbed::render_screen_center(const Vector2f& screen_center) const {
+	// see pixel_to_ray for how screen center is used; 0.5, 0.5 is 'normal'. we flip so that it becomes the point in the original image we want to center on.
+	return (Vector2f::Constant(0.5f) - screen_center) * m_zoom + Vector2f::Constant(0.5f);
 }
 
 __global__ void dlss_prep_kernel(
-	ETestbedMode mode,
 	Vector2i resolution,
 	uint32_t sample_index,
 	Vector2f focal_length,
@@ -3116,18 +3782,16 @@ __global__ void dlss_prep_kernel(
 	Vector3f parallax_shift,
 	bool snap_to_pixel_centers,
 	float* depth_buffer,
+	const float znear,
+	const float zfar,
 	Matrix<float, 3, 4> camera,
 	Matrix<float, 3, 4> prev_camera,
 	cudaSurfaceObject_t depth_surface,
 	cudaSurfaceObject_t mvec_surface,
 	cudaSurfaceObject_t exposure_surface,
-	Lens lens,
-	const float view_dist,
-	const float prev_view_dist,
-	const Vector2f image_pos,
-	const Vector2f prev_image_pos,
-	const Vector2i image_resolution,
-	const Vector2i quilting_dims
+	Foveation foveation,
+	Foveation prev_foveation,
+	Lens lens
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -3141,26 +3805,11 @@ __global__ void dlss_prep_kernel(
 	uint32_t x_orig = x;
 	uint32_t y_orig = y;
 
-	if (quilting_dims != Vector2i::Ones()) {
-		apply_quilting(&x, &y, resolution, parallax_shift, quilting_dims);
-	}
-
 	const float depth = depth_buffer[idx];
-	Vector2f mvec = mode == ETestbedMode::Image ? motion_vector_2d(
-		sample_index,
-		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
-		image_resolution,
-		screen_center,
-		view_dist,
-		prev_view_dist,
-		image_pos,
-		prev_image_pos,
-		snap_to_pixel_centers
-	) : motion_vector_3d(
+	Vector2f mvec = motion_vector(
 		sample_index,
 		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
+		resolution,
 		focal_length,
 		camera,
 		prev_camera,
@@ -3168,13 +3817,16 @@ __global__ void dlss_prep_kernel(
 		parallax_shift,
 		snap_to_pixel_centers,
 		depth,
+		foveation,
+		prev_foveation,
 		lens
 	);
 
 	surf2Dwrite(make_float2(mvec.x(), mvec.y()), mvec_surface, x_orig * sizeof(float2), y_orig);
 
-	// Scale depth buffer to be guaranteed in [0,1].
-	surf2Dwrite(std::min(std::max(depth / 128.0f, 0.0f), 1.0f), depth_surface, x_orig * sizeof(float), y_orig);
+	// DLSS was trained on games, which presumably used standard normalized device coordinates (ndc)
+	// depth buffers. So: convert depth to NDC with reasonable near- and far planes.
+	surf2Dwrite(to_ndc_depth(depth, znear, zfar), depth_surface, x_orig * sizeof(float), y_orig);
 
 	// First thread write an exposure factor of 1. Since DLSS will run on tonemapped data,
 	// exposure is assumed to already have been applied to DLSS' inputs.
@@ -3183,22 +3835,202 @@ __global__ void dlss_prep_kernel(
 	}
 }
 
-void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matrix<float, 3, 4>& camera_matrix1, const Vector4f& nerf_rolling_shutter, CudaRenderBuffer& render_buffer, bool to_srgb) {
-	Vector2i max_res = m_window_res.cwiseMax(render_buffer.in_resolution());
+__global__ void spherical_checkerboard_kernel(
+	Vector2i resolution,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera,
+	Vector2f screen_center,
+	Vector3f parallax_shift,
+	Foveation foveation,
+	Lens lens,
+	Array4f* frame_buffer
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	Ray ray = pixel_to_ray(
+		0,
+		{x, y},
+		resolution,
+		focal_length,
+		camera,
+		screen_center,
+		parallax_shift,
+		false,
+		0.0f,
+		1.0f,
+		0.0f,
+		foveation,
+		{}, // No need for hidden area mask
+		lens
+	);
+
+	// Blend with checkerboard to break up reprojection weirdness in some VR runtimes
+	host_device_swap(ray.d.z(), ray.d.y());
+	Vector2f spherical = dir_to_spherical(ray.d.normalized()) * 32.0f / PI();
+	const Array4f dark_gray = {0.5f, 0.5f, 0.5f, 1.0f};
+	const Array4f light_gray = {0.55f, 0.55f, 0.55f, 1.0f};
+	Array4f checker = fabsf(fmodf(floorf(spherical.x()) + floorf(spherical.y()), 2.0f)) < 0.5f ? dark_gray : light_gray;
+
+	uint32_t idx = x + resolution.x() * y;
+	frame_buffer[idx] += (1.0f - frame_buffer[idx].w()) * checker;
+}
+
+__global__ void vr_overlay_hands_kernel(
+	Vector2i resolution,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera,
+	Vector2f screen_center,
+	Vector3f parallax_shift,
+	Foveation foveation,
+	Lens lens,
+	Vector3f left_hand_pos,
+	float left_grab_strength,
+	Array4f left_hand_color,
+	Vector3f right_hand_pos,
+	float right_grab_strength,
+	Array4f right_hand_color,
+	float hand_radius,
+	EColorSpace output_color_space,
+	cudaSurfaceObject_t surface
+	// TODO: overwrite depth buffer
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	Ray ray = pixel_to_ray(
+		0,
+		{x, y},
+		resolution,
+		focal_length,
+		camera,
+		screen_center,
+		parallax_shift,
+		false,
+		0.0f,
+		1.0f,
+		0.0f,
+		foveation,
+		{}, // No need for hidden area mask
+		lens
+	);
+
+	Array4f color = Array4f::Zero();
+	auto composit_hand = [&](Vector3f hand_pos, float grab_strength, Array4f hand_color) {
+		// Don't render the hand indicator if it's behind the ray origin.
+		if (ray.d.dot(hand_pos - ray.o) < 0.0f) {
+			return;
+		}
+
+		float distance = ray.distance_to(hand_pos);
+
+		Array4f base_color = Array4f::Zero();
+		const Array4f border_color = {0.4f, 0.4f, 0.4f, 0.4f};
+
+		// Divide hand radius into an inner part (4/5ths) and a border (1/5th).
+		float radius = hand_radius * 0.8f;
+		float border_width = hand_radius * 0.2f;
+
+		// When grabbing, shrink the inner part as a visual indicator.
+		radius *= 0.5f + 0.5f * (1.0f - grab_strength);
+
+		if (distance < radius) {
+			base_color = hand_color;
+		} else if (distance < radius + border_width) {
+			base_color = border_color;
+		} else {
+			return;
+		}
+
+		// Make hand color opaque when grabbing.
+		base_color.w() = grab_strength + (1.0f - grab_strength) * base_color.w();
+		color += base_color * (1.0f - color.w());
+	};
+
+	if (ray.d.dot(left_hand_pos - ray.o) < ray.d.dot(right_hand_pos - ray.o)) {
+		composit_hand(left_hand_pos, left_grab_strength, left_hand_color);
+		composit_hand(right_hand_pos, right_grab_strength, right_hand_color);
+	} else {
+		composit_hand(right_hand_pos, right_grab_strength, right_hand_color);
+		composit_hand(left_hand_pos, left_grab_strength, left_hand_color);
+	}
+
+	// Blend with existing color of pixel
+	Array4f prev_color;
+	surf2Dread((float4*)&prev_color, surface, x * sizeof(float4), y);
+	if (output_color_space == EColorSpace::SRGB) {
+		prev_color.head<3>() = srgb_to_linear(prev_color.head<3>());
+	}
 
-	render_buffer.clear_frame(m_stream.get());
+	color += (1.0f - color.w()) * prev_color;
+
+	if (output_color_space == EColorSpace::SRGB) {
+		color.head<3>() = linear_to_srgb(color.head<3>());
+	}
+
+	surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y);
+}
+
+void Testbed::render_frame(
+	cudaStream_t stream,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Matrix<float, 3, 4>& prev_camera_matrix,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Vector4f& nerf_rolling_shutter,
+	const Foveation& foveation,
+	const Foveation& prev_foveation,
+	int visualized_dimension,
+	CudaRenderBuffer& render_buffer,
+	bool to_srgb,
+	CudaDevice* device
+) {
+	if (!device) {
+		device = &primary_device();
+	}
 
-	Vector2f focal_length = calc_focal_length(render_buffer.in_resolution(), m_fov_axis, m_zoom);
-	Vector2f screen_center = render_screen_center();
+	sync_device(render_buffer, *device);
+
+	{
+		auto device_guard = use_device(stream, render_buffer, *device);
+		render_frame_main(*device, camera_matrix0, camera_matrix1, orig_screen_center, relative_focal_length, nerf_rolling_shutter, foveation, visualized_dimension);
+	}
+
+	render_frame_epilogue(stream, camera_matrix0, prev_camera_matrix, orig_screen_center, relative_focal_length, foveation, prev_foveation, render_buffer, to_srgb);
+}
+
+void Testbed::render_frame_main(
+	CudaDevice& device,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Vector4f& nerf_rolling_shutter,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
+	device.render_buffer_view().clear(device.stream());
 
 	if (!m_network) {
 		return;
 	}
 
+	Vector2f focal_length = calc_focal_length(device.render_buffer_view().resolution, relative_focal_length, m_fov_axis, m_zoom);
+	Vector2f screen_center = render_screen_center(orig_screen_center);
+
 	switch (m_testbed_mode) {
 		case ETestbedMode::Nerf:
 			if (!m_render_ground_truth || m_ground_truth_alpha < 1.0f) {
-				render_nerf(render_buffer, max_res, focal_length, camera_matrix0, camera_matrix1, nerf_rolling_shutter, screen_center, m_stream.get());
+				render_nerf(device.stream(), device.render_buffer_view(), *device.nerf_network(), device.data().density_grid_bitfield_ptr, focal_length, camera_matrix0, camera_matrix1, nerf_rolling_shutter, screen_center, foveation, visualized_dimension);
 			}
 			break;
 		case ETestbedMode::Sdf:
@@ -3219,15 +4051,13 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 							m_sdf.brick_data.data(),
 							m_sdf.triangles_gpu.data(),
 							false,
-							m_stream.get()
+							device.stream()
 						);
 					}
 				}
+
 				distance_fun_t distance_fun =
 					m_render_ground_truth ? (distance_fun_t)[&](uint32_t n_elements, const Vector3f* positions, float* distances, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
 						if (m_sdf.groundtruth_mode == ESDFGroundTruthMode::SDFBricks) {
 							// linear_kernel(sdf_brick_kernel, 0, stream,
 							// 	n_elements,
@@ -3244,17 +4074,14 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 							m_sdf.triangle_bvh->signed_distance_gpu(
 								n_elements,
 								m_sdf.mesh_sdf_mode,
-								(Vector3f*)positions,
+								positions,
 								distances,
 								m_sdf.triangles_gpu.data(),
 								false,
-								m_stream.get()
+								stream
 							);
 						}
 					} : (distance_fun_t)[&](uint32_t n_elements, const Vector3f* positions, float* distances, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
 						n_elements = next_multiple(n_elements, tcnn::batch_size_granularity);
 						GPUMatrix<float> positions_matrix((float*)positions, 3, n_elements);
 						GPUMatrix<float, RM> distances_matrix(distances, 1, n_elements);
@@ -3265,53 +4092,64 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_render_ground_truth ? (normals_fun_t)[&](uint32_t n_elements, const Vector3f* positions, Vector3f* normals, cudaStream_t stream) {
 						// NO-OP. Normals will automatically be populated by raytrace
 					} : (normals_fun_t)[&](uint32_t n_elements, const Vector3f* positions, Vector3f* normals, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
-
 						n_elements = next_multiple(n_elements, tcnn::batch_size_granularity);
-
 						GPUMatrix<float> positions_matrix((float*)positions, 3, n_elements);
 						GPUMatrix<float> normals_matrix((float*)normals, 3, n_elements);
 						m_network->input_gradient(stream, 0, positions_matrix, normals_matrix);
 					};
 
 				render_sdf(
+					device.stream(),
 					distance_fun,
 					normals_fun,
-					render_buffer,
-					max_res,
+					device.render_buffer_view(),
 					focal_length,
 					camera_matrix0,
 					screen_center,
-					m_stream.get()
+					foveation,
+					visualized_dimension
 				);
 			}
 			break;
 		case ETestbedMode::Image:
-			render_image(render_buffer, m_stream.get());
+			render_image(device.stream(), device.render_buffer_view(), focal_length, camera_matrix0, screen_center, foveation, visualized_dimension);
 			break;
 		case ETestbedMode::Volume:
-			render_volume(render_buffer, focal_length, camera_matrix0, screen_center, m_stream.get());
+			render_volume(device.stream(), device.render_buffer_view(), focal_length, camera_matrix0, screen_center, foveation);
 			break;
 		default:
-			throw std::runtime_error{"Invalid render mode."};
+			// No-op if no mode is active
+			break;
 	}
+}
+
+void Testbed::render_frame_epilogue(
+	cudaStream_t stream,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& prev_camera_matrix,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Foveation& foveation,
+	const Foveation& prev_foveation,
+	CudaRenderBuffer& render_buffer,
+	bool to_srgb
+) {
+	Vector2f focal_length = calc_focal_length(render_buffer.in_resolution(), relative_focal_length, m_fov_axis, m_zoom);
+	Vector2f screen_center = render_screen_center(orig_screen_center);
 
 	render_buffer.set_color_space(m_color_space);
 	render_buffer.set_tonemap_curve(m_tonemap_curve);
 
+	Lens lens = (m_testbed_mode == ETestbedMode::Nerf && m_nerf.render_with_lens_distortion) ? m_nerf.render_lens : Lens{};
+
 	// Prepare DLSS data: motion vectors, scaled depth, exposure
 	if (render_buffer.dlss()) {
 		auto res = render_buffer.in_resolution();
 
-		bool distortion = m_testbed_mode == ETestbedMode::Nerf && m_nerf.render_with_lens_distortion;
-
 		const dim3 threads = { 16, 8, 1 };
 		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 
-		dlss_prep_kernel<<<blocks, threads, 0, m_stream.get()>>>(
-			m_testbed_mode,
+		dlss_prep_kernel<<<blocks, threads, 0, stream>>>(
 			res,
 			render_buffer.spp(),
 			focal_length,
@@ -3319,29 +4157,49 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 			m_parallax_shift,
 			m_snap_to_pixel_centers,
 			render_buffer.depth_buffer(),
+			m_ndc_znear,
+			m_ndc_zfar,
 			camera_matrix0,
-			m_prev_camera,
+			prev_camera_matrix,
 			render_buffer.dlss()->depth(),
 			render_buffer.dlss()->mvec(),
 			render_buffer.dlss()->exposure(),
-			distortion ? m_nerf.render_lens : Lens{},
-			m_scale,
-			m_prev_scale,
-			m_image.pos,
-			m_image.prev_pos,
-			m_image.resolution,
-			m_quilting_dims
+			foveation,
+			prev_foveation,
+			lens
 		);
 
 		render_buffer.set_dlss_sharpening(m_dlss_sharpening);
 	}
 
-	m_prev_camera = camera_matrix0;
-	m_prev_scale = m_scale;
-	m_image.prev_pos = m_image.pos;
+	EColorSpace output_color_space = to_srgb ? EColorSpace::SRGB : EColorSpace::Linear;
+
+	if (m_render_transparency_as_checkerboard) {
+		Matrix<float, 3, 4> checkerboard_transform = Matrix<float, 3, 4>::Identity();
+
+#if NGP_GUI
+		if (m_vr_frame_info && !m_vr_frame_info->views.empty()) {
+			checkerboard_transform = m_vr_frame_info->views[0].pose;
+		}
+#endif
+
+		auto res = render_buffer.in_resolution();
+		const dim3 threads = { 16, 8, 1 };
+		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
+		spherical_checkerboard_kernel<<<blocks, threads, 0, stream>>>(
+			res,
+			focal_length,
+			checkerboard_transform,
+			screen_center,
+			m_parallax_shift,
+			foveation,
+			lens,
+			render_buffer.frame_buffer()
+		);
+	}
 
-	render_buffer.accumulate(m_exposure, m_stream.get());
-	render_buffer.tonemap(m_exposure, m_background_color, to_srgb ? EColorSpace::SRGB : EColorSpace::Linear, m_stream.get());
+	render_buffer.accumulate(m_exposure, stream);
+	render_buffer.tonemap(m_exposure, m_background_color, output_color_space, m_ndc_znear, m_ndc_zfar, stream);
 
 	if (m_testbed_mode == ETestbedMode::Nerf) {
 		// Overlay the ground truth image if requested
@@ -3352,14 +4210,14 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_ground_truth_alpha,
 					Array3f::Constant(m_exposure) + m_nerf.training.cam_exposure[m_nerf.training.view].variable(),
 					m_background_color,
-					to_srgb ? EColorSpace::SRGB : EColorSpace::Linear,
+					output_color_space,
 					metadata.pixels,
 					metadata.image_data_type,
 					metadata.resolution,
 					m_fov_axis,
 					m_zoom,
 					Vector2f::Constant(0.5f),
-					m_stream.get()
+					stream
 				);
 			} else if (m_ground_truth_render_mode == EGroundTruthRenderMode::Depth && metadata.depth) {
 				render_buffer.overlay_depth(
@@ -3370,7 +4228,7 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_fov_axis,
 					m_zoom,
 					Vector2f::Constant(0.5f),
-					m_stream.get()
+					stream
 				);
 			}
 		}
@@ -3385,39 +4243,67 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 			}
 			size_t emap_size = error_map_res.x() * error_map_res.y();
 			err_data += emap_size * m_nerf.training.view;
-			static GPUMemory<float> average_error;
+
+			GPUMemory<float> average_error;
 			average_error.enlarge(1);
 			average_error.memset(0);
 			const float* aligned_err_data_s = (const float*)(((size_t)err_data)&~15);
 			const float* aligned_err_data_e = (const float*)(((size_t)(err_data+emap_size))&~15);
 			size_t reduce_size = aligned_err_data_e - aligned_err_data_s;
-			reduce_sum(aligned_err_data_s, [reduce_size] __device__ (float val) { return max(val,0.f) / (reduce_size); }, average_error.data(), reduce_size, m_stream.get());
+			reduce_sum(aligned_err_data_s, [reduce_size] __device__ (float val) { return max(val,0.f) / (reduce_size); }, average_error.data(), reduce_size, stream);
 			auto const &metadata = m_nerf.training.dataset.metadata[m_nerf.training.view];
-			render_buffer.overlay_false_color(metadata.resolution, to_srgb, m_fov_axis, m_stream.get(), err_data, error_map_res, average_error.data(), m_nerf.training.error_overlay_brightness, m_render_ground_truth);
+			render_buffer.overlay_false_color(metadata.resolution, to_srgb, m_fov_axis, stream, err_data, error_map_res, average_error.data(), m_nerf.training.error_overlay_brightness, m_render_ground_truth);
 		}
 	}
 
-	CUDA_CHECK_THROW(cudaStreamSynchronize(m_stream.get()));
-}
+#if NGP_GUI
+	// If in VR, indicate the hand position and render transparent background
+	if (m_vr_frame_info) {
+		auto& hands = m_vr_frame_info->hands;
 
-void Testbed::determine_autofocus_target_from_pixel(const Vector2i& focus_pixel) {
-	float depth;
-
-	const auto& surface = m_render_surfaces.front();
-	if (surface.depth_buffer()) {
-		auto res = surface.in_resolution();
-		Vector2i depth_pixel = focus_pixel.cast<float>().cwiseProduct(res.cast<float>()).cwiseQuotient(m_window_res.cast<float>()).cast<int>();
-		depth_pixel = depth_pixel.cwiseMin(res).cwiseMax(0);
+		auto res = render_buffer.out_resolution();
+		const dim3 threads = { 16, 8, 1 };
+		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
+		vr_overlay_hands_kernel<<<blocks, threads, 0, stream>>>(
+			res,
+			focal_length.cwiseProduct(render_buffer.out_resolution().cast<float>()).cwiseQuotient(render_buffer.in_resolution().cast<float>()),
+			camera_matrix0,
+			screen_center,
+			m_parallax_shift,
+			foveation,
+			lens,
+			vr_to_world(hands[0].pose.col(3)),
+			hands[0].grab_strength,
+			{hands[0].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f},
+			vr_to_world(hands[1].pose.col(3)),
+			hands[1].grab_strength,
+			{hands[1].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f},
+			0.05f * m_scale, // Hand radius
+			output_color_space,
+			render_buffer.surface()
+		);
+	}
+#endif
+}
 
-		CUDA_CHECK_THROW(cudaMemcpy(&depth, surface.depth_buffer() + depth_pixel.x() + depth_pixel.y() * res.x(), sizeof(float), cudaMemcpyDeviceToHost));
-	} else {
-		depth = m_scale;
+float Testbed::get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const Vector2f& uv) {
+	if (!render_buffer.depth_buffer()) {
+		return m_scale;
 	}
 
-	auto ray = pixel_to_ray_pinhole(0, focus_pixel, m_window_res, calc_focal_length(m_window_res, m_fov_axis, m_zoom), m_smoothed_camera, render_screen_center());
+	float depth;
+	auto res = render_buffer.in_resolution();
+	Vector2i depth_pixel = uv.cwiseProduct(res.cast<float>()).cast<int>().cwiseMin(res).cwiseMax(0);
+	depth_pixel = depth_pixel.cwiseMin(res).cwiseMax(0);
+
+	CUDA_CHECK_THROW(cudaMemcpy(&depth, render_buffer.depth_buffer() + depth_pixel.x() + depth_pixel.y() * res.x(), sizeof(float), cudaMemcpyDeviceToHost));
+	return depth;
+}
 
-	m_autofocus_target = ray.o + ray.d * depth;
-	m_autofocus = true; // If someone shift-clicked, that means they want the AUTOFOCUS
+Vector3f Testbed::get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const Vector2i& pixel) {
+	float depth = get_depth_from_renderbuffer(render_buffer, pixel.cast<float>().cwiseQuotient(m_window_res.cast<float>()));
+	auto ray = pixel_to_ray_pinhole(0, pixel, m_window_res, calc_focal_length(m_window_res, m_relative_focal_length, m_fov_axis, m_zoom), m_smoothed_camera, render_screen_center(m_screen_center));
+	return ray(depth);
 }
 
 void Testbed::autofocus() {
@@ -3593,7 +4479,7 @@ void Testbed::load_snapshot(const fs::path& path) {
 			density_grid[i] = (float)density_grid_fp16[i];
 		});
 
-		if (m_nerf.density_grid.size() == NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * (m_nerf.max_cascade + 1)) {
+		if (m_nerf.density_grid.size() == NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1)) {
 			update_density_grid_mean_and_bitfield(nullptr);
 		} else if (m_nerf.density_grid.size() != 0) {
 			// A size of 0 indicates that the density grid was never populated, which is a valid state of a (yet) untrained model.
@@ -3614,6 +4500,106 @@ void Testbed::load_snapshot(const fs::path& path) {
 	m_loss_scalar.set(m_network_config["snapshot"]["loss"]);
 
 	m_trainer->deserialize(m_network_config["snapshot"]);
+
+	set_all_devices_dirty();
+}
+
+void Testbed::CudaDevice::set_nerf_network(const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network) {
+	m_network = m_nerf_network = nerf_network;
+}
+
+void Testbed::sync_device(CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) {
+	if (!device.dirty()) {
+		return;
+	}
+
+	if (device.is_primary()) {
+		device.data().density_grid_bitfield_ptr = m_nerf.density_grid_bitfield.data();
+		device.data().hidden_area_mask = render_buffer.hidden_area_mask();
+		device.set_dirty(false);
+		return;
+	}
+
+	m_stream.signal(device.stream());
+
+	int active_device = cuda_device();
+	auto guard = device.device_guard();
+
+	device.data().density_grid_bitfield.resize(m_nerf.density_grid_bitfield.size());
+	if (m_nerf.density_grid_bitfield.size() > 0) {
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(device.data().density_grid_bitfield.data(), device.id(), m_nerf.density_grid_bitfield.data(), active_device, m_nerf.density_grid_bitfield.bytes(), device.stream()));
+	}
+
+	device.data().density_grid_bitfield_ptr = device.data().density_grid_bitfield.data();
+
+	if (m_network) {
+		device.data().params.resize(m_network->n_params());
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(device.data().params.data(), device.id(), m_network->inference_params(), active_device, device.data().params.bytes(), device.stream()));
+		device.nerf_network()->set_params(device.data().params.data(), device.data().params.data(), nullptr);
+	}
+
+	if (render_buffer.hidden_area_mask()) {
+		auto ham = std::make_shared<Buffer2D<uint8_t>>(render_buffer.hidden_area_mask()->resolution());
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(ham->data(), device.id(), render_buffer.hidden_area_mask()->data(), active_device, ham->bytes(), device.stream()));
+		device.data().hidden_area_mask = ham;
+	} else {
+		device.data().hidden_area_mask = nullptr;
+	}
+
+	device.set_dirty(false);
+}
+
+// From https://stackoverflow.com/questions/20843271/passing-a-non-copyable-closure-object-to-stdfunction-parameter
+template <class F>
+auto make_copyable_function(F&& f) {
+	using dF = std::decay_t<F>;
+	auto spf = std::make_shared<dF>(std::forward<F>(f));
+	return [spf](auto&&... args) -> decltype(auto) {
+		return (*spf)( decltype(args)(args)... );
+	};
+}
+
+ScopeGuard Testbed::use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) {
+	device.wait_for(stream);
+
+	if (device.is_primary()) {
+		device.set_render_buffer_view(render_buffer.view());
+		return ScopeGuard{[&device, stream]() {
+			device.set_render_buffer_view({});
+			device.signal(stream);
+		}};
+	}
+
+	int active_device = cuda_device();
+	auto guard = device.device_guard();
+
+	size_t n_pixels = render_buffer.in_resolution().prod();
+
+	GPUMemoryArena::Allocation alloc;
+	auto scratch = allocate_workspace_and_distribute<Array4f, float>(device.stream(), &alloc, n_pixels, n_pixels);
+
+	device.set_render_buffer_view({
+		std::get<0>(scratch),
+		std::get<1>(scratch),
+		render_buffer.in_resolution(),
+		render_buffer.spp(),
+		device.data().hidden_area_mask,
+	});
+
+	return ScopeGuard{make_copyable_function([&render_buffer, &device, guard=std::move(guard), alloc=std::move(alloc), active_device, stream]() {
+		// Copy device's render buffer's data onto the original render buffer
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(render_buffer.frame_buffer(), active_device, device.render_buffer_view().frame_buffer, device.id(), render_buffer.in_resolution().prod() * sizeof(Array4f), device.stream()));
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(render_buffer.depth_buffer(), active_device, device.render_buffer_view().depth_buffer, device.id(), render_buffer.in_resolution().prod() * sizeof(float), device.stream()));
+
+		device.set_render_buffer_view({});
+		device.signal(stream);
+	})};
+}
+
+void Testbed::set_all_devices_dirty() {
+	for (auto& device : m_devices) {
+		device.set_dirty(true);
+	}
 }
 
 void Testbed::load_camera_path(const fs::path& path) {
diff --git a/src/testbed_image.cu b/src/testbed_image.cu
index 5be8aa74bf6f11b99d70b8663835956b66137fd0..59d385a7840d561eb389fbdc8bca34cbff606ed2 100644
--- a/src/testbed_image.cu
+++ b/src/testbed_image.cu
@@ -77,14 +77,20 @@ __global__ void stratify2_kernel(uint32_t n_elements, uint32_t log2_batch_size,
 }
 
 __global__ void init_image_coords(
+	uint32_t sample_index,
 	Vector2f* __restrict__ positions,
+	float* __restrict__ depth_buffer,
 	Vector2i resolution,
-	Vector2i image_resolution,
-	float view_dist,
-	Vector2f image_pos,
+	float aspect,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera_matrix,
 	Vector2f screen_center,
+	Vector3f parallax_shift,
 	bool snap_to_pixel_centers,
-	uint32_t sample_index
+	float plane_z,
+	float aperture_size,
+	Foveation foveation,
+	Buffer2DView<const uint8_t> hidden_area_mask
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -93,49 +99,47 @@ __global__ void init_image_coords(
 		return;
 	}
 
-	uint32_t idx = x + resolution.x() * y;
-	positions[idx] = pixel_to_image_uv(
+	// The image is displayed on the plane [0.5, 0.5, 0.5] + [X, Y, 0] to facilitate
+	// a top-down view by default, while permitting general camera movements (for
+	// motion vectors and code sharing with 3D tasks).
+	// Hence: generate rays and intersect that plane.
+	Ray ray = pixel_to_ray(
 		sample_index,
 		{x, y},
 		resolution,
-		image_resolution,
+		focal_length,
+		camera_matrix,
 		screen_center,
-		view_dist,
-		image_pos,
-		snap_to_pixel_centers
+		parallax_shift,
+		snap_to_pixel_centers,
+		0.0f, // near distance
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
 	);
-}
 
-// #define COLOR_SPACE_CONVERT convert to ycrcb experiment - causes some color shift tho it does lead to very slightly sharper edges. not a net win if you like colors :)
-#define CHROMA_SCALE 0.2f
-
-__global__ void colorspace_convert_image_half(Vector2i resolution, const char* __restrict__ texture) {
-	uint32_t x = blockIdx.x * blockDim.x + threadIdx.x;
-	uint32_t y = blockIdx.y * blockDim.y + threadIdx.y;
-	if (x >= resolution.x() || y >= resolution.y()) return;
-	__half val[4];
-	*(int2*)&val[0] = ((int2*)texture)[y * resolution.x() + x];
-	float R=val[0],G=val[1],B=val[2];
-	val[0]=(0.2126f * R + 0.7152f * G + 0.0722f * B);
-	val[1]=((-0.1146f * R - 0.3845f * G + 0.5f * B)+0.f)*CHROMA_SCALE;
-	val[2]=((0.5f * R - 0.4542f * G - 0.0458f * B)+0.f)*CHROMA_SCALE;
-	((int2*)texture)[y * resolution.x() + x] = *(int2*)&val[0];
-}
+	// Intersect the Z=0.5 plane
+	float t = ray.is_valid() ? (0.5f - ray.o.z()) / ray.d.z() : -1.0f;
+
+	uint32_t idx = x + resolution.x() * y;
+	if (t <= 0.0f) {
+		depth_buffer[idx] = MAX_DEPTH();
+		positions[idx] = -Vector2f::Ones();
+		return;
+	}
+
+	Vector2f uv = ray(t).head<2>();
+
+	// Flip from world coordinates where Y goes up to image coordinates where Y goes down.
+	// Also, multiply the x-axis by the image's aspect ratio to make it have the right proportions.
+	uv = (uv - Vector2f::Constant(0.5f)).cwiseProduct(Vector2f{aspect, -1.0f}) + Vector2f::Constant(0.5f);
 
-__global__ void colorspace_convert_image_float(Vector2i resolution, const char* __restrict__ texture) {
-	uint32_t x = blockIdx.x * blockDim.x + threadIdx.x;
-	uint32_t y = blockIdx.y * blockDim.y + threadIdx.y;
-	if (x >= resolution.x() || y >= resolution.y()) return;
-	float val[4];
-	*(float4*)&val[0] = ((float4*)texture)[y * resolution.x() + x];
-	float R=val[0],G=val[1],B=val[2];
-	val[0]=(0.2126f * R + 0.7152f * G + 0.0722f * B);
-	val[1]=((-0.1146f * R - 0.3845f * G + 0.5f * B)+0.f)*CHROMA_SCALE;
-	val[2]=((0.5f * R - 0.4542f * G - 0.0458f * B)+0.f)*CHROMA_SCALE;
-	((float4*)texture)[y * resolution.x() + x] = *(float4*)&val[0];
+	depth_buffer[idx] = t;
+	positions[idx] = uv;
 }
 
-__global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restrict__ positions, const Array3f* __restrict__ colors, Array4f* __restrict__ frame_buffer, float* __restrict__ depth_buffer, bool linear_colors) {
+__global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restrict__ positions, const Array3f* __restrict__ colors, Array4f* __restrict__ frame_buffer, bool linear_colors) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
 
@@ -148,7 +152,6 @@ __global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restri
 	const Vector2f uv = positions[idx];
 	if (uv.x() < 0.0f || uv.x() > 1.0f || uv.y() < 0.0f || uv.y() > 1.0f) {
 		frame_buffer[idx] = Array4f::Zero();
-		depth_buffer[idx] = 1e10f;
 		return;
 	}
 
@@ -158,16 +161,7 @@ __global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restri
 		color = srgb_to_linear(color);
 	}
 
-#ifdef COLOR_SPACE_CONVERT
-	float Y=color.x(), Cb =color.y()*(1.f/CHROMA_SCALE) -0.f, Cr = color.z() * (1.f/CHROMA_SCALE) - 0.f;
-	float R = Y                + 1.5748f * Cr;
-	float G = Y - 0.1873f * Cb - 0.4681 * Cr;
-	float B = Y + 1.8556f * Cb;
-	frame_buffer[idx] = {R, G, B, 1.0f};
-#else
 	frame_buffer[idx] = {color.x(), color.y(), color.z(), 1.0f};
-#endif
-	depth_buffer[idx] = 1.0f;
 }
 
 template <typename T, uint32_t stride>
@@ -291,8 +285,16 @@ void Testbed::train_image(size_t target_batch_size, bool get_loss_scalar, cudaSt
 	m_training_step++;
 }
 
-void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream) {
-	auto res = render_buffer.in_resolution();
+void Testbed::render_image(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
+	const Vector2f& focal_length,
+	const Matrix<float, 3, 4>& camera_matrix,
+	const Vector2f& screen_center,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
+	auto res = render_buffer.resolution;
 
 	// Make sure we have enough memory reserved to render at the requested resolution
 	size_t n_pixels = (size_t)res.x() * res.y();
@@ -300,18 +302,27 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 	m_image.render_coords.enlarge(n_elements);
 	m_image.render_out.enlarge(n_elements);
 
+	float plane_z = m_slice_plane_z + m_scale;
+	float aspect = (float)m_image.resolution.y() / (float)m_image.resolution.x();
+
 	// Generate 2D coords at which to query the network
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	init_image_coords<<<blocks, threads, 0, stream>>>(
+		render_buffer.spp,
 		m_image.render_coords.data(),
+		render_buffer.depth_buffer,
 		res,
-		m_image.resolution,
-		m_scale,
-		m_image.pos,
-		m_screen_center - Vector2f::Constant(0.5f),
+		aspect,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		m_parallax_shift,
 		m_snap_to_pixel_centers,
-		render_buffer.spp()
+		plane_z,
+		m_aperture_size,
+		foveation,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{}
 	);
 
 	// Obtain colors for each 2D coord
@@ -338,10 +349,10 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 	}
 
 	if (!m_render_ground_truth) {
-		if (m_visualized_dimension >= 0) {
+		if (visualized_dimension >= 0) {
 			GPUMatrix<float> positions_matrix((float*)m_image.render_coords.data(), 2, n_elements);
 			GPUMatrix<float> colors_matrix((float*)m_image.render_out.data(), 3, n_elements);
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+			m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 		} else {
 			GPUMatrix<float> positions_matrix((float*)m_image.render_coords.data(), 2, n_elements);
 			GPUMatrix<float> colors_matrix((float*)m_image.render_out.data(), 3, n_elements);
@@ -354,8 +365,7 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 		res,
 		m_image.render_coords.data(),
 		m_image.render_out.data(),
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		render_buffer.frame_buffer,
 		m_image.training.linear_colors
 	);
 }
diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu
index 33a663d4894ad43422b55af9fa2b8d97eec7dd4c..ba563a6d1115ebd38bbe864e1e99bc6c44d326a5 100644
--- a/src/testbed_nerf.cu
+++ b/src/testbed_nerf.cu
@@ -377,34 +377,60 @@ __global__ void mark_untrained_density_grid(const uint32_t n_elements,  float* _
 	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
 	uint32_t z = tcnn::morton3D_invert(pos_idx>>2);
 
-	Vector3f pos = ((Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}) / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
-	float voxel_radius = 0.5f*SQRT3()*scalbnf(1.0f, level) / NERF_GRIDSIZE();
-	int count = 0;
-	for (uint32_t j=0; j < n_training_images; ++j) {
-		if (metadata[j].lens.mode == ELensMode::FTheta || metadata[j].lens.mode == ELensMode::LatLong || metadata[j].lens.mode == ELensMode::OpenCVFisheye) {
-			// not supported for now
-			count++;
-			break;
+	float voxel_size = scalbnf(1.0f / NERF_GRIDSIZE(), level);
+	Vector3f pos = (Vector3f{(float)x, (float)y, (float)z} / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
+
+	Vector3f corners[8] = {
+		pos + Vector3f{0.0f,       0.0f,       0.0f      },
+		pos + Vector3f{voxel_size, 0.0f,       0.0f      },
+		pos + Vector3f{0.0f,       voxel_size, 0.0f      },
+		pos + Vector3f{voxel_size, voxel_size, 0.0f      },
+		pos + Vector3f{0.0f,       0.0f,       voxel_size},
+		pos + Vector3f{voxel_size, 0.0f,       voxel_size},
+		pos + Vector3f{0.0f,       voxel_size, voxel_size},
+		pos + Vector3f{voxel_size, voxel_size, voxel_size},
+	};
+
+	// Number of training views that need to see a voxel cell
+	// at minimum for that cell to be marked trainable.
+	// Floaters can be reduced by increasing this value to 2,
+	// but at the cost of certain reconstruction artifacts.
+	const uint32_t min_count = 1;
+	uint32_t count = 0;
+
+	for (uint32_t j = 0; j < n_training_images && count < min_count; ++j) {
+		const auto& xform = training_xforms[j].start;
+		const auto& m = metadata[j];
+
+		if (m.lens.mode == ELensMode::FTheta || m.lens.mode == ELensMode::LatLong || m.lens.mode == ELensMode::Equirectangular) {
+			// FTheta lenses don't have a forward mapping, so are assumed seeing everything. Latlong and equirect lenses
+			// by definition see everything.
+			++count;
+			continue;
 		}
-		float half_resx = metadata[j].resolution.x() * 0.5f;
-		float half_resy = metadata[j].resolution.y() * 0.5f;
-		Matrix<float, 3, 4> xform = training_xforms[j].start;
-		Vector3f ploc = pos - xform.col(3);
-		float x = ploc.dot(xform.col(0));
-		float y = ploc.dot(xform.col(1));
-		float z = ploc.dot(xform.col(2));
-		if (z > 0.f) {
-			auto focal = metadata[j].focal_length;
-			// TODO - add a box / plane intersection to stop thomas from murdering me
-			if (fabsf(x) - voxel_radius < z / focal.x() * half_resx && fabsf(y) - voxel_radius < z / focal.y() * half_resy) {
-				count++;
-				if (count > 0) break;
+
+		for (uint32_t k = 0; k < 8; ++k) {
+			// Only consider voxel corners in front of the camera
+			Vector3f dir = (corners[k] - xform.col(3)).normalized();
+			if (dir.dot(xform.col(2)) < 1e-4f) {
+				continue;
+			}
+
+			// Check if voxel corner projects onto the image plane, i.e. uv must be in (0, 1)^2
+			Vector2f uv = pos_to_uv(corners[k], m.resolution, m.focal_length, xform, m.principal_point, Vector3f::Zero(), {}, m.lens);
+
+			// `pos_to_uv` is _not_ injective in the presence of lens distortion (which breaks down outside of the image plane).
+			// So we need to check whether the produced uv location generates a ray that matches the ray that we started with.
+			Ray ray = uv_to_ray(0.0f, uv, m.resolution, m.focal_length, xform, m.principal_point, Vector3f::Zero(), 0.0f, 1.0f, 0.0f, {}, {}, m.lens);
+			if ((ray.d.normalized() - dir).norm() < 1e-3f && uv.x() > 0.0f && uv.y() > 0.0f && uv.x() < 1.0f && uv.y() < 1.0f) {
+				++count;
+				break;
 			}
 		}
 	}
 
-	if (clear_visible_voxels || (grid_out[i] < 0) != (count <= 0)) {
-		grid_out[i] = (count > 0) ? 0.f : -1.f;
+	if (clear_visible_voxels || (grid_out[i] < 0) != (count < min_count)) {
+		grid_out[i] = (count >= min_count) ? 0.f : -1.f;
 	}
 }
 
@@ -525,9 +551,10 @@ __global__ void grid_samples_half_to_float(const uint32_t n_elements, BoundingBo
 		Vector3f pos = unwarp_position(coords_in[i].p, aabb);
 		float grid_density = cascaded_grid_at(pos, grid_in, mip_from_pos(pos, max_cascade));
 		if (grid_density < NERF_MIN_OPTICAL_THICKNESS()) {
-			mlp = -10000.f;
+			mlp = -10000.0f;
 		}
 	}
+
 	dst[i] = mlp;
 }
 
@@ -777,8 +804,6 @@ __global__ void composite_kernel_nerf(
 	BoundingBox aabb,
 	float glow_y_cutoff,
 	int glow_mode,
-	const uint32_t n_training_images,
-	const TrainingXForm* __restrict__ training_xforms,
 	Matrix<float, 3, 4> camera_matrix,
 	Vector2f focal_length,
 	float depth_scale,
@@ -1038,18 +1063,18 @@ inline __device__ float pdf_2d(Vector2f sample, uint32_t img, const Vector2i& re
 }
 
 inline __device__ Vector2f nerf_random_image_pos_training(default_rng_t& rng, const Vector2i& resolution, bool snap_to_pixel_centers, const float* __restrict__ cdf_x_cond_y, const float* __restrict__ cdf_y, const Vector2i& cdf_res, uint32_t img, float* __restrict__ pdf = nullptr) {
-	Vector2f xy = random_val_2d(rng);
+	Vector2f uv = random_val_2d(rng);
 
 	if (cdf_x_cond_y) {
-		xy = sample_cdf_2d(xy, img, cdf_res, cdf_x_cond_y, cdf_y, pdf);
+		uv = sample_cdf_2d(uv, img, cdf_res, cdf_x_cond_y, cdf_y, pdf);
 	} else if (pdf) {
 		*pdf = 1.0f;
 	}
 
 	if (snap_to_pixel_centers) {
-		xy = (xy.cwiseProduct(resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(resolution - Vector2i::Ones()).cast<float>() + Vector2f::Constant(0.5f)).cwiseQuotient(resolution.cast<float>());
+		uv = (uv.cwiseProduct(resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(resolution - Vector2i::Ones()).cast<float>() + Vector2f::Constant(0.5f)).cwiseQuotient(resolution.cast<float>());
 	}
-	return xy;
+	return uv;
 }
 
 inline __device__ uint32_t image_idx(uint32_t base_idx, uint32_t n_rays, uint32_t n_rays_total, uint32_t n_training_images, const float* __restrict__ cdf = nullptr, float* __restrict__ pdf = nullptr) {
@@ -1096,8 +1121,7 @@ __global__ void generate_training_samples_nerf(
 	bool snap_to_pixel_centers,
 	bool train_envmap,
 	float cone_angle_constant,
-	const float* __restrict__ distortion_data,
-	const Vector2i distortion_resolution,
+	Buffer2DView<const Vector2f> distortion,
 	const float* __restrict__ cdf_x_cond_y,
 	const float* __restrict__ cdf_y,
 	const float* __restrict__ cdf_img,
@@ -1112,11 +1136,11 @@ __global__ void generate_training_samples_nerf(
 	Eigen::Vector2i resolution = metadata[img].resolution;
 
 	rng.advance(i * N_MAX_RANDOM_SAMPLES_PER_RAY());
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);
 
 	// Negative values indicate masked-away regions
-	size_t pix_idx = pixel_idx(xy, resolution, 0);
-	if (read_rgba(xy, resolution, metadata[img].pixels, metadata[img].image_data_type).x() < 0.0f) {
+	size_t pix_idx = pixel_idx(uv, resolution, 0);
+	if (read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type).x() < 0.0f) {
 		return;
 	}
 
@@ -1129,7 +1153,7 @@ __global__ void generate_training_samples_nerf(
 	const float* extra_dims = extra_dims_gpu + img * n_extra_dims;
 	const Lens lens = metadata[img].lens;
 
-	const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, xy, motionblur_time);
+	const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, motionblur_time);
 
 	Ray ray_unnormalized;
 	const Ray* rays_in_unnormalized = metadata[img].rays;
@@ -1138,16 +1162,16 @@ __global__ void generate_training_samples_nerf(
 		ray_unnormalized = rays_in_unnormalized[pix_idx];
 
 		/* DEBUG - compare the stored rays to the computed ones
-		const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, xy, 0.f);
+		const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, 0.f);
 		Ray ray2;
 		ray2.o = xform.col(3);
-		ray2.d = f_theta_distortion(xy, principal_point, lens);
+		ray2.d = f_theta_distortion(uv, principal_point, lens);
 		ray2.d = (xform.block<3, 3>(0, 0) * ray2.d).normalized();
 		if (i==1000) {
 			printf("\n%d uv %0.3f,%0.3f pixel %0.2f,%0.2f transform from [%0.5f %0.5f %0.5f] to [%0.5f %0.5f %0.5f]\n"
 				" origin    [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n"
 				" direction [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n"
-			, img,xy.x(), xy.y(), xy.x()*resolution.x(), xy.y()*resolution.y(),
+			, img,uv.x(), uv.y(), uv.x()*resolution.x(), uv.y()*resolution.y(),
 				training_xforms[img].start.col(3).x(),training_xforms[img].start.col(3).y(),training_xforms[img].start.col(3).z(),
 				training_xforms[img].end.col(3).x(),training_xforms[img].end.col(3).y(),training_xforms[img].end.col(3).z(),
 				ray_unnormalized.o.x(),ray_unnormalized.o.y(),ray_unnormalized.o.z(),
@@ -1157,31 +1181,10 @@ __global__ void generate_training_samples_nerf(
 		}
 		*/
 	} else {
-		// Rays need to be inferred from the camera matrix
-		ray_unnormalized.o = xform.col(3);
-		if (lens.mode == ELensMode::FTheta) {
-			ray_unnormalized.d = f_theta_undistortion(xy - principal_point, lens.params, {0.f, 0.f, 1.f});
-		} else if (lens.mode == ELensMode::LatLong) {
-			ray_unnormalized.d = latlong_to_dir(xy);
-		} else {
-			ray_unnormalized.d = {
-				(xy.x()-principal_point.x())*resolution.x() / focal_length.x(),
-				(xy.y()-principal_point.y())*resolution.y() / focal_length.y(),
-				1.0f,
-			};
-
-			if (lens.mode == ELensMode::OpenCV) {
-				iterative_opencv_lens_undistortion(lens.params, &ray_unnormalized.d.x(), &ray_unnormalized.d.y());
-			} else if (lens.mode == ELensMode::OpenCVFisheye) {
-				iterative_opencv_fisheye_lens_undistortion(lens.params, &ray_unnormalized.d.x(), &ray_unnormalized.d.y());
-			}
+		ray_unnormalized = uv_to_ray(0, uv, resolution, focal_length, xform, principal_point, Vector3f::Zero(), 0.0f, 1.0f, 0.0f, {}, {}, lens, distortion);
+		if (!ray_unnormalized.is_valid()) {
+			ray_unnormalized = {xform.col(3), xform.col(2)};
 		}
-
-		if (distortion_data) {
-			ray_unnormalized.d.head<2>() += read_image<2>(distortion_data, distortion_resolution, xy);
-		}
-
-		ray_unnormalized.d = (xform.block<3, 3>(0, 0) * ray_unnormalized.d); // NOT normalized
 	}
 
 	Eigen::Vector3f ray_d_normalized = ray_unnormalized.d.normalized();
@@ -1278,7 +1281,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	const uint32_t* __restrict__ rays_counter,
 	float loss_scale,
 	int padded_output_width,
-	const float* __restrict__ envmap_data,
+	Buffer2DView<const Eigen::Array4f> envmap,
 	float* __restrict__ envmap_gradient,
 	const Vector2i envmap_resolution,
 	ELossType envmap_loss_type,
@@ -1374,8 +1377,8 @@ __global__ void compute_loss_kernel_train_nerf(
 	uint32_t img = image_idx(ray_idx, n_rays, n_rays_total, n_training_images, cdf_img, &img_pdf);
 	Eigen::Vector2i resolution = metadata[img].resolution;
 
-	float xy_pdf = 1.0f;
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_cdf_res, img, &xy_pdf);
+	float uv_pdf = 1.0f;
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_cdf_res, img, &uv_pdf);
 	float max_level = max_level_rand_training ? (random_val(rng) * 2.0f) : 1.0f; // Multiply by 2 to ensure 50% of training is at max level
 
 	if (train_with_random_bg_color) {
@@ -1386,16 +1389,16 @@ __global__ void compute_loss_kernel_train_nerf(
 	// Composit background behind envmap
 	Array4f envmap_value;
 	Vector3f dir;
-	if (envmap_data) {
+	if (envmap) {
 		dir = rays_in_unnormalized[i].d.normalized();
-		envmap_value = read_envmap(envmap_data, envmap_resolution, dir);
+		envmap_value = read_envmap(envmap, dir);
 		background_color = envmap_value.head<3>() + background_color * (1.0f - envmap_value.w());
 	}
 
 	Array3f exposure_scale = (0.6931471805599453f * exposure[img]).exp();
-	// Array3f rgbtarget = composit_and_lerp(xy, resolution, img, training_images, background_color, exposure_scale);
-	// Array3f rgbtarget = composit(xy, resolution, img, training_images, background_color, exposure_scale);
-	Array4f texsamp = read_rgba(xy, resolution, metadata[img].pixels, metadata[img].image_data_type);
+	// Array3f rgbtarget = composit_and_lerp(uv, resolution, img, training_images, background_color, exposure_scale);
+	// Array3f rgbtarget = composit(uv, resolution, img, training_images, background_color, exposure_scale);
+	Array4f texsamp = read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type);
 
 	Array3f rgbtarget;
 	if (train_in_linear_colors || color_space == EColorSpace::Linear) {
@@ -1437,9 +1440,9 @@ __global__ void compute_loss_kernel_train_nerf(
 	dloss_doutput += compacted_base * padded_output_width;
 
 	LossAndGradient lg = loss_and_gradient(rgbtarget, rgb_ray, loss_type);
-	lg.loss /= img_pdf * xy_pdf;
+	lg.loss /= img_pdf * uv_pdf;
 
-	float target_depth = rays_in_unnormalized[i].d.norm() * ((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(xy, resolution, metadata[img].depth) : -1.0f);
+	float target_depth = rays_in_unnormalized[i].d.norm() * ((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(uv, resolution, metadata[img].depth) : -1.0f);
 	LossAndGradient lg_depth = loss_and_gradient(Array3f::Constant(target_depth), Array3f::Constant(depth_ray), depth_loss_type);
 	float depth_loss_gradient = target_depth > 0.0f ? depth_supervision_lambda * lg_depth.gradient.x() : 0;
 
@@ -1447,7 +1450,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	// Essentially: variance reduction, but otherwise the same optimization.
 	// We _dont_ want that. If importance sampling is enabled, we _do_ actually want
 	// to change the weighting of the loss function. So don't divide.
-	// lg.gradient /= img_pdf * xy_pdf;
+	// lg.gradient /= img_pdf * uv_pdf;
 
 	float mean_loss = lg.loss.mean();
 	if (loss_output) {
@@ -1455,7 +1458,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	}
 
 	if (error_map) {
-		const Vector2f pos = (xy.cwiseProduct(error_map_res.cast<float>()) - Vector2f::Constant(0.5f)).cwiseMax(0.0f).cwiseMin(error_map_res.cast<float>() - Vector2f::Constant(1.0f + 1e-4f));
+		const Vector2f pos = (uv.cwiseProduct(error_map_res.cast<float>()) - Vector2f::Constant(0.5f)).cwiseMax(0.0f).cwiseMin(error_map_res.cast<float>() - Vector2f::Constant(1.0f + 1e-4f));
 		const Vector2i pos_int = pos.cast<int>();
 		const Vector2f weight = pos - pos_int.cast<float>();
 
@@ -1466,7 +1469,7 @@ __global__ void compute_loss_kernel_train_nerf(
 		};
 
 		if (sharpness_data && aabb.contains(hitpoint)) {
-			Vector2i sharpness_pos = xy.cwiseProduct(sharpness_resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(sharpness_resolution - Vector2i::Constant(1));
+			Vector2i sharpness_pos = uv.cwiseProduct(sharpness_resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(sharpness_resolution - Vector2i::Constant(1));
 			float sharp = sharpness_data[img * sharpness_resolution.prod() + sharpness_pos.y() * sharpness_resolution.x() + sharpness_pos.x()] + 1e-6f;
 
 			// The maximum value of positive floats interpreted in uint format is the same as the maximum value of the floats.
@@ -1549,7 +1552,7 @@ __global__ void compute_loss_kernel_train_nerf(
 
 	if (exposure_gradient) {
 		// Assume symmetric loss
-		Array3f dloss_by_dgt = -lg.gradient / xy_pdf;
+		Array3f dloss_by_dgt = -lg.gradient / uv_pdf;
 
 		if (!train_in_linear_colors) {
 			dloss_by_dgt /= srgb_to_linear_derivative(rgbtarget);
@@ -1656,9 +1659,9 @@ __global__ void compute_cam_gradient_train_nerf(
 	}
 
 	rng.advance(ray_idx * N_MAX_RANDOM_SAMPLES_PER_RAY());
-	float xy_pdf = 1.0f;
+	float uv_pdf = 1.0f;
 
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_res, img, &xy_pdf);
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_res, img, &uv_pdf);
 
 	if (distortion_gradient) {
 		// Projection of the raydir gradient onto the plane normal to raydir,
@@ -1671,14 +1674,14 @@ __global__ void compute_cam_gradient_train_nerf(
 		Vector3f image_plane_gradient = xform.block<3,3>(0,0).inverse() * orthogonal_ray_gradient;
 
 		// Splat the resulting 2D image plane gradient into the distortion params
-		deposit_image_gradient<2>(image_plane_gradient.head<2>() / xy_pdf, distortion_gradient, distortion_gradient_weight, distortion_resolution, xy);
+		deposit_image_gradient<2>(image_plane_gradient.head<2>() / uv_pdf, distortion_gradient, distortion_gradient_weight, distortion_resolution, uv);
 	}
 
 	if (cam_pos_gradient) {
 		// Atomically reduce the ray gradient into the xform gradient
 		NGP_PRAGMA_UNROLL
 		for (uint32_t j = 0; j < 3; ++j) {
-			atomicAdd(&cam_pos_gradient[img][j], ray_gradient.o[j] / xy_pdf);
+			atomicAdd(&cam_pos_gradient[img][j], ray_gradient.o[j] / uv_pdf);
 		}
 	}
 
@@ -1692,7 +1695,7 @@ __global__ void compute_cam_gradient_train_nerf(
 		// Atomically reduce the ray gradient into the xform gradient
 		NGP_PRAGMA_UNROLL
 		for (uint32_t j = 0; j < 3; ++j) {
-			atomicAdd(&cam_rot_gradient[img][j], angle_axis[j] / xy_pdf);
+			atomicAdd(&cam_rot_gradient[img][j], angle_axis[j] / uv_pdf);
 		}
 	}
 }
@@ -1811,15 +1814,14 @@ __global__ void init_rays_with_payload_kernel_nerf(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
+	Foveation foveation,
 	Lens lens,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
-	const float* __restrict__ distortion_data,
-	const Vector2i distortion_resolution,
-	ERenderMode render_mode,
-	Vector2i quilting_dims
+	Buffer2DView<const Eigen::Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
+	Buffer2DView<const Eigen::Vector2f> distortion,
+	ERenderMode render_mode
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -1834,34 +1836,37 @@ __global__ void init_rays_with_payload_kernel_nerf(
 		aperture_size = 0.0;
 	}
 
-	if (quilting_dims != Vector2i::Ones()) {
-		apply_quilting(&x, &y, resolution, parallax_shift, quilting_dims);
-	}
-
-	// TODO: pixel_to_ray also immediately computes u,v for the pixel, so this is somewhat redundant
-	float u = (x + 0.5f) * (1.f / resolution.x());
-	float v = (y + 0.5f) * (1.f / resolution.y());
-	float ray_time = rolling_shutter.x() + rolling_shutter.y() * u + rolling_shutter.z() * v + rolling_shutter.w() * ld_random_val(sample_index, idx * 72239731);
-	Ray ray = pixel_to_ray(
+	Vector2f pixel_offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index);
+	Vector2f uv = Vector2f{(float)x + pixel_offset.x(), (float)y + pixel_offset.y()}.cwiseQuotient(resolution.cast<float>());
+	float ray_time = rolling_shutter.x() + rolling_shutter.y() * uv.x() + rolling_shutter.z() * uv.y() + rolling_shutter.w() * ld_random_val(sample_index, idx * 72239731);
+	Ray ray = uv_to_ray(
 		sample_index,
-		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
+		uv,
+		resolution,
 		focal_length,
 		camera_matrix0 * ray_time + camera_matrix1 * (1.f - ray_time),
 		screen_center,
 		parallax_shift,
-		snap_to_pixel_centers,
 		near_distance,
 		plane_z,
 		aperture_size,
+		foveation,
+		hidden_area_mask,
 		lens,
-		distortion_data,
-		distortion_resolution
+		distortion
 	);
 
 	NerfPayload& payload = payloads[idx];
 	payload.max_weight = 0.0f;
 
+	depth_buffer[idx] = MAX_DEPTH();
+
+	if (!ray.is_valid()) {
+		payload.origin = ray.o;
+		payload.alive = false;
+		return;
+	}
+
 	if (plane_z < 0) {
 		float n = ray.d.norm();
 		payload.origin = ray.o;
@@ -1870,21 +1875,19 @@ __global__ void init_rays_with_payload_kernel_nerf(
 		payload.idx = idx;
 		payload.n_steps = 0;
 		payload.alive = false;
-		depthbuffer[idx] = -plane_z;
+		depth_buffer[idx] = -plane_z;
 		return;
 	}
 
-	depthbuffer[idx] = 1e10f;
-
 	ray.d = ray.d.normalized();
 
-	if (envmap_data) {
-		framebuffer[idx] = read_envmap(envmap_data, envmap_resolution, ray.d);
+	if (envmap) {
+		frame_buffer[idx] = read_envmap(envmap, ray.d);
 	}
 
 	float t = fmaxf(render_aabb.ray_intersect(render_aabb_to_local * ray.o, render_aabb_to_local * ray.d).x(), 0.0f) + 1e-6f;
 
-	if (!render_aabb.contains(render_aabb_to_local * (ray.o + ray.d * t))) {
+	if (!render_aabb.contains(render_aabb_to_local * ray(t))) {
 		payload.origin = ray.o;
 		payload.alive = false;
 		return;
@@ -1892,13 +1895,14 @@ __global__ void init_rays_with_payload_kernel_nerf(
 
 	if (render_mode == ERenderMode::Distortion) {
 		Vector2f offset = Vector2f::Zero();
-		if (distortion_data) {
-			offset += read_image<2>(distortion_data, distortion_resolution, Vector2f((float)x + 0.5f, (float)y + 0.5f).cwiseQuotient(resolution.cast<float>()));
+		if (distortion) {
+			offset += distortion.at_lerp(Vector2f{(float)x + 0.5f, (float)y + 0.5f}.cwiseQuotient(resolution.cast<float>()));
 		}
-		framebuffer[idx].head<3>() = to_rgb(offset * 50.0f);
-		framebuffer[idx].w() = 1.0f;
-		depthbuffer[idx] = 1.0f;
-		payload.origin = ray.o + ray.d * 10000.0f;
+
+		frame_buffer[idx].head<3>() = to_rgb(offset * 50.0f);
+		frame_buffer[idx].w() = 1.0f;
+		depth_buffer[idx] = 1.0f;
+		payload.origin = ray(MAX_DEPTH());
 		payload.alive = false;
 		return;
 	}
@@ -1987,21 +1991,20 @@ void Testbed::NerfTracer::init_rays_from_camera(
 	const Vector4f& rolling_shutter,
 	const Vector2f& screen_center,
 	const Vector3f& parallax_shift,
-	const Vector2i& quilting_dims,
 	bool snap_to_pixel_centers,
 	const BoundingBox& render_aabb,
 	const Matrix3f& render_aabb_to_local,
 	float near_distance,
 	float plane_z,
 	float aperture_size,
+	const Foveation& foveation,
 	const Lens& lens,
-	const float* envmap_data,
-	const Vector2i& envmap_resolution,
-	const float* distortion_data,
-	const Vector2i& distortion_resolution,
+	const Buffer2DView<const Array4f>& envmap,
+	const Buffer2DView<const Vector2f>& distortion,
 	Eigen::Array4f* frame_buffer,
 	float* depth_buffer,
-	uint8_t* grid,
+	const Buffer2DView<const uint8_t>& hidden_area_mask,
+	const uint8_t* grid,
 	int show_accel,
 	float cone_angle_constant,
 	ERenderMode render_mode,
@@ -2029,15 +2032,14 @@ void Testbed::NerfTracer::init_rays_from_camera(
 		near_distance,
 		plane_z,
 		aperture_size,
+		foveation,
 		lens,
-		envmap_data,
-		envmap_resolution,
+		envmap,
 		frame_buffer,
 		depth_buffer,
-		distortion_data,
-		distortion_resolution,
-		render_mode,
-		quilting_dims
+		hidden_area_mask,
+		distortion,
+		render_mode
 	);
 
 	m_n_rays_initialized = resolution.x() * resolution.y();
@@ -2064,8 +2066,6 @@ uint32_t Testbed::NerfTracer::trace(
 	const BoundingBox& render_aabb,
 	const Eigen::Matrix3f& render_aabb_to_local,
 	const BoundingBox& train_aabb,
-	const uint32_t n_training_images,
-	const TrainingXForm* training_xforms,
 	const Vector2f& focal_length,
 	float cone_angle_constant,
 	const uint8_t* grid,
@@ -2156,8 +2156,6 @@ uint32_t Testbed::NerfTracer::trace(
 			train_aabb,
 			glow_y_cutoff,
 			glow_mode,
-			n_training_images,
-			training_xforms,
 			camera_matrix,
 			focal_length,
 			depth_scale,
@@ -2261,51 +2259,59 @@ const float* Testbed::get_inference_extra_dims(cudaStream_t stream) const {
 	return dims_gpu;
 }
 
-void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_res, const Vector2f& focal_length, const Matrix<float, 3, 4>& camera_matrix0, const Matrix<float, 3, 4>& camera_matrix1, const Vector4f& rolling_shutter, const Vector2f& screen_center, cudaStream_t stream) {
+void Testbed::render_nerf(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
+	NerfNetwork<precision_t>& nerf_network,
+	const uint8_t* density_grid_bitfield,
+	const Vector2f& focal_length,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Vector4f& rolling_shutter,
+	const Vector2f& screen_center,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
 	float plane_z = m_slice_plane_z + m_scale;
 	if (m_render_mode == ERenderMode::Slice) {
 		plane_z = -plane_z;
 	}
 
-	ERenderMode render_mode = m_visualized_dimension > -1 ? ERenderMode::EncodingVis : m_render_mode;
+	ERenderMode render_mode = visualized_dimension > -1 ? ERenderMode::EncodingVis : m_render_mode;
 
 	const float* extra_dims_gpu = get_inference_extra_dims(stream);
 
 	NerfTracer tracer;
 
-	// Our motion vector code can't undo f-theta and grid distortions -- so don't render these if DLSS is enabled.
-	bool render_opencv_lens = m_nerf.render_with_lens_distortion && (!render_buffer.dlss() || m_nerf.render_lens.mode == ELensMode::OpenCV || m_nerf.render_lens.mode == ELensMode::OpenCVFisheye);
-	bool render_grid_distortion = m_nerf.render_with_lens_distortion && !render_buffer.dlss();
-
-	Lens lens = render_opencv_lens ? m_nerf.render_lens : Lens{};
-
+	// Our motion vector code can't undo grid distortions -- so don't render grid distortion if DLSS is enabled
+	auto grid_distortion = m_nerf.render_with_lens_distortion && !m_dlss ? m_distortion.inference_view() : Buffer2DView<const Vector2f>{};
+	Lens lens = m_nerf.render_with_lens_distortion ? m_nerf.render_lens : Lens{};
 
 	tracer.init_rays_from_camera(
-		render_buffer.spp(),
-		m_network->padded_output_width(),
-		m_nerf_network->n_extra_dims(),
-		render_buffer.in_resolution(),
+		render_buffer.spp,
+		nerf_network.padded_output_width(),
+		nerf_network.n_extra_dims(),
+		render_buffer.resolution,
 		focal_length,
 		camera_matrix0,
 		camera_matrix1,
 		rolling_shutter,
 		screen_center,
 		m_parallax_shift,
-		m_quilting_dims,
 		m_snap_to_pixel_centers,
 		m_render_aabb,
 		m_render_aabb_to_local,
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
+		foveation,
 		lens,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_grid_distortion ? m_distortion.map->inference_params() : nullptr,
-		m_distortion.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
-		m_nerf.density_grid_bitfield.data(),
+		m_envmap.inference_view(),
+		grid_distortion,
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
+		density_grid_bitfield,
 		m_nerf.show_accel,
 		m_nerf.cone_angle_constant,
 		render_mode,
@@ -2318,20 +2324,18 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 	} else {
 		float depth_scale = 1.0f / m_nerf.training.dataset.scale;
 		n_hit = tracer.trace(
-			*m_nerf_network,
+			nerf_network,
 			m_render_aabb,
 			m_render_aabb_to_local,
 			m_aabb,
-			m_nerf.training.n_images_for_training,
-			m_nerf.training.transforms.data(),
 			focal_length,
 			m_nerf.cone_angle_constant,
-			m_nerf.density_grid_bitfield.data(),
+			density_grid_bitfield,
 			render_mode,
 			camera_matrix1,
 			depth_scale,
 			m_visualized_layer,
-			m_visualized_dimension,
+			visualized_dimension,
 			m_nerf.rgb_activation,
 			m_nerf.density_activation,
 			m_nerf.show_accel,
@@ -2347,19 +2351,19 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 	if (m_render_mode == ERenderMode::Slice) {
 		// Store colors in the normal buffer
 		uint32_t n_elements = next_multiple(n_hit, tcnn::batch_size_granularity);
-		const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + m_nerf_network->n_extra_dims();
-		const uint32_t extra_stride = m_nerf_network->n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct
+		const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + nerf_network.n_extra_dims();
+		const uint32_t extra_stride = nerf_network.n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct
 
 		GPUMatrix<float> positions_matrix{floats_per_coord, n_elements, stream};
 		GPUMatrix<float> rgbsigma_matrix{4, n_elements, stream};
 
 		linear_kernel(generate_nerf_network_inputs_at_current_position, 0, stream, n_hit, m_aabb, rays_hit.payload, PitchedPtr<NerfCoordinate>((NerfCoordinate*)positions_matrix.data(), 1, 0, extra_stride), extra_dims_gpu );
 
-		if (m_visualized_dimension == -1) {
-			m_network->inference(stream, positions_matrix, rgbsigma_matrix);
+		if (visualized_dimension == -1) {
+			nerf_network.inference(stream, positions_matrix, rgbsigma_matrix);
 			linear_kernel(compute_nerf_rgba, 0, stream, n_hit, (Array4f*)rgbsigma_matrix.data(), m_nerf.rgb_activation, m_nerf.density_activation, 0.01f, false);
 		} else {
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, rgbsigma_matrix);
+			nerf_network.visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, rgbsigma_matrix);
 		}
 
 		linear_kernel(shade_kernel_nerf, 0, stream,
@@ -2369,8 +2373,8 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 			rays_hit.payload,
 			m_render_mode,
 			m_nerf.training.linear_colors,
-			render_buffer.frame_buffer(),
-			render_buffer.depth_buffer()
+			render_buffer.frame_buffer,
+			render_buffer.depth_buffer
 		);
 		return;
 	}
@@ -2382,8 +2386,8 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 		rays_hit.payload,
 		m_render_mode,
 		m_nerf.training.linear_colors,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer()
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer
 	);
 
 	if (render_mode == ERenderMode::Cost) {
@@ -2673,7 +2677,20 @@ void Testbed::load_nerf(const fs::path& data_path) {
 			throw std::runtime_error{"NeRF data path must either be a json file or a directory containing json files."};
 		}
 
+		const auto prev_aabb_scale = m_nerf.training.dataset.aabb_scale;
+
 		m_nerf.training.dataset = ngp::load_nerf(json_paths, m_nerf.sharpen);
+
+		// Check if the NeRF network has been previously configured.
+		// If it has not, don't reset it.
+		bool previously_configured = !m_network_config["rgb_network"].is_null()
+		                          && !m_network_config["dir_encoding"].is_null();
+
+		if (m_nerf.training.dataset.aabb_scale != prev_aabb_scale && previously_configured) {
+			// The AABB scale affects network size indirectly. If it changed after loading,
+			// we need to reset the previously configured network to keep a consistent internal state.
+			reset_network();
+		}
 	}
 
 	load_nerf_post();
@@ -2785,6 +2802,40 @@ void Testbed::update_density_grid_mean_and_bitfield(cudaStream_t stream) {
 	for (uint32_t level = 1; level < NERF_CASCADES(); ++level) {
 		linear_kernel(bitfield_max_pool, 0, stream, n_elements/64, m_nerf.get_density_grid_bitfield_mip(level-1), m_nerf.get_density_grid_bitfield_mip(level));
 	}
+
+	set_all_devices_dirty();
+}
+
+__global__ void mark_density_grid_in_sphere_empty_kernel(const uint32_t n_elements, float* density_grid, Vector3f pos, float radius) {
+	const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
+	if (i >= n_elements) return;
+
+	// Random position within that cellq
+	uint32_t level = i / NERF_GRID_N_CELLS();
+	uint32_t pos_idx = i % NERF_GRID_N_CELLS();
+
+	uint32_t x = tcnn::morton3D_invert(pos_idx>>0);
+	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
+	uint32_t z = tcnn::morton3D_invert(pos_idx>>2);
+
+	float cell_radius = scalbnf(SQRT3(), level) / NERF_GRIDSIZE();
+	Vector3f cell_pos = ((Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}) / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
+
+	// Disable if the cell touches the sphere (conservatively, by bounding the cell with a sphere)
+	if ((pos - cell_pos).norm() < radius + cell_radius) {
+		density_grid[i] = -1.0f;
+	}
+}
+
+void Testbed::mark_density_grid_in_sphere_empty(const Vector3f& pos, float radius, cudaStream_t stream) {
+	const uint32_t n_elements = NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1);
+	if (m_nerf.density_grid.size() != n_elements) {
+		return;
+	}
+
+	linear_kernel(mark_density_grid_in_sphere_empty_kernel, 0, stream, n_elements, m_nerf.density_grid.data(), pos, radius);
+
+	update_density_grid_mean_and_bitfield(stream);
 }
 
 void Testbed::NerfCounters::prepare_for_training_steps(cudaStream_t stream) {
@@ -3167,8 +3218,7 @@ void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters&
 		m_nerf.training.snap_to_pixel_centers,
 		m_nerf.training.train_envmap,
 		m_nerf.cone_angle_constant,
-		m_distortion.map->params(),
-		m_distortion.resolution,
+		m_distortion.view(),
 		sample_focal_plane_proportional_to_error ? m_nerf.training.error_map.cdf_x_cond_y.data() : nullptr,
 		sample_focal_plane_proportional_to_error ? m_nerf.training.error_map.cdf_y.data() : nullptr,
 		sample_image_proportional_to_error ? m_nerf.training.error_map.cdf_img.data() : nullptr,
@@ -3197,7 +3247,7 @@ void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters&
 		ray_counter,
 		LOSS_SCALE,
 		padded_output_width,
-		m_envmap.envmap->params(),
+		m_envmap.view(),
 		envmap_gradient,
 		m_envmap.resolution,
 		m_envmap.loss_type,
diff --git a/src/testbed_sdf.cu b/src/testbed_sdf.cu
index aced131525d5198518e14ed2a97ac75e180f6a6d..2332bec11f2b36773872657d6c941edab8f52a81 100644
--- a/src/testbed_sdf.cu
+++ b/src/testbed_sdf.cu
@@ -156,7 +156,7 @@ __global__ void advance_pos_kernel_sdf(
 	BoundingBox aabb,
 	float floor_y,
 	const TriangleOctreeNode* __restrict__ octree_nodes,
-	int max_depth,
+	int max_octree_depth,
 	float distance_scale,
 	float maximum_distance,
 	float k,
@@ -181,8 +181,8 @@ __global__ void advance_pos_kernel_sdf(
 	pos += distance * payload.dir;
 
 	// Skip over regions not covered by the octree
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, pos)) {
-		float octree_distance = (TriangleOctree::ray_intersect(octree_nodes, max_depth, pos, payload.dir) + 1e-6f);
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, pos)) {
+		float octree_distance = (TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, pos, payload.dir) + 1e-6f);
 		distance += octree_distance;
 		pos += octree_distance * payload.dir;
 	}
@@ -242,7 +242,7 @@ __global__ void prepare_shadow_rays(const uint32_t n_elements,
 	SdfPayload* __restrict__ payloads,
 	BoundingBox aabb,
 	const TriangleOctreeNode* __restrict__ octree_nodes,
-	int max_depth
+	int max_octree_depth
 ) {
 	const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
 	if (i >= n_elements) return;
@@ -256,21 +256,21 @@ __global__ void prepare_shadow_rays(const uint32_t n_elements,
 	float t = fmaxf(aabb.ray_intersect(view_pos, dir).x() + 1e-6f, 0.0f);
 	view_pos += t * dir;
 
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, view_pos)) {
-		t = fmaxf(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_depth, view_pos, dir) + 1e-6f);
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, view_pos)) {
+		t = fmaxf(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, view_pos, dir) + 1e-6f);
 		view_pos += t * dir;
 	}
 
 	positions[i] = view_pos;
 
 	if (!aabb.contains(view_pos)) {
-		distances[i] = 10000.0f;
+		distances[i] = MAX_DEPTH();
 		payload.alive = false;
 		min_visibility[i] = 1.0f;
 		return;
 	}
 
-	distances[i] = 10000.0f;
+	distances[i] = MAX_DEPTH();
 	payload.idx = i;
 	payload.dir = dir;
 	payload.n_steps = 0;
@@ -322,13 +322,13 @@ __global__ void shade_kernel_sdf(
 
 	// The normal in memory isn't normalized yet
 	Vector3f normal = normals[i].normalized();
-
 	Vector3f pos = positions[i];
 	bool floor = false;
-	if (pos.y() < floor_y+0.001f && payload.dir.y() < 0.f) {
+	if (pos.y() < floor_y + 0.001f && payload.dir.y() < 0.f) {
 		normal = Vector3f(0.f, 1.f, 0.f);
 		floor = true;
 	}
+
 	Vector3f cam_pos = camera_matrix.col(3);
 	Vector3f cam_fwd = camera_matrix.col(2);
 	float ao = powf(0.92f, payload.n_steps * 0.5f) * (1.f / 0.92f);
@@ -456,12 +456,12 @@ __global__ void scale_to_aabb_kernel(uint32_t n_elements, BoundingBox aabb, Vect
 	inout[i] = aabb.min + inout[i].cwiseProduct(aabb.diag());
 }
 
-__global__ void compare_signs_kernel(uint32_t n_elements, const Vector3f *positions, const float *distances_ref, const float *distances_model, uint32_t *counters, const TriangleOctreeNode* octree_nodes, int max_depth) {
+__global__ void compare_signs_kernel(uint32_t n_elements, const Vector3f *positions, const float *distances_ref, const float *distances_model, uint32_t *counters, const TriangleOctreeNode* octree_nodes, int max_octree_depth) {
 	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
 	if (i >= n_elements) return;
 	bool inside1 = distances_ref[i]<=0.f;
 	bool inside2 = distances_model[i]<=0.f;
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, positions[i])) {
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, positions[i])) {
 		inside2=inside1; // assume, when using the octree, that the model is always correct outside the octree.
 		atomicAdd(&counters[6],1); // outside the octree
 	} else {
@@ -506,12 +506,13 @@ __global__ void init_rays_with_payload_kernel_sdf(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Foveation foveation,
+	Buffer2DView<const Eigen::Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
 	const TriangleOctreeNode* __restrict__ octree_nodes = nullptr,
-	int max_depth = 0
+	int max_octree_depth = 0
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -526,30 +527,54 @@ __global__ void init_rays_with_payload_kernel_sdf(
 		aperture_size = 0.0;
 	}
 
-	Ray ray = pixel_to_ray(sample_index, {x, y}, resolution, focal_length, camera_matrix, screen_center, parallax_shift, snap_to_pixel_centers, near_distance, plane_z, aperture_size);
+	Ray ray = pixel_to_ray(
+		sample_index,
+		{x, y},
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		snap_to_pixel_centers,
+		near_distance,
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
+	);
+
+	distances[idx] = MAX_DEPTH();
+	depth_buffer[idx] = MAX_DEPTH();
+
+	SdfPayload& payload = payloads[idx];
 
-	distances[idx] = 10000.0f;
+	if (!ray.is_valid()) {
+		payload.dir = ray.d;
+		payload.idx = idx;
+		payload.n_steps = 0;
+		payload.alive = false;
+		positions[idx] = ray.o;
+		return;
+	}
 
 	if (plane_z < 0) {
 		float n = ray.d.norm();
-		SdfPayload& payload = payloads[idx];
 		payload.dir = (1.0f/n) * ray.d;
 		payload.idx = idx;
 		payload.n_steps = 0;
 		payload.alive = false;
 		positions[idx] = ray.o - plane_z * ray.d;
-		depthbuffer[idx] = -plane_z;
+		depth_buffer[idx] = -plane_z;
 		return;
 	}
 
-	depthbuffer[idx] = 1e10f;
-
 	ray.d = ray.d.normalized();
 	float t = max(aabb.ray_intersect(ray.o, ray.d).x(), 0.0f);
-	ray.o = ray.o + (t + 1e-6f) * ray.d;
 
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, ray.o)) {
-		t = max(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_depth, ray.o, ray.d));
+	ray.advance(t + 1e-6f);
+
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, ray.o)) {
+		t = max(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, ray.o, ray.d));
 		if (ray.o.y() > floor_y && ray.d.y() < 0.f) {
 			float floor_dist = -(ray.o.y() - floor_y) / ray.d.y();
 			if (floor_dist > 0.f) {
@@ -557,16 +582,15 @@ __global__ void init_rays_with_payload_kernel_sdf(
 			}
 		}
 
-		ray.o = ray.o + (t + 1e-6f) * ray.d;
+		ray.advance(t + 1e-6f);
 	}
 
 	positions[idx] = ray.o;
 
-	if (envmap_data) {
-		framebuffer[idx] = read_envmap(envmap_data, envmap_resolution, ray.d);
+	if (envmap) {
+		frame_buffer[idx] = read_envmap(envmap, ray.d);
 	}
 
-	SdfPayload& payload = payloads[idx];
 	payload.dir = ray.d;
 	payload.idx = idx;
 	payload.n_steps = 0;
@@ -600,10 +624,11 @@ void Testbed::SphereTracer::init_rays_from_camera(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* envmap_data,
-	const Vector2i& envmap_resolution,
+	const Foveation& foveation,
+	const Buffer2DView<const Array4f>& envmap,
 	Array4f* frame_buffer,
 	float* depth_buffer,
+	const Buffer2DView<const uint8_t>& hidden_area_mask,
 	const TriangleOctree* octree,
 	uint32_t n_octree_levels,
 	cudaStream_t stream
@@ -630,10 +655,11 @@ void Testbed::SphereTracer::init_rays_from_camera(
 		near_distance,
 		plane_z,
 		aperture_size,
-		envmap_data,
-		envmap_resolution,
+		foveation,
+		envmap,
 		frame_buffer,
 		depth_buffer,
+		hidden_area_mask,
 		octree ? octree->nodes_gpu() : nullptr,
 		octree ? n_octree_levels : 0
 	);
@@ -840,14 +866,15 @@ void Testbed::FiniteDifferenceNormalsApproximator::normal(uint32_t n_elements, c
 }
 
 void Testbed::render_sdf(
+	cudaStream_t stream,
 	const distance_fun_t& distance_function,
 	const normals_fun_t& normals_function,
-	CudaRenderBuffer& render_buffer,
-	const Vector2i& max_res,
+	const CudaRenderBufferView& render_buffer,
 	const Vector2f& focal_length,
 	const Matrix<float, 3, 4>& camera_matrix,
 	const Vector2f& screen_center,
-	cudaStream_t stream
+	const Foveation& foveation,
+	int visualized_dimension
 ) {
 	float plane_z = m_slice_plane_z + m_scale;
 	if (m_render_mode == ERenderMode::Slice) {
@@ -865,8 +892,8 @@ void Testbed::render_sdf(
 	BoundingBox sdf_bounding_box = m_aabb;
 	sdf_bounding_box.inflate(m_sdf.zero_offset);
 	tracer.init_rays_from_camera(
-		render_buffer.spp(),
-		render_buffer.in_resolution(),
+		render_buffer.spp,
+		render_buffer.resolution,
 		focal_length,
 		camera_matrix,
 		screen_center,
@@ -877,10 +904,11 @@ void Testbed::render_sdf(
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		foveation,
+		m_envmap.inference_view(),
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
 		octree_ptr,
 		n_octree_levels,
 		stream
@@ -912,10 +940,11 @@ void Testbed::render_sdf(
 	} else {
 		n_hit = trace(tracer);
 	}
+
 	RaysSdfSoa& rays_hit = m_render_mode == ERenderMode::Slice || gt_raytrace ? tracer.rays_init() : tracer.rays_hit();
 
 	if (m_render_mode == ERenderMode::Slice) {
-		if (m_visualized_dimension == -1) {
+		if (visualized_dimension == -1) {
 			distance_function(n_hit, rays_hit.pos, rays_hit.distance, stream);
 			extract_dimension_pos_neg_kernel<float><<<n_blocks_linear(n_hit*3), n_threads_linear, 0, stream>>>(n_hit*3, 0, 1, 3, rays_hit.distance, CM, (float*)rays_hit.normal);
 		} else {
@@ -924,11 +953,11 @@ void Testbed::render_sdf(
 
 			GPUMatrix<float> positions_matrix((float*)rays_hit.pos, 3, n_elements);
 			GPUMatrix<float> colors_matrix((float*)rays_hit.normal, 3, n_elements);
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+			m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 		}
 	}
 
-	ERenderMode render_mode = (m_visualized_dimension > -1 || m_render_mode == ERenderMode::Slice) ? ERenderMode::EncodingVis : m_render_mode;
+	ERenderMode render_mode = (visualized_dimension > -1 || m_render_mode == ERenderMode::Slice) ? ERenderMode::EncodingVis : m_render_mode;
 	if (render_mode == ERenderMode::Shade || render_mode == ERenderMode::Normals) {
 		if (m_sdf.analytic_normals || gt_raytrace) {
 			normals_function(n_hit, rays_hit.pos, rays_hit.normal, stream);
@@ -964,6 +993,7 @@ void Testbed::render_sdf(
 				octree_ptr ? octree_ptr->nodes_gpu() : nullptr,
 				n_octree_levels
 			);
+
 			uint32_t n_hit_shadow = trace(shadow_tracer);
 			auto& shadow_rays_hit = gt_raytrace ? shadow_tracer.rays_init() : shadow_tracer.rays_hit();
 
@@ -984,7 +1014,7 @@ void Testbed::render_sdf(
 
 		GPUMatrix<float> positions_matrix((float*)rays_hit.pos, 3, n_elements);
 		GPUMatrix<float> colors_matrix((float*)rays_hit.normal, 3, n_elements);
-		m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+		m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 	}
 
 	linear_kernel(shade_kernel_sdf, 0, stream,
@@ -1000,8 +1030,8 @@ void Testbed::render_sdf(
 		rays_hit.normal,
 		rays_hit.distance,
 		rays_hit.payload,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer()
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer
 	);
 
 	if (render_mode == ERenderMode::Cost) {
diff --git a/src/testbed_volume.cu b/src/testbed_volume.cu
index 10306cb0ca6e0dbca0f59f32923c9c84df79b6a8..c8c7a09a7236b2740b0d6f5b5da5d1496a68673d 100644
--- a/src/testbed_volume.cu
+++ b/src/testbed_volume.cu
@@ -218,10 +218,11 @@ __global__ void init_rays_volume(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Foveation foveation,
+	Buffer2DView<const Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
 	default_rng_t rng,
 	const uint8_t *bitgrid,
 	float distance_scale,
@@ -240,20 +241,42 @@ __global__ void init_rays_volume(
 	if (plane_z < 0) {
 		aperture_size = 0.0;
 	}
-	Ray ray = pixel_to_ray(sample_index, {x, y}, resolution, focal_length, camera_matrix, screen_center, parallax_shift, snap_to_pixel_centers, near_distance, plane_z, aperture_size);
+
+	Ray ray = pixel_to_ray(
+		sample_index,
+		{x, y},
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		snap_to_pixel_centers,
+		near_distance,
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
+	);
+
+	if (!ray.is_valid()) {
+		depth_buffer[idx] = MAX_DEPTH();
+		return;
+	}
+
 	ray.d = ray.d.normalized();
 	auto box_intersection = aabb.ray_intersect(ray.o, ray.d);
 	float t = max(box_intersection.x(), 0.0f);
-	ray.o = ray.o + (t + 1e-6f) * ray.d;
+	ray.advance(t + 1e-6f);
 	float scale = distance_scale / global_majorant;
+
 	if (t >= box_intersection.y() || !walk_to_next_event(rng, aabb, ray.o, ray.d, bitgrid, scale)) {
-		framebuffer[idx] = proc_envmap_render(ray.d, up_dir, sun_dir, sky_col);
-		depthbuffer[idx] = 1e10f;
+		frame_buffer[idx] = proc_envmap_render(ray.d, up_dir, sun_dir, sky_col);
+		depth_buffer[idx] = MAX_DEPTH();
 	} else {
 		uint32_t dstidx = atomicAdd(pixel_counter, 1);
 		positions[dstidx] = ray.o;
 		payloads[dstidx] = {ray.d, Array4f::Constant(0.f), idx};
-		depthbuffer[idx] = camera_matrix.col(2).dot(ray.o - camera_matrix.col(3));
+		depth_buffer[idx] = camera_matrix.col(2).dot(ray.o - camera_matrix.col(3));
 	}
 }
 
@@ -276,8 +299,7 @@ __global__ void volume_render_kernel_gt(
 	float distance_scale,
 	float albedo,
 	float scattering,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer
+	Array4f* __restrict__ frame_buffer
 ) {
 	uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x;
 	if (idx>=n_pixels || idx>=pixel_counter_in[0])
@@ -325,7 +347,7 @@ __global__ void volume_render_kernel_gt(
 	} else {
 		col = proc_envmap_render(dir, up_dir, sun_dir, sky_col);
 	}
-	framebuffer[pixidx] = col;
+	frame_buffer[pixidx] = col;
 }
 
 __global__ void volume_render_kernel_step(
@@ -351,8 +373,7 @@ __global__ void volume_render_kernel_step(
 	float distance_scale,
 	float albedo,
 	float scattering,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Array4f* __restrict__ frame_buffer,
 	bool force_finish_ray
 ) {
 	uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x;
@@ -382,23 +403,25 @@ __global__ void volume_render_kernel_step(
 	payload.col.w() += alpha;
 	if (payload.col.w() > 0.99f || !walk_to_next_event(rng, aabb, pos, dir, bitgrid, scale) || force_finish_ray) {
 		payload.col += (1.f-payload.col.w()) * proc_envmap_render(dir, up_dir, sun_dir, sky_col);
-		framebuffer[pixidx] = payload.col;
+		frame_buffer[pixidx] = payload.col;
 		return;
 	}
 	uint32_t dstidx = atomicAdd(pixel_counter_out, 1);
-	positions_out[dstidx]=pos;
-	payloads_out[dstidx]=payload;
+	positions_out[dstidx] = pos;
+	payloads_out[dstidx] = payload;
 }
 
-void Testbed::render_volume(CudaRenderBuffer& render_buffer,
+void Testbed::render_volume(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
 	const Vector2f& focal_length,
 	const Matrix<float, 3, 4>& camera_matrix,
 	const Vector2f& screen_center,
-	cudaStream_t stream
+	const Foveation& foveation
 ) {
 	float plane_z = m_slice_plane_z + m_scale;
 	float distance_scale = 1.f/std::max(m_volume.inv_distance_scale,0.01f);
-	auto res = render_buffer.in_resolution();
+	auto res = render_buffer.resolution;
 
 	size_t n_pixels = (size_t)res.x() * res.y();
 	for (uint32_t i=0;i<2;++i) {
@@ -413,7 +436,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	init_rays_volume<<<blocks, threads, 0, stream>>>(
-		render_buffer.spp(),
+		render_buffer.spp,
 		m_volume.pos[0].data(),
 		m_volume.payload[0].data(),
 		m_volume.hit_counter.data(),
@@ -427,10 +450,11 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		foveation,
+		m_envmap.inference_view(),
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
 		m_rng,
 		m_volume.bitgrid.data(),
 		distance_scale,
@@ -466,8 +490,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 			distance_scale,
 			std::min(m_volume.albedo,0.995f),
 			m_volume.scattering,
-			render_buffer.frame_buffer(),
-			render_buffer.depth_buffer()
+			render_buffer.frame_buffer
 		);
 		m_rng.advance(n_pixels*256);
 	} else {
@@ -508,8 +531,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 				distance_scale,
 				std::min(m_volume.albedo,0.995f),
 				m_volume.scattering,
-				render_buffer.frame_buffer(),
-				render_buffer.depth_buffer(),
+				render_buffer.frame_buffer,
 				(iter>=max_iter-1)
 			);
 			m_rng.advance(n_pixels*256);