diff --git a/.gitmodules b/.gitmodules
index 37d13bdac9a6f3ffedcf2fb1299600a66bf0adcf..23cf10097cbd7800f269aec2452b9745705a9815 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -28,3 +28,6 @@
 [submodule "dependencies/zlib"]
 	path = dependencies/zlib
 	url = https://github.com/Tom94/zlib
+[submodule "dependencies/OpenXR-SDK"]
+	path = dependencies/OpenXR-SDK
+	url = https://github.com/KhronosGroup/OpenXR-SDK.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8ea96eb9860e7854f7007440a17a0254f72689d7..92d926af9326d69fe84de5750f97da2c6cfd52f4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -119,6 +119,38 @@ if (NGP_BUILD_WITH_GUI)
 		endif()
 	endif()
 
+	# OpenXR
+	if (WIN32)
+		list(APPEND NGP_DEFINITIONS -DXR_USE_PLATFORM_WIN32 -DGLFW_EXPOSE_NATIVE_WGL)
+	elseif (UNIX AND NOT APPLE)
+		list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_GLX)
+		if (JK_USE_WAYLAND)
+			set(PRESENTATION_BACKEND wayland CACHE STRING " " FORCE)
+			set(BUILD_WITH_XLIB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_XCB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_WAYLAND_HEADERS ON CACHE BOOL " " FORCE)
+			list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_WAYLAND -DXR_USE_PLATFORM_WAYLAND)
+		else()
+			set(PRESENTATION_BACKEND xlib CACHE STRING " " FORCE)
+			set(BUILD_WITH_XLIB_HEADERS ON CACHE BOOL " " FORCE)
+			set(BUILD_WITH_XCB_HEADERS OFF CACHE BOOL " " FORCE)
+			set(BUILD_WITH_WAYLAND_HEADERS OFF CACHE BOOL " " FORCE)
+			list(APPEND NGP_DEFINITIONS -DGLFW_EXPOSE_NATIVE_X11 -DXR_USE_PLATFORM_XLIB)
+		endif()
+	else()
+		message(FATAL_ERROR "No OpenXR platform set for this OS")
+	endif()
+
+	add_subdirectory(dependencies/OpenXR-SDK)
+
+	list(APPEND NGP_INCLUDE_DIRECTORIES "dependencies/OpenXR-SDK/include" "dependencies/OpenXR-SDK/src/common")
+	list(APPEND NGP_LIBRARIES openxr_loader)
+	list(APPEND GUI_SOURCES src/openxr_hmd.cu)
+
+	# OpenGL
+	find_package(OpenGL REQUIRED)
+
+	# GLFW
 	set(GLFW_BUILD_EXAMPLES OFF CACHE BOOL " " FORCE)
 	set(GLFW_BUILD_TESTS OFF CACHE BOOL " " FORCE)
 	set(GLFW_BUILD_DOCS OFF CACHE BOOL " " FORCE)
diff --git a/LICENSE.txt b/LICENSE.txt
index 2cfc50b56f42a59b9bbeb82e56b527ad8deace84..34191874786e0339d672cd74d10175bec14b9f9d 100644
--- a/LICENSE.txt
+++ b/LICENSE.txt
@@ -1,4 +1,4 @@
-Copyright (c) 2022, NVIDIA Corporation & affiliates. All rights reserved.
+Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved.
 
 
 NVIDIA Source Code License for instant neural graphics primitives
diff --git a/README.md b/README.md
index bdf06b55e32fda8788d8c62bb361e5f7504cda7f..46451b8dd833284e8a8bf9bf5817cf646ec98f49 100644
--- a/README.md
+++ b/README.md
@@ -179,7 +179,10 @@ Here are the main keyboard controls for the __instant-ngp__ application.
 | Spacebar / C    | Move up / down. |
 | = or + / - or _ | Increase / decrease camera velocity. |
 | E / Shift+E     | Increase / decrease exposure. |
+| Tab             | Toggle menu visibility. |
 | T               | Toggle training. After around two minutes training tends to settle down, so can be toggled off. |
+| { }             | Go to the first/last training set image's camera view. |
+| [ ]             | Go to the previous/next training set image's camera view. |
 | R               | Reload network from file. |
 | Shift+R         | Reset camera. |
 | O               | Toggle visualization or accumulated error map. |
diff --git a/dependencies/OpenXR-SDK b/dependencies/OpenXR-SDK
new file mode 160000
index 0000000000000000000000000000000000000000..e2da9ce83a4388c9622da328bf48548471261290
--- /dev/null
+++ b/dependencies/OpenXR-SDK
@@ -0,0 +1 @@
+Subproject commit e2da9ce83a4388c9622da328bf48548471261290
diff --git a/include/neural-graphics-primitives/camera_path.h b/include/neural-graphics-primitives/camera_path.h
index 8d3fe24792b182606950fba15df410bbb692816c..9d121757233a323295f0f722f2025df523b0dffe 100644
--- a/include/neural-graphics-primitives/camera_path.h
+++ b/include/neural-graphics-primitives/camera_path.h
@@ -132,8 +132,8 @@ struct CameraPath {
 #ifdef NGP_GUI
 	ImGuizmo::MODE m_gizmo_mode = ImGuizmo::LOCAL;
 	ImGuizmo::OPERATION m_gizmo_op = ImGuizmo::TRANSLATE;
-	bool imgui_viz(ImDrawList* list, Eigen::Matrix<float, 4, 4>& view2proj, Eigen::Matrix<float, 4, 4>& world2proj, Eigen::Matrix<float, 4, 4>& world2view, Eigen::Vector2f focal, float aspect);
 	int imgui(char path_filename_buf[1024], float frame_milliseconds, Eigen::Matrix<float, 3, 4>& camera, float slice_plane_z, float scale, float fov, float aperture_size, float bounding_radius, const Eigen::Matrix<float, 3, 4>& first_xform, int glow_mode, float glow_y_cutoff);
+	bool imgui_viz(ImDrawList* list, Eigen::Matrix<float, 4, 4>& view2proj, Eigen::Matrix<float, 4, 4>& world2proj, Eigen::Matrix<float, 4, 4>& world2view, Eigen::Vector2f focal, float aspect, float znear, float zfar);
 #endif
 };
 
diff --git a/include/neural-graphics-primitives/common.h b/include/neural-graphics-primitives/common.h
index 8d3f6068b8462f29fedaa4837f3596d5187d1993..3a632b4d2f51fa1189c3fe9cf1c53b3d3052359f 100644
--- a/include/neural-graphics-primitives/common.h
+++ b/include/neural-graphics-primitives/common.h
@@ -232,8 +232,13 @@ enum class ELensMode : int {
 	FTheta,
 	LatLong,
 	OpenCVFisheye,
+	Equirectangular,
 };
-static constexpr const char* LensModeStr = "Perspective\0OpenCV\0F-Theta\0LatLong\0OpenCV Fisheye\0\0";
+static constexpr const char* LensModeStr = "Perspective\0OpenCV\0F-Theta\0LatLong\0OpenCV Fisheye\0Equirectangular\0\0";
+
+inline bool supports_dlss(ELensMode mode) {
+	return mode == ELensMode::Perspective || mode == ELensMode::OpenCV || mode == ELensMode::OpenCVFisheye;
+}
 
 struct Lens {
 	ELensMode mode = ELensMode::Perspective;
@@ -343,6 +348,47 @@ private:
 	std::chrono::time_point<std::chrono::steady_clock> m_creation_time;
 };
 
+template <typename T>
+struct Buffer2DView {
+	T* data = nullptr;
+	Eigen::Vector2i resolution = Eigen::Vector2i::Zero();
+
+	// Lookup via integer pixel position (no bounds checking)
+	NGP_HOST_DEVICE T at(const Eigen::Vector2i& xy) const {
+		return data[xy.x() + xy.y() * resolution.x()];
+	}
+
+	// Lookup via UV coordinates in [0,1]^2
+	NGP_HOST_DEVICE T at(const Eigen::Vector2f& uv) const {
+		Eigen::Vector2i xy = resolution.cast<float>().cwiseProduct(uv).cast<int>().cwiseMax(0).cwiseMin(resolution - Eigen::Vector2i::Ones());
+		return at(xy);
+	}
+
+	// Lookup via UV coordinates in [0,1]^2 and LERP the nearest texels
+	NGP_HOST_DEVICE T at_lerp(const Eigen::Vector2f& uv) const {
+		const Eigen::Vector2f xy_float = resolution.cast<float>().cwiseProduct(uv);
+		const Eigen::Vector2i xy = xy_float.cast<int>();
+
+		const Eigen::Vector2f weight = xy_float - xy.cast<float>();
+
+		auto read_val = [&](Eigen::Vector2i pos) {
+			pos = pos.cwiseMax(0).cwiseMin(resolution - Eigen::Vector2i::Ones());
+			return at(pos);
+		};
+
+		return (
+			(1 - weight.x()) * (1 - weight.y()) * read_val({xy.x(), xy.y()}) +
+			(weight.x()) * (1 - weight.y()) * read_val({xy.x()+1, xy.y()}) +
+			(1 - weight.x()) * (weight.y()) * read_val({xy.x(), xy.y()+1}) +
+			(weight.x()) * (weight.y()) * read_val({xy.x()+1, xy.y()+1})
+		);
+	}
+
+	NGP_HOST_DEVICE operator bool() const {
+		return data;
+	}
+};
+
 uint8_t* load_stbi(const fs::path& path, int* width, int* height, int* comp, int req_comp);
 float* load_stbi_float(const fs::path& path, int* width, int* height, int* comp, int req_comp);
 uint16_t* load_stbi_16(const fs::path& path, int* width, int* height, int* comp, int req_comp);
diff --git a/include/neural-graphics-primitives/common_device.cuh b/include/neural-graphics-primitives/common_device.cuh
index 389fc3bdab7aff5925f14c5cfe88d8aebbd94e39..361c62bafd81b58ce3eabcbe79f02b472f7767d1 100644
--- a/include/neural-graphics-primitives/common_device.cuh
+++ b/include/neural-graphics-primitives/common_device.cuh
@@ -28,6 +28,52 @@ NGP_NAMESPACE_BEGIN
 using precision_t = tcnn::network_precision_t;
 
 
+// The maximum depth that can be produced when rendering a frame.
+// Chosen somewhat low (rather than std::numeric_limits<float>::infinity())
+// to permit numerically stable reprojection and DLSS operation,
+// even when rendering the infinitely distant horizon.
+inline constexpr __device__ float MAX_DEPTH() { return 16384.0f; }
+
+template <typename T>
+class Buffer2D {
+public:
+	Buffer2D() = default;
+	Buffer2D(const Eigen::Vector2i& resolution) {
+		resize(resolution);
+	}
+
+	T* data() const {
+		return m_data.data();
+	}
+
+	size_t bytes() const {
+		return m_data.bytes();
+	}
+
+	void resize(const Eigen::Vector2i& resolution) {
+		m_data.resize(resolution.prod());
+		m_resolution = resolution;
+	}
+
+	const Eigen::Vector2i& resolution() const {
+		return m_resolution;
+	}
+
+	Buffer2DView<T> view() const {
+		// Row major for now.
+		return {data(), m_resolution};
+	}
+
+	Buffer2DView<const T> const_view() const {
+		// Row major for now.
+		return {data(), m_resolution};
+	}
+
+private:
+	tcnn::GPUMemory<T> m_data;
+	Eigen::Vector2i m_resolution;
+};
+
 inline NGP_HOST_DEVICE float srgb_to_linear(float srgb) {
 	if (srgb <= 0.04045f) {
 		return srgb / 12.92f;
@@ -76,42 +122,9 @@ inline NGP_HOST_DEVICE Eigen::Array3f linear_to_srgb_derivative(const Eigen::Arr
 	return {linear_to_srgb_derivative(x.x()), linear_to_srgb_derivative(x.y()), (linear_to_srgb_derivative(x.z()))};
 }
 
-template <uint32_t N_DIMS, typename T>
-NGP_HOST_DEVICE Eigen::Matrix<float, N_DIMS, 1> read_image(const T* __restrict__ data, const Eigen::Vector2i& resolution, const Eigen::Vector2f& pos) {
-	const Eigen::Vector2f pos_float = Eigen::Vector2f{pos.x() * (float)(resolution.x()-1), pos.y() * (float)(resolution.y()-1)};
-	const Eigen::Vector2i texel = pos_float.cast<int>();
-
-	const Eigen::Vector2f weight = pos_float - texel.cast<float>();
-
-	auto read_val = [&](Eigen::Vector2i pos) {
-		pos.x() = std::max(std::min(pos.x(), resolution.x()-1), 0);
-		pos.y() = std::max(std::min(pos.y(), resolution.y()-1), 0);
-
-		Eigen::Matrix<float, N_DIMS, 1> result;
-		if (std::is_same<T, float>::value) {
-			result = *(Eigen::Matrix<T, N_DIMS, 1>*)&data[(pos.x() + pos.y() * resolution.x()) * N_DIMS];
-		} else {
-			auto val = *(tcnn::vector_t<T, N_DIMS>*)&data[(pos.x() + pos.y() * resolution.x()) * N_DIMS];
-
-			NGP_PRAGMA_UNROLL
-			for (uint32_t i = 0; i < N_DIMS; ++i) {
-				result[i] = (float)val[i];
-			}
-		}
-		return result;
-	};
-
-	return (
-		(1 - weight.x()) * (1 - weight.y()) * read_val({texel.x(), texel.y()}) +
-		(weight.x()) * (1 - weight.y()) * read_val({texel.x()+1, texel.y()}) +
-		(1 - weight.x()) * (weight.y()) * read_val({texel.x(), texel.y()+1}) +
-		(weight.x()) * (weight.y()) * read_val({texel.x()+1, texel.y()+1})
-	);
-}
-
 template <uint32_t N_DIMS, typename T>
 __device__ void deposit_image_gradient(const Eigen::Matrix<float, N_DIMS, 1>& value, T* __restrict__ gradient, T* __restrict__ gradient_weight, const Eigen::Vector2i& resolution, const Eigen::Vector2f& pos) {
-	const Eigen::Vector2f pos_float = Eigen::Vector2f{pos.x() * (resolution.x()-1), pos.y() * (resolution.y()-1)};
+	const Eigen::Vector2f pos_float = resolution.cast<float>().cwiseProduct(pos);
 	const Eigen::Vector2i texel = pos_float.cast<int>();
 
 	const Eigen::Vector2f weight = pos_float - texel.cast<float>();
@@ -142,6 +155,138 @@ __device__ void deposit_image_gradient(const Eigen::Matrix<float, N_DIMS, 1>& va
 	deposit_val(value, (weight.x()) * (weight.y()), {texel.x()+1, texel.y()+1});
 }
 
+struct FoveationPiecewiseQuadratic {
+	NGP_HOST_DEVICE FoveationPiecewiseQuadratic() = default;
+
+	FoveationPiecewiseQuadratic(float center_pixel_steepness, float center_inverse_piecewise_y, float center_radius) {
+		float center_inverse_radius = center_radius * center_pixel_steepness;
+		float left_inverse_piecewise_switch = center_inverse_piecewise_y - center_inverse_radius;
+		float right_inverse_piecewise_switch = center_inverse_piecewise_y + center_inverse_radius;
+
+		if (left_inverse_piecewise_switch < 0) {
+			left_inverse_piecewise_switch = 0.0f;
+		}
+
+		if (right_inverse_piecewise_switch > 1) {
+			right_inverse_piecewise_switch = 1.0f;
+		}
+
+		float am = center_pixel_steepness;
+		float d = (right_inverse_piecewise_switch - left_inverse_piecewise_switch) / center_pixel_steepness / 2;
+
+		// binary search for l,r,bm since analytical is very complex
+		float bm;
+		float m_min = 0.0f;
+		float m_max = 1.0f;
+		for (uint32_t i = 0; i < 20; i++) {
+			float m = (m_min + m_max) / 2.0f;
+			float l = m - d;
+			float r = m + d;
+
+			bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+
+			float l_actual = (left_inverse_piecewise_switch - bm) / am;
+			float r_actual = (right_inverse_piecewise_switch - bm) / am;
+			float m_actual = (l_actual + r_actual) / 2;
+
+			if (m_actual > m) {
+				m_min = m;
+			} else {
+				m_max = m;
+			}
+		}
+
+		float l = (left_inverse_piecewise_switch - bm) / am;
+		float r = (right_inverse_piecewise_switch - bm) / am;
+
+		// Full linear case. Default construction covers this.
+		if ((l == 0.0f && r == 1.0f) || (am == 1.0f)) {
+			return;
+		}
+
+		// write out solution
+		switch_left = l;
+		switch_right = r;
+		this->am = am;
+		al = (am - 1) / (r * r - 2 * r + l * l + 1);
+		bl = (am * (r * r - 2 * r + 1) + am * l * l + (2 - 2 * am) * l) / (r * r - 2 * r + l * l + 1);
+		cl = 0;
+		this->bm = bm = -((am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+		ar = -(am - 1) / (r * r - 2 * r + l * l + 1);
+		br = (am * (r * r + 1) - 2 * r + am * l * l) / (r * r - 2 * r + l * l + 1);
+		cr = -(am * r * r - r * r + (am - 1) * l * l) / (r * r - 2 * r + l * l + 1);
+
+		inv_switch_left = am * switch_left + bm;
+		inv_switch_right = am * switch_right + bm;
+	}
+
+	// left parabola: al * x^2 + bl * x + cl
+	float al = 0.0f, bl = 0.0f, cl = 0.0f;
+	// middle linear piece: am * x + bm.  am should give 1:1 pixel mapping between warped size and full size.
+	float am = 1.0f, bm = 0.0f;
+	// right parabola: al * x^2 + bl * x + cl
+	float ar = 0.0f, br = 0.0f, cr = 0.0f;
+
+	// points where left and right switch over from quadratic to linear
+	float switch_left = 0.0f, switch_right = 1.0f;
+	// same, in inverted space
+	float inv_switch_left = 0.0f, inv_switch_right = 1.0f;
+
+	NGP_HOST_DEVICE float warp(float x) const {
+		x = tcnn::clamp(x, 0.0f, 1.0f);
+		if (x < switch_left) {
+			return al * x * x + bl * x + cl;
+		} else if (x > switch_right) {
+			return ar * x * x + br * x + cr;
+		} else {
+			return am * x + bm;
+		}
+	}
+
+	NGP_HOST_DEVICE float unwarp(float y) const {
+		y = tcnn::clamp(y, 0.0f, 1.0f);
+		if (y < inv_switch_left) {
+			return (std::sqrt(-4 * al * cl + 4 * al * y + bl * bl) - bl) / (2 * al);
+		} else if (y > inv_switch_right) {
+			return (std::sqrt(-4 * ar * cr + 4 * ar * y + br * br) - br) / (2 * ar);
+		} else {
+			return (y - bm) / am;
+		}
+	}
+
+	NGP_HOST_DEVICE float density(float x) const {
+		x = tcnn::clamp(x, 0.0f, 1.0f);
+		if (x < switch_left) {
+			return 2 * al * x + bl;
+		} else if (x > switch_right) {
+			return 2 * ar * x + br;
+		} else {
+			return am;
+		}
+	}
+};
+
+struct Foveation {
+	NGP_HOST_DEVICE Foveation() = default;
+
+	Foveation(const Eigen::Vector2f& center_pixel_steepness, const Eigen::Vector2f& center_inverse_piecewise_y, const Eigen::Vector2f& center_radius)
+	: warp_x{center_pixel_steepness.x(), center_inverse_piecewise_y.x(), center_radius.x()}, warp_y{center_pixel_steepness.y(), center_inverse_piecewise_y.y(), center_radius.y()} {}
+
+	FoveationPiecewiseQuadratic warp_x, warp_y;
+
+	NGP_HOST_DEVICE Eigen::Vector2f warp(const Eigen::Vector2f& x) const {
+		return {warp_x.warp(x.x()), warp_y.warp(x.y())};
+	}
+
+	NGP_HOST_DEVICE Eigen::Vector2f unwarp(const Eigen::Vector2f& y) const {
+		return {warp_x.unwarp(y.x()), warp_y.unwarp(y.y())};
+	}
+
+	NGP_HOST_DEVICE float density(const Eigen::Vector2f& x) const {
+		return warp_x.density(x.x()) * warp_y.density(x.y());
+	}
+};
+
 template <typename T>
 NGP_HOST_DEVICE inline void opencv_lens_distortion_delta(const T* extra_params, const T u, const T v, T* du, T* dv) {
 	const T k1 = extra_params[0];
@@ -292,37 +437,53 @@ inline NGP_HOST_DEVICE Eigen::Vector3f latlong_to_dir(const Eigen::Vector2f& uv)
 	return {sp * ct, st, cp * ct};
 }
 
-inline NGP_HOST_DEVICE Ray pixel_to_ray(
+inline NGP_HOST_DEVICE Eigen::Vector3f equirectangular_to_dir(const Eigen::Vector2f& uv) {
+	float ct = (uv.y() - 0.5f) * 2.0f;
+	float st = std::sqrt(std::max(1.0f - ct * ct, 0.0f));
+	float phi = (uv.x() - 0.5f) * PI() * 2.0f;
+	float sp, cp;
+	sincosf(phi, &sp, &cp);
+	return {sp * st, ct, cp * st};
+}
+
+inline NGP_HOST_DEVICE Ray uv_to_ray(
 	uint32_t spp,
-	const Eigen::Vector2i& pixel,
+	const Eigen::Vector2f& uv,
 	const Eigen::Vector2i& resolution,
 	const Eigen::Vector2f& focal_length,
 	const Eigen::Matrix<float, 3, 4>& camera_matrix,
 	const Eigen::Vector2f& screen_center,
-	const Eigen::Vector3f& parallax_shift,
-	bool snap_to_pixel_centers = false,
+	const Eigen::Vector3f& parallax_shift = Eigen::Vector3f::Zero(),
 	float near_distance = 0.0f,
 	float focus_z = 1.0f,
 	float aperture_size = 0.0f,
+	const Foveation& foveation = {},
+	Buffer2DView<const uint8_t> hidden_area_mask = {},
 	const Lens& lens = {},
-	const float* __restrict__ distortion_grid = nullptr,
-	const Eigen::Vector2i distortion_grid_resolution = Eigen::Vector2i::Zero()
+	Buffer2DView<const Eigen::Vector2f> distortion = {}
 ) {
-	Eigen::Vector2f offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp);
-	Eigen::Vector2f uv = (pixel.cast<float>() + offset).cwiseQuotient(resolution.cast<float>());
+	Eigen::Vector2f warped_uv = foveation.warp(uv);
+
+	// Check the hidden area mask _after_ applying foveation, because foveation will be undone
+	// before blitting to the framebuffer to which the hidden area mask corresponds.
+	if (hidden_area_mask && !hidden_area_mask.at(warped_uv)) {
+		return Ray::invalid();
+	}
 
 	Eigen::Vector3f dir;
 	if (lens.mode == ELensMode::FTheta) {
-		dir = f_theta_undistortion(uv - screen_center, lens.params, {1000.f, 0.f, 0.f});
-		if (dir.x() == 1000.f) {
-			return {{1000.f, 0.f, 0.f}, {0.f, 0.f, 1.f}}; // return a point outside the aabb so the pixel is not rendered
+		dir = f_theta_undistortion(warped_uv - screen_center, lens.params, {0.f, 0.f, 0.f});
+		if (dir == Eigen::Vector3f::Zero()) {
+			return Ray::invalid();
 		}
 	} else if (lens.mode == ELensMode::LatLong) {
-		dir = latlong_to_dir(uv);
+		dir = latlong_to_dir(warped_uv);
+	} else if (lens.mode == ELensMode::Equirectangular) {
+		dir = equirectangular_to_dir(warped_uv);
 	} else {
 		dir = {
-			(uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(),
-			(uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(),
+			(warped_uv.x() - screen_center.x()) * (float)resolution.x() / focal_length.x(),
+			(warped_uv.y() - screen_center.y()) * (float)resolution.y() / focal_length.y(),
 			1.0f
 		};
 
@@ -332,8 +493,9 @@ inline NGP_HOST_DEVICE Ray pixel_to_ray(
 			iterative_opencv_fisheye_lens_undistortion(lens.params, &dir.x(), &dir.y());
 		}
 	}
-	if (distortion_grid) {
-		dir.head<2>() += read_image<2>(distortion_grid, distortion_grid_resolution, uv);
+
+	if (distortion) {
+		dir.head<2>() += distortion.at_lerp(warped_uv);
 	}
 
 	Eigen::Vector3f head_pos = {parallax_shift.x(), parallax_shift.y(), 0.f};
@@ -341,26 +503,60 @@ inline NGP_HOST_DEVICE Ray pixel_to_ray(
 	dir = camera_matrix.block<3, 3>(0, 0) * dir;
 
 	Eigen::Vector3f origin = camera_matrix.block<3, 3>(0, 0) * head_pos + camera_matrix.col(3);
-	
-	if (aperture_size > 0.0f) {
+	if (aperture_size != 0.0f) {
 		Eigen::Vector3f lookat = origin + dir * focus_z;
-		Eigen::Vector2f blur = aperture_size * square2disk_shirley(ld_random_val_2d(spp, (uint32_t)pixel.x() * 19349663 + (uint32_t)pixel.y() * 96925573) * 2.0f - Eigen::Vector2f::Ones());
+		Eigen::Vector2f blur = aperture_size * square2disk_shirley(ld_random_val_2d(spp, uv.cwiseProduct(resolution.cast<float>()).cast<int>().dot(Eigen::Vector2i{19349663, 96925573})) * 2.0f - Eigen::Vector2f::Ones());
 		origin += camera_matrix.block<3, 2>(0, 0) * blur;
 		dir = (lookat - origin) / focus_z;
 	}
-	
-	origin += dir * near_distance;
 
+	origin += dir * near_distance;
 	return {origin, dir};
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
+inline NGP_HOST_DEVICE Ray pixel_to_ray(
+	uint32_t spp,
+	const Eigen::Vector2i& pixel,
+	const Eigen::Vector2i& resolution,
+	const Eigen::Vector2f& focal_length,
+	const Eigen::Matrix<float, 3, 4>& camera_matrix,
+	const Eigen::Vector2f& screen_center,
+	const Eigen::Vector3f& parallax_shift = Eigen::Vector3f::Zero(),
+	bool snap_to_pixel_centers = false,
+	float near_distance = 0.0f,
+	float focus_z = 1.0f,
+	float aperture_size = 0.0f,
+	const Foveation& foveation = {},
+	Buffer2DView<const uint8_t> hidden_area_mask = {},
+	const Lens& lens = {},
+	Buffer2DView<const Eigen::Vector2f> distortion = {}
+) {
+	return uv_to_ray(
+		spp,
+		(pixel.cast<float>() + ld_random_pixel_offset(snap_to_pixel_centers ? 0 : spp)).cwiseQuotient(resolution.cast<float>()),
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		near_distance,
+		focus_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask,
+		lens,
+		distortion
+	);
+}
+
+inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_uv(
 	const Eigen::Vector3f& pos,
 	const Eigen::Vector2i& resolution,
 	const Eigen::Vector2f& focal_length,
 	const Eigen::Matrix<float, 3, 4>& camera_matrix,
 	const Eigen::Vector2f& screen_center,
 	const Eigen::Vector3f& parallax_shift,
+	const Foveation& foveation = {},
 	const Lens& lens = {}
 ) {
 	// Express ray in terms of camera frame
@@ -386,10 +582,23 @@ inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
 	dir.y() += dv;
 
 	Eigen::Vector2f uv = Eigen::Vector2f{dir.x(), dir.y()}.cwiseProduct(focal_length).cwiseQuotient(resolution.cast<float>()) + screen_center;
-	return uv.cwiseProduct(resolution.cast<float>());
+	return foveation.unwarp(uv);
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
+inline NGP_HOST_DEVICE Eigen::Vector2f pos_to_pixel(
+	const Eigen::Vector3f& pos,
+	const Eigen::Vector2i& resolution,
+	const Eigen::Vector2f& focal_length,
+	const Eigen::Matrix<float, 3, 4>& camera_matrix,
+	const Eigen::Vector2f& screen_center,
+	const Eigen::Vector3f& parallax_shift,
+	const Foveation& foveation = {},
+	const Lens& lens = {}
+) {
+	return pos_to_uv(pos, resolution, focal_length, camera_matrix, screen_center, parallax_shift, foveation, lens).cwiseProduct(resolution.cast<float>());
+}
+
+inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector(
 	const uint32_t sample_index,
 	const Eigen::Vector2i& pixel,
 	const Eigen::Vector2i& resolution,
@@ -400,6 +609,8 @@ inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
 	const Eigen::Vector3f& parallax_shift,
 	const bool snap_to_pixel_centers,
 	const float depth,
+	const Foveation& foveation = {},
+	const Foveation& prev_foveation = {},
 	const Lens& lens = {}
 ) {
 	Ray ray = pixel_to_ray(
@@ -414,98 +625,39 @@ inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_3d(
 		0.0f,
 		1.0f,
 		0.0f,
-		lens,
-		nullptr,
-		Eigen::Vector2i::Zero()
+		foveation,
+		{}, // No hidden area mask
+		lens
 	);
 
 	Eigen::Vector2f prev_pixel = pos_to_pixel(
-		ray.o + ray.d * depth,
+		ray(depth),
 		resolution,
 		focal_length,
 		prev_camera,
 		screen_center,
 		parallax_shift,
+		prev_foveation,
 		lens
 	);
 
 	return prev_pixel - (pixel.cast<float>() + ld_random_pixel_offset(sample_index));
 }
 
-inline NGP_HOST_DEVICE Eigen::Vector2f pixel_to_image_uv(
-	const uint32_t sample_index,
-	const Eigen::Vector2i& pixel,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const Eigen::Vector2f& image_pos,
-	const bool snap_to_pixel_centers
-) {
-	Eigen::Vector2f jit = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index);
-	Eigen::Vector2f offset = screen_center.cwiseProduct(resolution.cast<float>()) + jit;
-
-	float y_scale = view_dist;
-	float x_scale = y_scale * resolution.x() / resolution.y();
-
-	return {
-		((x_scale * (pixel.x() + offset.x())) / resolution.x() - view_dist * image_pos.x()) / image_resolution.x() * image_resolution.y(),
-		(y_scale * (pixel.y() + offset.y())) / resolution.y() - view_dist * image_pos.y()
-	};
-}
-
-inline NGP_HOST_DEVICE Eigen::Vector2f image_uv_to_pixel(
-	const Eigen::Vector2f& uv,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const Eigen::Vector2f& image_pos
-) {
-	Eigen::Vector2f offset = screen_center.cwiseProduct(resolution.cast<float>());
-
-	float y_scale = view_dist;
-	float x_scale = y_scale * resolution.x() / resolution.y();
+// Maps view-space depth (physical units) in the range [znear, zfar] hyperbolically to
+// the interval [1, 0]. This is the reverse-z-component of "normalized device coordinates",
+// which are commonly used in rasterization, where linear interpolation in screen space
+// has to be equivalent to linear interpolation in real space (which, in turn, is
+// guaranteed by the hyperbolic mapping of depth). This format is commonly found in
+// z-buffers, and hence expected by downstream image processing functions, such as DLSS
+// and VR reprojection.
+inline NGP_HOST_DEVICE float to_ndc_depth(float z, float n, float f) {
+	// View depth outside of the view frustum leads to output outside of [0, 1]
+	z = tcnn::clamp(z, n, f);
 
-	return {
-		((uv.x() / image_resolution.y() * image_resolution.x()) + view_dist * image_pos.x()) * resolution.x() / x_scale - offset.x(),
-		(uv.y() + view_dist * image_pos.y()) * resolution.y() / y_scale - offset.y()
-	};
-}
-
-inline NGP_HOST_DEVICE Eigen::Vector2f motion_vector_2d(
-	const uint32_t sample_index,
-	const Eigen::Vector2i& pixel,
-	const Eigen::Vector2i& resolution,
-	const Eigen::Vector2i& image_resolution,
-	const Eigen::Vector2f& screen_center,
-	const float view_dist,
-	const float prev_view_dist,
-	const Eigen::Vector2f& image_pos,
-	const Eigen::Vector2f& prev_image_pos,
-	const bool snap_to_pixel_centers
-) {
-	Eigen::Vector2f uv = pixel_to_image_uv(
-		sample_index,
-		pixel,
-		resolution,
-		image_resolution,
-		screen_center,
-		view_dist,
-		image_pos,
-		snap_to_pixel_centers
-	);
-
-	Eigen::Vector2f prev_pixel = image_uv_to_pixel(
-		uv,
-		resolution,
-		image_resolution,
-		screen_center,
-		prev_view_dist,
-		prev_image_pos
-	);
-
-	return prev_pixel - (pixel.cast<float>() + ld_random_pixel_offset(sample_index));
+	float scale = n / (n - f);
+	float bias = -f * scale;
+	return tcnn::clamp((z * scale + bias) / z, 0.0f, 1.0f);
 }
 
 inline NGP_HOST_DEVICE float fov_to_focal_length(int resolution, float degrees) {
@@ -587,7 +739,8 @@ inline NGP_HOST_DEVICE void apply_quilting(uint32_t* x, uint32_t* y, const Eigen
 
 	if (quilting_dims == Eigen::Vector2i{2, 1}) {
 		// Likely VR: parallax_shift.x() is the IPD in this case. The following code centers the camera matrix between both eyes.
-		parallax_shift.x() = idx ? (-0.5f * parallax_shift.x()) : (0.5f * parallax_shift.x());
+		// idx == 0 -> left eye -> -1/2 x
+		parallax_shift.x() = (idx == 0) ? (-0.5f * parallax_shift.x()) : (0.5f * parallax_shift.x());
 	} else {
 		// Likely HoloPlay lenticular display: in this case, `parallax_shift.z()` is the inverse height of the head above the display.
 		// The following code computes the x-offset of views as a function of this.
diff --git a/include/neural-graphics-primitives/dlss.h b/include/neural-graphics-primitives/dlss.h
index dbe86fccca1cd0d911535b340e8c2a05bd406a5b..49d16f95925bd0b032c44440431d7ee72a433ef3 100644
--- a/include/neural-graphics-primitives/dlss.h
+++ b/include/neural-graphics-primitives/dlss.h
@@ -54,16 +54,16 @@ public:
 	virtual EDlssQuality quality() const = 0;
 };
 
-#ifdef NGP_VULKAN
-std::shared_ptr<IDlss> dlss_init(const Eigen::Vector2i& out_resolution);
+class IDlssProvider {
+public:
+	virtual ~IDlssProvider() {}
 
-void vulkan_and_ngx_init();
-size_t dlss_allocated_bytes();
-void vulkan_and_ngx_destroy();
-#else
-inline size_t dlss_allocated_bytes() {
-	return 0;
-}
+	virtual size_t allocated_bytes() const = 0;
+	virtual std::unique_ptr<IDlss> init_dlss(const Eigen::Vector2i& out_resolution) = 0;
+};
+
+#ifdef NGP_VULKAN
+std::shared_ptr<IDlssProvider> init_vulkan_and_ngx();
 #endif
 
 NGP_NAMESPACE_END
diff --git a/include/neural-graphics-primitives/envmap.cuh b/include/neural-graphics-primitives/envmap.cuh
index 6960800719147f13170d8347a6493c83064291cc..7ba6698fec85609c9ef981a5c7a305ace0399c77 100644
--- a/include/neural-graphics-primitives/envmap.cuh
+++ b/include/neural-graphics-primitives/envmap.cuh
@@ -26,31 +26,22 @@
 
 NGP_NAMESPACE_BEGIN
 
-template <typename T>
-__device__ Eigen::Array4f read_envmap(const T* __restrict__ envmap_data, const Eigen::Vector2i envmap_resolution, const Eigen::Vector3f& dir) {
+inline __device__ Eigen::Array4f read_envmap(const Buffer2DView<const Eigen::Array4f>& envmap, const Eigen::Vector3f& dir) {
 	auto dir_cyl = dir_to_spherical_unorm({dir.z(), -dir.x(), dir.y()});
 
-	auto envmap_float = Eigen::Vector2f{dir_cyl.y() * (envmap_resolution.x()-1), dir_cyl.x() * (envmap_resolution.y()-1)};
+	auto envmap_float = Eigen::Vector2f{dir_cyl.y() * (envmap.resolution.x()-1), dir_cyl.x() * (envmap.resolution.y()-1)};
 	Eigen::Vector2i envmap_texel = envmap_float.cast<int>();
 
 	auto weight = envmap_float - envmap_texel.cast<float>();
 
 	auto read_val = [&](Eigen::Vector2i pos) {
 		if (pos.x() < 0) {
-			pos.x() += envmap_resolution.x();
-		} else if (pos.x() >= envmap_resolution.x()) {
-			pos.x() -= envmap_resolution.x();
-		}
-		pos.y() = std::max(std::min(pos.y(), envmap_resolution.y()-1), 0);
-
-		Eigen::Array4f result;
-		if (std::is_same<T, float>::value) {
-			result = *(Eigen::Array4f*)&envmap_data[(pos.x() + pos.y() * envmap_resolution.x()) * 4];
-		} else {
-			auto val = *(tcnn::vector_t<T, 4>*)&envmap_data[(pos.x() + pos.y() * envmap_resolution.x()) * 4];
-			result = {(float)val[0], (float)val[1], (float)val[2], (float)val[3]};
+			pos.x() += envmap.resolution.x();
+		} else if (pos.x() >= envmap.resolution.x()) {
+			pos.x() -= envmap.resolution.x();
 		}
-		return result;
+		pos.y() = std::max(std::min(pos.y(), envmap.resolution.y()-1), 0);
+		return envmap.at(pos);
 	};
 
 	auto result = (
diff --git a/include/neural-graphics-primitives/openxr_hmd.h b/include/neural-graphics-primitives/openxr_hmd.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee442ef9aa5f07abeae2c2f78a06f3d376c59916
--- /dev/null
+++ b/include/neural-graphics-primitives/openxr_hmd.h
@@ -0,0 +1,261 @@
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
+ * and proprietary rights in and to this software, related documentation
+ * and any modifications thereto.  Any use, reproduction, disclosure or
+ * distribution of this software and related documentation without an express
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
+ */
+
+/** @file   openxr_hmd.h
+ *  @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA
+ *  @brief  Wrapper around the OpenXR API, providing access to
+ *          per-eye framebuffers, lens parameters, visible area,
+ *          view, hand, and eye poses, as well as controller inputs.
+ */
+
+#pragma once
+
+#ifdef _WIN32
+#  include <GL/gl3w.h>
+#else
+#  include <GL/glew.h>
+#endif
+
+#define XR_USE_GRAPHICS_API_OPENGL
+
+#include <neural-graphics-primitives/common_device.cuh>
+
+#include <openxr/openxr.h>
+#include <xr_linear.h>
+#include <xr_dependencies.h>
+#include <openxr/openxr_platform.h>
+
+#include <Eigen/Dense>
+
+#include <tiny-cuda-nn/gpu_memory.h>
+
+#include <array>
+#include <memory>
+#include <vector>
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type
+#endif
+
+NGP_NAMESPACE_BEGIN
+
+class OpenXRHMD {
+public:
+	enum class ControlFlow {
+		CONTINUE,
+		RESTART,
+		QUIT,
+	};
+
+	struct FrameInfo {
+		struct View {
+			GLuint framebuffer;
+			XrCompositionLayerProjectionView view{XR_TYPE_COMPOSITION_LAYER_PROJECTION_VIEW};
+			XrCompositionLayerDepthInfoKHR depth_info{XR_TYPE_COMPOSITION_LAYER_DEPTH_INFO_KHR};
+			std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask = nullptr;
+			Eigen::Matrix<float, 3, 4> pose;
+		};
+		struct Hand {
+			Eigen::Matrix<float, 3, 4> pose;
+			bool pose_active = false;
+			Eigen::Vector2f thumbstick = Eigen::Vector2f::Zero();
+			float grab_strength = 0.0f;
+			bool grabbing = false;
+			bool pressing = false;
+			Eigen::Vector3f grab_pos;
+			Eigen::Vector3f prev_grab_pos;
+			Eigen::Vector3f drag() const {
+				return grab_pos - prev_grab_pos;
+			}
+		};
+		std::vector<View> views;
+		Hand hands[2];
+	};
+	using FrameInfoPtr = std::shared_ptr<FrameInfo>;
+
+	// RAII OpenXRHMD with OpenGL
+#if defined(XR_USE_PLATFORM_WIN32)
+	OpenXRHMD(HDC hdc, HGLRC hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	OpenXRHMD(wl_display* display);
+#endif
+
+	virtual ~OpenXRHMD();
+
+	// disallow copy / move
+	OpenXRHMD(const OpenXRHMD&) = delete;
+	OpenXRHMD& operator=(const OpenXRHMD&) = delete;
+	OpenXRHMD(OpenXRHMD&&) = delete;
+	OpenXRHMD& operator=(OpenXRHMD&&) = delete;
+
+	void clear();
+
+	// poll events, handle state changes, return control flow information
+	ControlFlow poll_events();
+
+	// begin OpenXR frame, return views to render
+	FrameInfoPtr begin_frame();
+	// must be called for each begin_frame
+	void end_frame(FrameInfoPtr frame_info, float znear, float zfar);
+
+	// if true call begin_frame and end_frame - does not imply visibility
+	bool must_run_frame_loop() const {
+		return
+			m_session_state == XR_SESSION_STATE_READY ||
+			m_session_state == XR_SESSION_STATE_SYNCHRONIZED ||
+			m_session_state == XR_SESSION_STATE_VISIBLE ||
+			m_session_state == XR_SESSION_STATE_FOCUSED;
+	}
+
+	// if true, VR is being rendered to the HMD
+	bool is_visible() const {
+		// XR_SESSION_STATE_VISIBLE -> app content is shown in HMD
+		// XR_SESSION_STATE_FOCUSED -> VISIBLE + input is send to app
+		return m_session_state == XR_SESSION_STATE_VISIBLE || m_session_state == XR_SESSION_STATE_FOCUSED;
+	}
+
+private:
+	// steps of the init process, called from the constructor
+	void init_create_xr_instance();
+	void init_get_xr_system();
+	void init_configure_xr_views();
+	void init_check_for_xr_blend_mode();
+	void init_xr_actions();
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	void init_open_gl(HDC hdc, HGLRC hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	void init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	void init_open_gl(wl_display* display);
+#endif
+
+	void init_xr_session();
+	void init_xr_spaces();
+	void init_xr_swapchain_open_gl();
+	void init_open_gl_shaders();
+
+	// session state change
+	void session_state_change(XrSessionState state, ControlFlow& flow);
+
+	std::shared_ptr<Buffer2D<uint8_t>> rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view);
+	// system/instance
+	XrInstance m_instance{XR_NULL_HANDLE};
+	XrSystemId m_system_id = {};
+	XrInstanceProperties m_instance_properties = {XR_TYPE_INSTANCE_PROPERTIES};
+	XrSystemProperties m_system_properties = {XR_TYPE_SYSTEM_PROPERTIES};
+	std::vector<XrApiLayerProperties> m_api_layer_properties;
+	std::vector<XrExtensionProperties> m_instance_extension_properties;
+
+	// view and blending
+	XrViewConfigurationType m_view_configuration_type = {};
+	XrViewConfigurationProperties m_view_configuration_properties = {XR_TYPE_VIEW_CONFIGURATION_PROPERTIES};
+	std::vector<XrViewConfigurationView> m_view_configuration_views;
+	std::vector<XrEnvironmentBlendMode> m_environment_blend_modes;
+	XrEnvironmentBlendMode m_environment_blend_mode = {XR_ENVIRONMENT_BLEND_MODE_OPAQUE};
+
+	// actions
+	std::array<XrPath, 2> m_hand_paths;
+	std::array<XrSpace, 2> m_hand_spaces;
+	XrAction m_pose_action{XR_NULL_HANDLE};
+	XrAction m_press_action{XR_NULL_HANDLE};
+	XrAction m_grab_action{XR_NULL_HANDLE};
+
+	// Two separate actions for Xbox controller support
+	std::array<XrAction, 2> m_thumbstick_actions;
+
+	XrActionSet m_action_set{XR_NULL_HANDLE};
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	XrGraphicsBindingOpenGLWin32KHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WIN32_KHR};
+#elif defined(XR_USE_PLATFORM_XLIB)
+	XrGraphicsBindingOpenGLXlibKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_XLIB_KHR};
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	XrGraphicsBindingOpenGLWaylandKHR m_graphics_binding{XR_TYPE_GRAPHICS_BINDING_OPENGL_WAYLAND_KHR};
+#endif
+
+	XrSession m_session{XR_NULL_HANDLE};
+	XrSessionState m_session_state{XR_SESSION_STATE_UNKNOWN};
+
+	// reference space
+	std::vector<XrReferenceSpaceType> m_reference_spaces;
+	XrSpace m_space{XR_NULL_HANDLE};
+	XrExtent2Df m_bounds;
+
+	// swap chains
+	struct Swapchain {
+		Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& xr_instance);
+		Swapchain(const Swapchain&) = delete;
+		Swapchain& operator=(const Swapchain&) = delete;
+		Swapchain(Swapchain&& other) {
+			*this = std::move(other);
+		}
+		Swapchain& operator=(Swapchain&& other) {
+			std::swap(handle, other.handle);
+			std::swap(depth_handle, other.depth_handle);
+			std::swap(width, other.width);
+			std::swap(height, other.height);
+			images_gl = std::move(other.images_gl);
+			depth_images_gl = std::move(other.depth_images_gl);
+			framebuffers_gl = std::move(other.framebuffers_gl);
+			return *this;
+		}
+		virtual ~Swapchain();
+
+		void clear();
+
+		XrSwapchain handle{XR_NULL_HANDLE};
+		XrSwapchain depth_handle{XR_NULL_HANDLE};
+
+		int32_t width = 0;
+		int32_t height = 0;
+		std::vector<XrSwapchainImageOpenGLKHR> images_gl;
+		std::vector<XrSwapchainImageOpenGLKHR> depth_images_gl;
+		std::vector<GLuint> framebuffers_gl;
+	};
+
+	int64_t m_swapchain_rgba_format = 0;
+	std::vector<Swapchain> m_swapchains;
+
+	bool m_supports_composition_layer_depth = false;
+	int64_t m_swapchain_depth_format = 0;
+
+	bool m_supports_hidden_area_mask = false;
+	std::vector<std::shared_ptr<Buffer2D<uint8_t>>> m_hidden_area_masks;
+
+	bool m_supports_eye_tracking = false;
+
+	// frame data
+	XrFrameState m_frame_state{XR_TYPE_FRAME_STATE};
+	FrameInfoPtr m_previous_frame_info;
+
+	GLuint m_hidden_area_mask_program = 0;
+
+	// print more debug info during OpenXRs init:
+	const bool m_print_api_layers = false;
+	const bool m_print_extensions = false;
+	const bool m_print_system_properties = false;
+	const bool m_print_instance_properties = false;
+	const bool m_print_view_configuration_types = false;
+	const bool m_print_view_configuration_properties = false;
+	const bool m_print_view_configuration_view = false;
+	const bool m_print_environment_blend_modes = false;
+	const bool m_print_available_swapchain_formats = false;
+	const bool m_print_reference_spaces = false;
+};
+
+NGP_NAMESPACE_END
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
diff --git a/include/neural-graphics-primitives/random_val.cuh b/include/neural-graphics-primitives/random_val.cuh
index 667c938dd4c37ddecda10b8939ec1cd27c51a96e..e436b18ba23500a909a1ef27fca01b62bb0999dc 100644
--- a/include/neural-graphics-primitives/random_val.cuh
+++ b/include/neural-graphics-primitives/random_val.cuh
@@ -61,11 +61,16 @@ inline __host__ __device__ Eigen::Vector2f dir_to_cylindrical(const Eigen::Vecto
 	return {(cos_theta + 1.0f) / 2.0f, (phi / (2.0f * PI())) + 0.5f};
 }
 
-inline __host__ __device__ Eigen::Vector2f dir_to_spherical_unorm(const Eigen::Vector3f& d) {
+inline __host__ __device__ Eigen::Vector2f dir_to_spherical(const Eigen::Vector3f& d) {
 	const float cos_theta = fminf(fmaxf(d.z(), -1.0f), 1.0f);
 	const float theta = acosf(cos_theta);
 	float phi = std::atan2(d.y(), d.x());
-	return {theta / PI(), (phi / (2.0f * PI()) + 0.5f)};
+	return {theta, phi};
+}
+
+inline __host__ __device__ Eigen::Vector2f dir_to_spherical_unorm(const Eigen::Vector3f& d) {
+	Eigen::Vector2f spherical = dir_to_spherical(d);
+	return {spherical.x() / PI(), (spherical.y() / (2.0f * PI()) + 0.5f)};
 }
 
 template <typename RNG>
diff --git a/include/neural-graphics-primitives/render_buffer.h b/include/neural-graphics-primitives/render_buffer.h
index 2e29f72a6fd896c815645fa81273bbf20e219ff2..0e51364f36f228209b125f6f507c41dc0e2f43fb 100644
--- a/include/neural-graphics-primitives/render_buffer.h
+++ b/include/neural-graphics-primitives/render_buffer.h
@@ -34,7 +34,7 @@ public:
 	virtual cudaSurfaceObject_t surface() = 0;
 	virtual cudaArray_t array() = 0;
 	virtual Eigen::Vector2i resolution() const = 0;
-	virtual void resize(const Eigen::Vector2i&) = 0;
+	virtual void resize(const Eigen::Vector2i&, int n_channels = 4) = 0;
 };
 
 class CudaSurface2D : public SurfaceProvider {
@@ -50,7 +50,7 @@ public:
 
 	void free();
 
-	void resize(const Eigen::Vector2i& size) override;
+	void resize(const Eigen::Vector2i& size, int n_channels) override;
 
 	cudaSurfaceObject_t surface() override {
 		return m_surface;
@@ -65,7 +65,8 @@ public:
 	}
 
 private:
-	Eigen::Vector2i m_size = Eigen::Vector2i::Constant(0);
+	Eigen::Vector2i m_size = Eigen::Vector2i::Zero();
+	int m_n_channels = 0;
 	cudaArray_t m_array;
 	cudaSurfaceObject_t m_surface;
 };
@@ -111,10 +112,10 @@ public:
 
 	void load(const uint8_t* data, Eigen::Vector2i new_size, int n_channels);
 
-	void resize(const Eigen::Vector2i& new_size, int n_channels, bool is_8bit = false);
+	void resize(const Eigen::Vector2i& new_size, int n_channels, bool is_8bit);
 
-	void resize(const Eigen::Vector2i& new_size) override {
-		resize(new_size, 4);
+	void resize(const Eigen::Vector2i& new_size, int n_channels) override {
+		resize(new_size, n_channels, false);
 	}
 
 	Eigen::Vector2i resolution() const override {
@@ -124,7 +125,7 @@ public:
 private:
 	class CUDAMapping {
 	public:
-		CUDAMapping(GLuint texture_id, const Eigen::Vector2i& size);
+		CUDAMapping(GLuint texture_id, const Eigen::Vector2i& size, int n_channels);
 		~CUDAMapping();
 
 		cudaSurfaceObject_t surface() const { return m_cuda_surface ? m_cuda_surface->surface() : m_surface; }
@@ -141,6 +142,7 @@ private:
 		cudaSurfaceObject_t m_surface = {};
 
 		Eigen::Vector2i m_size;
+		int m_n_channels;
 		std::vector<float> m_data_cpu;
 
 		std::unique_ptr<CudaSurface2D> m_cuda_surface;
@@ -157,9 +159,20 @@ private:
 };
 #endif //NGP_GUI
 
+struct CudaRenderBufferView {
+	Eigen::Array4f* frame_buffer = nullptr;
+	float* depth_buffer = nullptr;
+	Eigen::Vector2i resolution = Eigen::Vector2i::Zero();
+	uint32_t spp = 0;
+
+	std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask = nullptr;
+
+	void clear(cudaStream_t stream) const;
+};
+
 class CudaRenderBuffer {
 public:
-	CudaRenderBuffer(const std::shared_ptr<SurfaceProvider>& surf) : m_surface_provider{surf} {}
+	CudaRenderBuffer(const std::shared_ptr<SurfaceProvider>& rgba, const std::shared_ptr<SurfaceProvider>& depth = nullptr) : m_rgba_target{rgba}, m_depth_target{depth} {}
 
 	CudaRenderBuffer(const CudaRenderBuffer& other) = delete;
 	CudaRenderBuffer& operator=(const CudaRenderBuffer& other) = delete;
@@ -167,7 +180,7 @@ public:
 	CudaRenderBuffer& operator=(CudaRenderBuffer&& other) = default;
 
 	cudaSurfaceObject_t surface() {
-		return m_surface_provider->surface();
+		return m_rgba_target->surface();
 	}
 
 	Eigen::Vector2i in_resolution() const {
@@ -175,7 +188,7 @@ public:
 	}
 
 	Eigen::Vector2i out_resolution() const {
-		return m_surface_provider->resolution();
+		return m_rgba_target->resolution();
 	}
 
 	void resize(const Eigen::Vector2i& res);
@@ -204,11 +217,21 @@ public:
 		return m_accumulate_buffer.data();
 	}
 
+	CudaRenderBufferView view() const {
+		return {
+			frame_buffer(),
+			depth_buffer(),
+			in_resolution(),
+			spp(),
+			hidden_area_mask(),
+		};
+	}
+
 	void clear_frame(cudaStream_t stream);
 
 	void accumulate(float exposure, cudaStream_t stream);
 
-	void tonemap(float exposure, const Eigen::Array4f& background_color, EColorSpace output_color_space, cudaStream_t stream);
+	void tonemap(float exposure, const Eigen::Array4f& background_color, EColorSpace output_color_space, float znear, float zfar, cudaStream_t stream);
 
 	void overlay_image(
 		float alpha,
@@ -238,7 +261,7 @@ public:
 	void overlay_false_color(Eigen::Vector2i training_resolution, bool to_srgb, int fov_axis, cudaStream_t stream, const float *error_map, Eigen::Vector2i error_map_resolution, const float *average, float brightness, bool viridis);
 
 	SurfaceProvider& surface_provider() {
-		return *m_surface_provider;
+		return *m_rgba_target;
 	}
 
 	void set_color_space(EColorSpace color_space) {
@@ -255,22 +278,30 @@ public:
 		}
 	}
 
-	void enable_dlss(const Eigen::Vector2i& max_out_res);
+	void enable_dlss(IDlssProvider& dlss_provider, const Eigen::Vector2i& max_out_res);
 	void disable_dlss();
 	void set_dlss_sharpening(float value) {
 		m_dlss_sharpening = value;
 	}
 
-	const std::shared_ptr<IDlss>& dlss() const {
+	const std::unique_ptr<IDlss>& dlss() const {
 		return m_dlss;
 	}
 
+	void set_hidden_area_mask(const std::shared_ptr<Buffer2D<uint8_t>>& hidden_area_mask) {
+		m_hidden_area_mask = hidden_area_mask;
+	}
+
+	const std::shared_ptr<Buffer2D<uint8_t>>& hidden_area_mask() const {
+		return m_hidden_area_mask;
+	}
+
 private:
 	uint32_t m_spp = 0;
 	EColorSpace m_color_space = EColorSpace::Linear;
 	ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity;
 
-	std::shared_ptr<IDlss> m_dlss;
+	std::unique_ptr<IDlss> m_dlss;
 	float m_dlss_sharpening = 0.0f;
 
 	Eigen::Vector2i m_in_resolution = Eigen::Vector2i::Zero();
@@ -279,7 +310,10 @@ private:
 	tcnn::GPUMemory<float> m_depth_buffer;
 	tcnn::GPUMemory<Eigen::Array4f> m_accumulate_buffer;
 
-	std::shared_ptr<SurfaceProvider> m_surface_provider;
+	std::shared_ptr<Buffer2D<uint8_t>> m_hidden_area_mask = nullptr;
+
+	std::shared_ptr<SurfaceProvider> m_rgba_target;
+	std::shared_ptr<SurfaceProvider> m_depth_target;
 };
 
 NGP_NAMESPACE_END
diff --git a/include/neural-graphics-primitives/testbed.h b/include/neural-graphics-primitives/testbed.h
index e6db007b6ce0c39437b4ac51ff1820566c368ed3..9459285110f29eb678dac7cd0c9063db5c2bef51 100644
--- a/include/neural-graphics-primitives/testbed.h
+++ b/include/neural-graphics-primitives/testbed.h
@@ -26,6 +26,10 @@
 #include <neural-graphics-primitives/thread_pool.h>
 #include <neural-graphics-primitives/trainable_buffer.cuh>
 
+#ifdef NGP_GUI
+#  include <neural-graphics-primitives/openxr_hmd.h>
+#endif
+
 #include <tiny-cuda-nn/multi_stream.h>
 #include <tiny-cuda-nn/random.h>
 
@@ -95,10 +99,11 @@ public:
 			float near_distance,
 			float plane_z,
 			float aperture_size,
-			const float* envmap_data,
-			const Eigen::Vector2i& envmap_resolution,
+			const Foveation& foveation,
+			const Buffer2DView<const Eigen::Array4f>& envmap,
 			Eigen::Array4f* frame_buffer,
 			float* depth_buffer,
+			const Buffer2DView<const uint8_t>& hidden_area_mask,
 			const TriangleOctree* octree,
 			uint32_t n_octree_levels,
 			cudaStream_t stream
@@ -151,21 +156,20 @@ public:
 			const Eigen::Vector4f& rolling_shutter,
 			const Eigen::Vector2f& screen_center,
 			const Eigen::Vector3f& parallax_shift,
-			const Eigen::Vector2i& quilting_dims,
 			bool snap_to_pixel_centers,
 			const BoundingBox& render_aabb,
 			const Eigen::Matrix3f& render_aabb_to_local,
 			float near_distance,
 			float plane_z,
 			float aperture_size,
+			const Foveation& foveation,
 			const Lens& lens,
-			const float* envmap_data,
-			const Eigen::Vector2i& envmap_resolution,
-			const float* distortion_data,
-			const Eigen::Vector2i& distortion_resolution,
+			const Buffer2DView<const Eigen::Array4f>& envmap,
+			const Buffer2DView<const Eigen::Vector2f>& distortion,
 			Eigen::Array4f* frame_buffer,
 			float* depth_buffer,
-			uint8_t* grid,
+			const Buffer2DView<const uint8_t>& hidden_area_mask,
+			const uint8_t* grid,
 			int show_accel,
 			float cone_angle_constant,
 			ERenderMode render_mode,
@@ -177,8 +181,6 @@ public:
 			const BoundingBox& render_aabb,
 			const Eigen::Matrix3f& render_aabb_to_local,
 			const BoundingBox& train_aabb,
-			const uint32_t n_training_images,
-			const TrainingXForm* training_xforms,
 			const Eigen::Vector2f& focal_length,
 			float cone_angle_constant,
 			const uint8_t* grid,
@@ -250,7 +252,11 @@ public:
 		int count;
 	};
 
-	static constexpr float LOSS_SCALE = 128.f;
+	// Due to mixed-precision training, small loss values can lead to
+	// underflow (round to zero) in the gradient computations. Hence,
+	// scale the loss (and thereby gradients) up by this factor and
+	// divide it out in the optimizer later on.
+	static constexpr float LOSS_SCALE = 128.0f;
 
 	struct NetworkDims {
 		uint32_t n_input;
@@ -265,30 +271,91 @@ public:
 
 	NetworkDims network_dims() const;
 
-	void render_volume(CudaRenderBuffer& render_buffer,
-		const Eigen::Vector2f& focal_length,
-		const Eigen::Matrix<float, 3, 4>& camera_matrix,
-		const Eigen::Vector2f& screen_center,
-		cudaStream_t stream
-	);
 	void train_volume(size_t target_batch_size, bool get_loss_scalar, cudaStream_t stream);
 	void training_prep_volume(uint32_t batch_size, cudaStream_t stream) {}
 	void load_volume(const fs::path& data_path);
 
+	class CudaDevice;
+
+	const float* get_inference_extra_dims(cudaStream_t stream) const;
+	void render_nerf(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		NerfNetwork<precision_t>& nerf_network,
+		const uint8_t* density_grid_bitfield,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Vector4f& rolling_shutter,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
 	void render_sdf(
+		cudaStream_t stream,
 		const distance_fun_t& distance_function,
 		const normals_fun_t& normals_function,
-		CudaRenderBuffer& render_buffer,
-		const Eigen::Vector2i& max_res,
+		const CudaRenderBufferView& render_buffer,
 		const Eigen::Vector2f& focal_length,
 		const Eigen::Matrix<float, 3, 4>& camera_matrix,
 		const Eigen::Vector2f& screen_center,
-		cudaStream_t stream
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_image(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_volume(
+		cudaStream_t stream,
+		const CudaRenderBufferView& render_buffer,
+		const Eigen::Vector2f& focal_length,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Foveation& foveation
+	);
+
+	void render_frame(
+		cudaStream_t stream,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Matrix<float, 3, 4>& prev_camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Eigen::Vector4f& nerf_rolling_shutter,
+		const Foveation& foveation,
+		const Foveation& prev_foveation,
+		int visualized_dimension,
+		CudaRenderBuffer& render_buffer,
+		bool to_srgb = true,
+		CudaDevice* device = nullptr
+	);
+	void render_frame_main(
+		CudaDevice& device,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix1,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Eigen::Vector4f& nerf_rolling_shutter,
+		const Foveation& foveation,
+		int visualized_dimension
+	);
+	void render_frame_epilogue(
+		cudaStream_t stream,
+		const Eigen::Matrix<float, 3, 4>& camera_matrix0,
+		const Eigen::Matrix<float, 3, 4>& prev_camera_matrix,
+		const Eigen::Vector2f& screen_center,
+		const Eigen::Vector2f& relative_focal_length,
+		const Foveation& foveation,
+		const Foveation& prev_foveation,
+		CudaRenderBuffer& render_buffer,
+		bool to_srgb = true
 	);
-	const float* get_inference_extra_dims(cudaStream_t stream) const;
-	void render_nerf(CudaRenderBuffer& render_buffer, const Eigen::Vector2i& max_res, const Eigen::Vector2f& focal_length, const Eigen::Matrix<float, 3, 4>& camera_matrix0, const Eigen::Matrix<float, 3, 4>& camera_matrix1, const Eigen::Vector4f& rolling_shutter, const Eigen::Vector2f& screen_center, cudaStream_t stream);
-	void render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream);
-	void render_frame(const Eigen::Matrix<float, 3, 4>& camera_matrix0, const Eigen::Matrix<float, 3, 4>& camera_matrix1, const Eigen::Vector4f& nerf_rolling_shutter, CudaRenderBuffer& render_buffer, bool to_srgb = true) ;
 	void visualize_nerf_cameras(ImDrawList* list, const Eigen::Matrix<float, 4, 4>& world2proj);
 	fs::path find_network_config(const fs::path& network_config_path);
 	nlohmann::json load_network_config(const fs::path& network_config_path);
@@ -310,9 +377,10 @@ public:
 	void set_min_level(float minlevel);
 	void set_visualized_dim(int dim);
 	void set_visualized_layer(int layer);
-	void translate_camera(const Eigen::Vector3f& rel);
-	void mouse_drag(const Eigen::Vector2f& rel, int button);
-	void mouse_wheel(Eigen::Vector2f m, float delta);
+	void translate_camera(const Eigen::Vector3f& rel, const Eigen::Matrix3f& rot, bool allow_up_down = true);
+	Eigen::Matrix3f rotation_from_angles(const Eigen::Vector2f& angles) const;
+	void mouse_drag();
+	void mouse_wheel();
 	void load_file(const fs::path& path);
 	void set_nerf_camera_matrix(const Eigen::Matrix<float, 3, 4>& cam);
 	Eigen::Vector3f look_at() const;
@@ -334,6 +402,7 @@ public:
 	void generate_training_samples_sdf(Eigen::Vector3f* positions, float* distances, uint32_t n_to_generate, cudaStream_t stream, bool uniform_only);
 	void update_density_grid_nerf(float decay, uint32_t n_uniform_density_grid_samples, uint32_t n_nonuniform_density_grid_samples, cudaStream_t stream);
 	void update_density_grid_mean_and_bitfield(cudaStream_t stream);
+	void mark_density_grid_in_sphere_empty(const Eigen::Vector3f& pos, float radius, cudaStream_t stream);
 
 	struct NerfCounters {
 		tcnn::GPUMemory<uint32_t> numsteps_counter; // number of steps each ray took
@@ -364,8 +433,8 @@ public:
 	void training_prep_sdf(uint32_t batch_size, cudaStream_t stream);
 	void training_prep_image(uint32_t batch_size, cudaStream_t stream) {}
 	void train(uint32_t batch_size);
-	Eigen::Vector2f calc_focal_length(const Eigen::Vector2i& resolution, int fov_axis, float zoom) const ;
-	Eigen::Vector2f render_screen_center() const ;
+	Eigen::Vector2f calc_focal_length(const Eigen::Vector2i& resolution, const Eigen::Vector2f& relative_focal_length, int fov_axis, float zoom) const;
+	Eigen::Vector2f render_screen_center(const Eigen::Vector2f& screen_center) const;
 	void optimise_mesh_step(uint32_t N_STEPS);
 	void compute_mesh_vertex_colors();
 	tcnn::GPUMemory<float> get_density_on_grid(Eigen::Vector3i res3d, const BoundingBox& aabb, const Eigen::Matrix3f& render_aabb_to_local); // network version (nerf or sdf)
@@ -373,9 +442,8 @@ public:
 	tcnn::GPUMemory<Eigen::Array4f> get_rgba_on_grid(Eigen::Vector3i res3d, Eigen::Vector3f ray_dir, bool voxel_centers, float depth, bool density_as_alpha = false);
 	int marching_cubes(Eigen::Vector3i res3d, const BoundingBox& render_aabb, const Eigen::Matrix3f& render_aabb_to_local, float thresh);
 
-	// Determines the 3d focus point by rendering a little 16x16 depth image around
-	// the mouse cursor and picking the median depth.
-	void determine_autofocus_target_from_pixel(const Eigen::Vector2i& focus_pixel);
+	float get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const Eigen::Vector2f& uv);
+	Eigen::Vector3f get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const Eigen::Vector2i& focus_pixel);
 	void autofocus();
 	size_t n_params();
 	size_t first_encoder_param();
@@ -396,7 +464,10 @@ public:
 	void destroy_window();
 	void apply_camera_smoothing(float elapsed_ms);
 	int find_best_training_view(int default_view);
-	bool begin_frame_and_handle_user_input();
+	bool begin_frame();
+	void handle_user_input();
+	Eigen::Vector3f vr_to_world(const Eigen::Vector3f& pos) const;
+	void begin_vr_frame_and_handle_vr_input();
 	void gather_histograms();
 	void draw_gui();
 	bool frame();
@@ -479,18 +550,18 @@ public:
 	bool m_dynamic_res = true;
 	float m_dynamic_res_target_fps = 20.0f;
 	int m_fixed_res_factor = 8;
-	float m_last_render_res_factor = 1.0f;
 	float m_scale = 1.0;
-	float m_prev_scale = 1.0;
 	float m_aperture_size = 0.0f;
 	Eigen::Vector2f m_relative_focal_length = Eigen::Vector2f::Ones();
 	uint32_t m_fov_axis = 1;
 	float m_zoom = 1.f; // 2d zoom factor (for insets?)
 	Eigen::Vector2f m_screen_center = Eigen::Vector2f::Constant(0.5f); // center of 2d zoom
 
+	float m_ndc_znear = 1.0f / 32.0f;
+	float m_ndc_zfar = 128.0f;
+
 	Eigen::Matrix<float, 3, 4> m_camera = Eigen::Matrix<float, 3, 4>::Zero();
 	Eigen::Matrix<float, 3, 4> m_smoothed_camera = Eigen::Matrix<float, 3, 4>::Zero();
-	Eigen::Matrix<float, 3, 4> m_prev_camera = Eigen::Matrix<float, 3, 4>::Zero();
 	size_t m_render_skip_due_to_lack_of_camera_movement_counter = 0;
 
 	bool m_fps_camera = false;
@@ -505,8 +576,6 @@ public:
 	float m_bounding_radius = 1;
 	float m_exposure = 0.f;
 
-	Eigen::Vector2i m_quilting_dims = Eigen::Vector2i::Ones();
-
 	ERenderMode m_render_mode = ERenderMode::Shade;
 	EMeshRenderMode m_mesh_render_mode = EMeshRenderMode::VertexNormals;
 
@@ -520,19 +589,31 @@ public:
 		void draw(GLuint texture);
 	} m_second_window;
 
+	float m_drag_depth = 1.0f;
+
+	// The VAO will be empty, but we need a valid one for attribute-less rendering
+	GLuint m_blit_vao = 0;
+	GLuint m_blit_program = 0;
+
+	void init_opengl_shaders();
+	void blit_texture(const Foveation& foveation, GLint rgba_texture, GLint rgba_filter_mode, GLint depth_texture, GLint framebuffer, const Eigen::Vector2i& offset, const Eigen::Vector2i& resolution);
+
 	void create_second_window();
 
+	std::unique_ptr<OpenXRHMD> m_hmd;
+	OpenXRHMD::FrameInfoPtr m_vr_frame_info;
+	void init_vr();
+	void set_n_views(size_t n_views);
+
 	std::function<bool()> m_keyboard_event_callback;
 
 	std::shared_ptr<GLTexture> m_pip_render_texture;
-	std::vector<std::shared_ptr<GLTexture>> m_render_textures;
+	std::vector<std::shared_ptr<GLTexture>> m_rgba_render_textures;
+	std::vector<std::shared_ptr<GLTexture>> m_depth_render_textures;
 #endif
 
-	ThreadPool m_thread_pool;
-	std::vector<std::future<void>> m_render_futures;
 
-	std::vector<CudaRenderBuffer> m_render_surfaces;
-	std::unique_ptr<CudaRenderBuffer> m_pip_render_surface;
+	std::unique_ptr<CudaRenderBuffer> m_pip_render_buffer;
 
 	SharedQueue<std::unique_ptr<ICallable>> m_task_queue;
 
@@ -731,8 +812,6 @@ public:
 	};
 
 	struct Image {
-		Eigen::Vector2f pos = Eigen::Vector2f::Constant(0.0f);
-		Eigen::Vector2f prev_pos = Eigen::Vector2f::Constant(0.0f);
 		tcnn::GPUMemory<char> data;
 
 		EDataType type = EDataType::Float;
@@ -785,7 +864,7 @@ public:
 	EColorSpace m_color_space = EColorSpace::Linear;
 	ETonemapCurve m_tonemap_curve = ETonemapCurve::Identity;
 	bool m_dlss = false;
-	bool m_dlss_supported = false;
+	std::shared_ptr<IDlssProvider> m_dlss_provider;
 	float m_dlss_sharpening = 0.0f;
 
 	// 3D stuff
@@ -814,13 +893,35 @@ public:
 	Eigen::Array4f m_background_color = {0.0f, 0.0f, 0.0f, 1.0f};
 
 	bool m_vsync = false;
+	bool m_render_transparency_as_checkerboard = false;
 
 	// Visualization of neuron activations
 	int m_visualized_dimension = -1;
 	int m_visualized_layer = 0;
+
+	struct View {
+		std::shared_ptr<CudaRenderBuffer> render_buffer;
+		Eigen::Vector2i full_resolution = {1, 1};
+		int visualized_dimension = 0;
+
+		Eigen::Matrix<float, 3, 4> camera0 = Eigen::Matrix<float, 3, 4>::Zero();
+		Eigen::Matrix<float, 3, 4> camera1 = Eigen::Matrix<float, 3, 4>::Zero();
+		Eigen::Matrix<float, 3, 4> prev_camera = Eigen::Matrix<float, 3, 4>::Zero();
+
+		Foveation foveation;
+		Foveation prev_foveation;
+
+		Eigen::Vector2f relative_focal_length;
+		Eigen::Vector2f screen_center;
+
+		CudaDevice* device = nullptr;
+	};
+
+	std::vector<View> m_views;
 	Eigen::Vector2i m_n_views = {1, 1};
-	Eigen::Vector2i m_view_size = {1, 1};
-	bool m_single_view = true; // Whether a single neuron is visualized, or all in a tiled grid
+
+	bool m_single_view = true;
+
 	float m_picture_in_picture_res = 0.f; // if non zero, requests a small second picture :)
 
 	struct ImGuiVars {
@@ -835,9 +936,10 @@ public:
 	} m_imgui;
 
 	bool m_visualize_unit_cube = false;
-	bool m_snap_to_pixel_centers = false;
 	bool m_edit_render_aabb = false;
 
+	bool m_snap_to_pixel_centers = false;
+
 	Eigen::Vector3f m_parallax_shift = {0.0f, 0.0f, 0.0f}; // to shift the viewer's origin by some amount in camera space
 
 	// CUDA stuff
@@ -863,6 +965,172 @@ public:
 	bool m_train_encoding = true;
 	bool m_train_network = true;
 
+	class CudaDevice {
+	public:
+		struct Data {
+			tcnn::GPUMemory<uint8_t> density_grid_bitfield;
+			uint8_t* density_grid_bitfield_ptr;
+
+			tcnn::GPUMemory<precision_t> params;
+			std::shared_ptr<Buffer2D<uint8_t>> hidden_area_mask;
+		};
+
+		CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} {
+			auto guard = device_guard();
+			m_stream = std::make_unique<tcnn::StreamAndEvent>();
+			m_data = std::make_unique<Data>();
+			m_render_worker = std::make_unique<ThreadPool>(is_primary ? 0u : 1u);
+		}
+
+		CudaDevice(const CudaDevice&) = delete;
+		CudaDevice& operator=(const CudaDevice&) = delete;
+
+		CudaDevice(CudaDevice&&) = default;
+		CudaDevice& operator=(CudaDevice&&) = default;
+
+		tcnn::ScopeGuard device_guard() {
+			int prev_device = tcnn::cuda_device();
+			if (prev_device == m_id) {
+				return {};
+			}
+
+			tcnn::set_cuda_device(m_id);
+			return tcnn::ScopeGuard{[prev_device]() {
+				tcnn::set_cuda_device(prev_device);
+			}};
+		}
+
+		int id() const {
+			return m_id;
+		}
+
+		bool is_primary() const {
+			return m_is_primary;
+		}
+
+		std::string name() const {
+			return tcnn::cuda_device_name(m_id);
+		}
+
+		int compute_capability() const {
+			return tcnn::cuda_compute_capability(m_id);
+		}
+
+		cudaStream_t stream() const {
+			return m_stream->get();
+		}
+
+		void wait_for(cudaStream_t stream) const {
+			CUDA_CHECK_THROW(cudaEventRecord(m_primary_device_event.event, stream));
+			m_stream->wait_for(m_primary_device_event.event);
+		}
+
+		void signal(cudaStream_t stream) const {
+			m_stream->signal(stream);
+		}
+
+		const CudaRenderBufferView& render_buffer_view() const {
+			return m_render_buffer_view;
+		}
+
+		void set_render_buffer_view(const CudaRenderBufferView& view) {
+			m_render_buffer_view = view;
+		}
+
+		Data& data() const {
+			return *m_data;
+		}
+
+		bool dirty() const {
+			return m_dirty;
+		}
+
+		void set_dirty(bool value) {
+			m_dirty = value;
+		}
+
+		void set_network(const std::shared_ptr<tcnn::Network<float, precision_t>>& network) {
+			m_network = network;
+		}
+
+		void set_nerf_network(const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network);
+
+		const std::shared_ptr<tcnn::Network<float, precision_t>>& network() const {
+			return m_network;
+		}
+
+		const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network() const {
+			return m_nerf_network;
+		}
+
+		void clear() {
+			m_data = std::make_unique<Data>();
+			m_render_buffer_view = {};
+			m_network = {};
+			m_nerf_network = {};
+			set_dirty(true);
+		}
+
+		template <class F>
+		auto enqueue_task(F&& f) -> std::future<std::result_of_t <F()>> {
+			if (is_primary()) {
+				return std::async(std::launch::deferred, std::forward<F>(f));
+			} else {
+				return m_render_worker->enqueue_task(std::forward<F>(f));
+			}
+		}
+
+	private:
+		int m_id;
+		bool m_is_primary;
+		std::unique_ptr<tcnn::StreamAndEvent> m_stream;
+		struct Event {
+			Event() {
+				CUDA_CHECK_THROW(cudaEventCreate(&event));
+			}
+
+			~Event() {
+				cudaEventDestroy(event);
+			}
+
+			Event(const Event&) = delete;
+			Event& operator=(const Event&) = delete;
+			Event(Event&& other) { *this = std::move(other); }
+			Event& operator=(Event&& other) {
+				std::swap(event, other.event);
+				return *this;
+			}
+
+			cudaEvent_t event = {};
+		};
+		Event m_primary_device_event;
+		std::unique_ptr<Data> m_data;
+		CudaRenderBufferView m_render_buffer_view = {};
+
+		std::shared_ptr<tcnn::Network<float, precision_t>> m_network;
+		std::shared_ptr<NerfNetwork<precision_t>> m_nerf_network;
+
+		bool m_dirty = true;
+
+		std::unique_ptr<ThreadPool> m_render_worker;
+	};
+
+	void sync_device(CudaRenderBuffer& render_buffer, CudaDevice& device);
+	tcnn::ScopeGuard use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, CudaDevice& device);
+	void set_all_devices_dirty();
+
+	std::vector<CudaDevice> m_devices;
+	CudaDevice& primary_device() {
+		return m_devices.front();
+	}
+
+	ThreadPool m_thread_pool;
+	std::vector<std::future<void>> m_render_futures;
+
+	bool m_use_aux_devices = false;
+	bool m_foveated_rendering = false;
+	float m_foveated_rendering_max_scaling = 2.0f;
+
 	fs::path m_data_path;
 	fs::path m_network_config_path = "base.json";
 
@@ -876,8 +1144,8 @@ public:
 	uint32_t network_width(uint32_t layer) const;
 	uint32_t network_num_forward_activations() const;
 
-	std::shared_ptr<tcnn::Loss<precision_t>> m_loss;
 	// Network & training stuff
+	std::shared_ptr<tcnn::Loss<precision_t>> m_loss;
 	std::shared_ptr<tcnn::Optimizer<precision_t>> m_optimizer;
 	std::shared_ptr<tcnn::Encoding<precision_t>> m_encoding;
 	std::shared_ptr<tcnn::Network<float, precision_t>> m_network;
@@ -890,6 +1158,22 @@ public:
 
 		Eigen::Vector2i resolution;
 		ELossType loss_type;
+
+		Buffer2DView<const Eigen::Array4f> inference_view() const {
+			if (!envmap) {
+				return {};
+			}
+
+			return {(const Eigen::Array4f*)envmap->inference_params(), resolution};
+		}
+
+		Buffer2DView<const Eigen::Array4f> view() const {
+			if (!envmap) {
+				return {};
+			}
+
+			return {(const Eigen::Array4f*)envmap->params(), resolution};
+		}
 	} m_envmap;
 
 	struct TrainableDistortionMap {
@@ -897,6 +1181,22 @@ public:
 		std::shared_ptr<TrainableBuffer<2, 2, float>> map;
 		std::shared_ptr<tcnn::Trainer<float, float, float>> trainer;
 		Eigen::Vector2i resolution;
+
+		Buffer2DView<const Eigen::Vector2f> inference_view() const {
+			if (!map) {
+				return {};
+			}
+
+			return {(const Eigen::Vector2f*)map->inference_params(), resolution};
+		}
+
+		Buffer2DView<const Eigen::Vector2f> view() const {
+			if (!map) {
+				return {};
+			}
+
+			return {(const Eigen::Vector2f*)map->params(), resolution};
+		}
 	} m_distortion;
 	std::shared_ptr<NerfNetwork<precision_t>> m_nerf_network;
 };
diff --git a/scripts/colmap2nerf.py b/scripts/colmap2nerf.py
index 68d6bd9ece6a74588e8ac87f9197d8c9da910079..30f5104a0f597dea4e645cf141ef8584d6e61136 100755
--- a/scripts/colmap2nerf.py
+++ b/scripts/colmap2nerf.py
@@ -414,8 +414,7 @@ if __name__ == "__main__":
 		from detectron2 import model_zoo
 		from detectron2.engine import DefaultPredictor
 
-		dir_path = Path(os.path.dirname(os.path.realpath(__file__)))
-		category2id = json.load(open(dir_path / "category2id.json", "r"))
+		category2id = json.load(open(SCRIPTS_FOLDER / "category2id.json", "r"))
 		mask_ids = [category2id[c] for c in args.mask_categories]
 
 		cfg = get_cfg()
diff --git a/scripts/run.py b/scripts/run.py
index 16152cce42ddf2f2b00edeac8fe0e4088512f33b..d7b2054383b6eda31656629ce5dba0f2ed23514b 100644
--- a/scripts/run.py
+++ b/scripts/run.py
@@ -64,6 +64,7 @@ def parse_args():
 	parser.add_argument("--train", action="store_true", help="If the GUI is enabled, controls whether training starts immediately.")
 	parser.add_argument("--n_steps", type=int, default=-1, help="Number of steps to train for before quitting.")
 	parser.add_argument("--second_window", action="store_true", help="Open a second window containing a copy of the main output.")
+	parser.add_argument("--vr", action="store_true", help="Render to a VR headset.")
 
 	parser.add_argument("--sharpen", default=0, help="Set amount of sharpening applied to NeRF training images. Range 0.0 to 1.0.")
 
@@ -78,6 +79,8 @@ def get_scene(scene):
 
 if __name__ == "__main__":
 	args = parse_args()
+	if args.vr: # VR implies having the GUI running at the moment
+		args.gui = True
 
 	if args.mode:
 		print("Warning: the '--mode' argument is no longer in use. It has no effect. The mode is automatically chosen based on the scene.")
@@ -106,7 +109,9 @@ if __name__ == "__main__":
 		while sw * sh > 1920 * 1080 * 4:
 			sw = int(sw / 2)
 			sh = int(sh / 2)
-		testbed.init_window(sw, sh, second_window = args.second_window or False)
+		testbed.init_window(sw, sh, second_window=args.second_window)
+		if args.vr:
+			testbed.init_vr()
 
 
 	if args.load_snapshot:
@@ -159,10 +164,8 @@ if __name__ == "__main__":
 		# setting here.
 		testbed.nerf.cone_angle_constant = 0
 
-		# Optionally match nerf paper behaviour and train on a
-		# fixed white bg. We prefer training on random BG colors.
-		# testbed.background_color = [1.0, 1.0, 1.0, 1.0]
-		# testbed.nerf.training.random_bg_color = False
+		# Match nerf paper behaviour and train on a fixed bg.
+		testbed.nerf.training.random_bg_color = False
 
 	old_training_step = 0
 	n_steps = args.n_steps
@@ -223,53 +226,24 @@ if __name__ == "__main__":
 
 		testbed.nerf.render_min_transmittance = 1e-4
 
-		testbed.fov_axis = 0
-		testbed.fov = test_transforms["camera_angle_x"] * 180 / np.pi
 		testbed.shall_train = False
+		testbed.load_training_data(args.test_transforms)
 
-		with tqdm(list(enumerate(test_transforms["frames"])), unit="images", desc=f"Rendering test frame") as t:
-			for i, frame in t:
-				p = frame["file_path"]
-				if "." not in p:
-					p = p + ".png"
-				ref_fname = os.path.join(data_dir, p)
-				if not os.path.isfile(ref_fname):
-					ref_fname = os.path.join(data_dir, p + ".png")
-					if not os.path.isfile(ref_fname):
-						ref_fname = os.path.join(data_dir, p + ".jpg")
-						if not os.path.isfile(ref_fname):
-							ref_fname = os.path.join(data_dir, p + ".jpeg")
-							if not os.path.isfile(ref_fname):
-								ref_fname = os.path.join(data_dir, p + ".exr")
-
-				ref_image = read_image(ref_fname)
-
-				# NeRF blends with background colors in sRGB space, rather than first
-				# transforming to linear space, blending there, and then converting back.
-				# (See e.g. the PNG spec for more information on how the `alpha` channel
-				# is always a linear quantity.)
-				# The following lines of code reproduce NeRF's behavior (if enabled in
-				# testbed) in order to make the numbers comparable.
-				if testbed.color_space == ngp.ColorSpace.SRGB and ref_image.shape[2] == 4:
-					# Since sRGB conversion is non-linear, alpha must be factored out of it
-					ref_image[...,:3] = np.divide(ref_image[...,:3], ref_image[...,3:4], out=np.zeros_like(ref_image[...,:3]), where=ref_image[...,3:4] != 0)
-					ref_image[...,:3] = linear_to_srgb(ref_image[...,:3])
-					ref_image[...,:3] *= ref_image[...,3:4]
-					ref_image += (1.0 - ref_image[...,3:4]) * testbed.background_color
-					ref_image[...,:3] = srgb_to_linear(ref_image[...,:3])
+		with tqdm(range(testbed.nerf.training.dataset.n_images), unit="images", desc=f"Rendering test frame") as t:
+			for i in t:
+				resolution = testbed.nerf.training.dataset.metadata[i].resolution
+				testbed.render_ground_truth = True
+				testbed.set_camera_to_training_view(i)
+				ref_image = testbed.render(resolution[0], resolution[1], 1, True)
+				testbed.render_ground_truth = False
+				image = testbed.render(resolution[0], resolution[1], spp, True)
 
 				if i == 0:
-					write_image("ref.png", ref_image)
-
-				testbed.set_nerf_camera_matrix(np.matrix(frame["transform_matrix"])[:-1,:])
-				image = testbed.render(ref_image.shape[1], ref_image.shape[0], spp, True)
+					write_image(f"ref.png", ref_image)
+					write_image(f"out.png", image)
 
-				if i == 0:
-					write_image("out.png", image)
-
-				diffimg = np.absolute(image - ref_image)
-				diffimg[...,3:4] = 1.0
-				if i == 0:
+					diffimg = np.absolute(image - ref_image)
+					diffimg[...,3:4] = 1.0
 					write_image("diff.png", diffimg)
 
 				A = np.clip(linear_to_srgb(image[...,:3]), 0.0, 1.0)
diff --git a/src/camera_path.cu b/src/camera_path.cu
index 853f495a92af1056b3cab7183fd0bb82d36ff2ff..6b8953ce339c6dbff353012d9ca975afba970250 100644
--- a/src/camera_path.cu
+++ b/src/camera_path.cu
@@ -318,13 +318,11 @@ void visualize_nerf_camera(ImDrawList* list, const Matrix<float, 4, 4>& world2pr
 	add_debug_line(list, world2proj, d, a, col, thickness);
 }
 
-bool CameraPath::imgui_viz(ImDrawList* list, Matrix<float, 4, 4> &view2proj, Matrix<float, 4, 4> &world2proj, Matrix<float, 4, 4> &world2view, Vector2f focal, float aspect) {
+bool CameraPath::imgui_viz(ImDrawList* list, Matrix<float, 4, 4> &view2proj, Matrix<float, 4, 4> &world2proj, Matrix<float, 4, 4> &world2view, Vector2f focal, float aspect, float znear, float zfar) {
 	bool changed = false;
 	float flx = focal.x();
 	float fly = focal.y();
 	Matrix<float, 4, 4> view2proj_guizmo;
-	float zfar = 100.f;
-	float znear = 0.1f;
 	view2proj_guizmo <<
 		fly * 2.0f / aspect, 0, 0, 0,
 		0, -fly * 2.0f, 0, 0,
diff --git a/src/dlss.cu b/src/dlss.cu
index d81df9390f895cd0f4f47039e6db9a849081d805..28ac47e05998f6f3dac3aae74416bf128fafc8b0 100644
--- a/src/dlss.cu
+++ b/src/dlss.cu
@@ -79,36 +79,18 @@ std::string ngx_error_string(NVSDK_NGX_Result result) {
 			throw std::runtime_error(std::string(FILE_LINE " " #x " failed with error ") + ngx_error_string(result)); \
 	} while(0)
 
-static VkInstance vk_instance = VK_NULL_HANDLE;
-static VkDebugUtilsMessengerEXT vk_debug_messenger = VK_NULL_HANDLE;
-static VkPhysicalDevice vk_physical_device = VK_NULL_HANDLE;
-static VkDevice vk_device = VK_NULL_HANDLE;
-static VkQueue vk_queue = VK_NULL_HANDLE;
-static VkCommandPool vk_command_pool = VK_NULL_HANDLE;
-static VkCommandBuffer vk_command_buffer = VK_NULL_HANDLE;
-
-static bool ngx_initialized = false;
-static NVSDK_NGX_Parameter* ngx_parameters = nullptr;
-
-uint32_t vk_find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags properties) {
-	VkPhysicalDeviceMemoryProperties mem_properties;
-	vkGetPhysicalDeviceMemoryProperties(vk_physical_device, &mem_properties);
-
-	for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++) {
-		if (type_filter & (1 << i) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties) {
-			return i;
-		}
-	}
-
-	throw std::runtime_error{"Failed to find suitable memory type."};
-}
-
 static VKAPI_ATTR VkBool32 VKAPI_CALL vk_debug_callback(
 	VkDebugUtilsMessageSeverityFlagBitsEXT message_severity,
 	VkDebugUtilsMessageTypeFlagsEXT message_type,
 	const VkDebugUtilsMessengerCallbackDataEXT* callback_data,
 	void* user_data
 ) {
+	// Ignore json files that couldn't be found... third party tools sometimes install bogus layers
+	// that manifest as warnings like this.
+	if (std::string{callback_data->pMessage}.find("Failed to open JSON file") != std::string::npos) {
+		return VK_FALSE;
+	}
+
 	if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT) {
 		tlog::warning() << "Vulkan error: " << callback_data->pMessage;
 	} else if (message_severity & VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT) {
@@ -120,366 +102,449 @@ static VKAPI_ATTR VkBool32 VKAPI_CALL vk_debug_callback(
 	return VK_FALSE;
 }
 
-void vulkan_and_ngx_init() {
-	static bool already_initialized = false;
+class VulkanAndNgx : public IDlssProvider, public std::enable_shared_from_this<VulkanAndNgx> {
+public:
+	VulkanAndNgx() {
+		ScopeGuard cleanup_guard{[&]() { clear(); }};
 
-	if (already_initialized) {
-		return;
-	}
+		if (!glfwVulkanSupported()) {
+			throw std::runtime_error{"!glfwVulkanSupported()"};
+		}
 
-	already_initialized = true;
+		// -------------------------------
+		// Vulkan Instance
+		// -------------------------------
+		VkApplicationInfo app_info{};
+		app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
+		app_info.pApplicationName = "NGP";
+		app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
+		app_info.pEngineName = "No engine";
+		app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
+		app_info.apiVersion = VK_API_VERSION_1_0;
+
+		VkInstanceCreateInfo instance_create_info = {};
+		instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
+		instance_create_info.pApplicationInfo = &app_info;
+
+		uint32_t available_layer_count;
+		vkEnumerateInstanceLayerProperties(&available_layer_count, nullptr);
+
+		std::vector<VkLayerProperties> available_layers(available_layer_count);
+		vkEnumerateInstanceLayerProperties(&available_layer_count, available_layers.data());
+
+		std::vector<const char*> layers;
+		auto try_add_layer = [&](const char* layer) {
+			for (const auto& props : available_layers) {
+				if (strcmp(layer, props.layerName) == 0) {
+					layers.emplace_back(layer);
+					return true;
+				}
+			}
 
-	if (!glfwVulkanSupported()) {
-		throw std::runtime_error{"!glfwVulkanSupported()"};
-	}
+			return false;
+		};
 
-	// -------------------------------
-	// Vulkan Instance
-	// -------------------------------
-	VkApplicationInfo app_info{};
-	app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
-	app_info.pApplicationName = "NGP";
-	app_info.applicationVersion = VK_MAKE_VERSION(1, 0, 0);
-	app_info.pEngineName = "No engine";
-	app_info.engineVersion = VK_MAKE_VERSION(1, 0, 0);
-	app_info.apiVersion = VK_API_VERSION_1_0;
-
-	VkInstanceCreateInfo instance_create_info = {};
-	instance_create_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
-	instance_create_info.pApplicationInfo = &app_info;
-
-	uint32_t available_layer_count;
-	vkEnumerateInstanceLayerProperties(&available_layer_count, nullptr);
-
-	std::vector<VkLayerProperties> available_layers(available_layer_count);
-	vkEnumerateInstanceLayerProperties(&available_layer_count, available_layers.data());
-
-	std::vector<const char*> layers;
-	auto try_add_layer = [&](const char* layer) {
-		for (const auto& props : available_layers) {
-			if (strcmp(layer, props.layerName) == 0) {
-				layers.emplace_back(layer);
-				return true;
-			}
+		bool validation_layer_enabled = try_add_layer("VK_LAYER_KHRONOS_validation");
+		if (!validation_layer_enabled) {
+			tlog::warning() << "Vulkan validation layer is not available. Vulkan errors will be difficult to diagnose.";
 		}
 
-		return false;
-	};
+		instance_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
+		instance_create_info.ppEnabledLayerNames = layers.empty() ? nullptr : layers.data();
 
-	bool validation_layer_enabled = try_add_layer("VK_LAYER_KHRONOS_validation");
-	if (!validation_layer_enabled) {
-		tlog::warning() << "Vulkan validation layer is not available. Vulkan errors will be difficult to diagnose.";
-	}
+		std::vector<const char*> instance_extensions;
+		std::vector<const char*> device_extensions;
 
-	instance_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
-	instance_create_info.ppEnabledLayerNames = layers.empty() ? nullptr : layers.data();
+		uint32_t n_ngx_instance_extensions = 0;
+		const char** ngx_instance_extensions;
 
-	std::vector<const char*> instance_extensions;
-	std::vector<const char*> device_extensions;
+		uint32_t n_ngx_device_extensions = 0;
+		const char** ngx_device_extensions;
 
-	uint32_t n_ngx_instance_extensions = 0;
-	const char** ngx_instance_extensions;
+		NVSDK_NGX_VULKAN_RequiredExtensions(&n_ngx_instance_extensions, &ngx_instance_extensions, &n_ngx_device_extensions, &ngx_device_extensions);
 
-	uint32_t n_ngx_device_extensions = 0;
-	const char** ngx_device_extensions;
+		for (uint32_t i = 0; i < n_ngx_instance_extensions; ++i) {
+			instance_extensions.emplace_back(ngx_instance_extensions[i]);
+		}
 
-	NVSDK_NGX_VULKAN_RequiredExtensions(&n_ngx_instance_extensions, &ngx_instance_extensions, &n_ngx_device_extensions, &ngx_device_extensions);
+		instance_extensions.emplace_back(VK_KHR_DEVICE_GROUP_CREATION_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME);
+		instance_extensions.emplace_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
 
-	for (uint32_t i = 0; i < n_ngx_instance_extensions; ++i) {
-		instance_extensions.emplace_back(ngx_instance_extensions[i]);
-	}
+		if (validation_layer_enabled) {
+			instance_extensions.emplace_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
+		}
 
-	instance_extensions.emplace_back(VK_KHR_DEVICE_GROUP_CREATION_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_EXTERNAL_FENCE_CAPABILITIES_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME);
-	instance_extensions.emplace_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME);
+		for (uint32_t i = 0; i < n_ngx_device_extensions; ++i) {
+			device_extensions.emplace_back(ngx_device_extensions[i]);
+		}
 
-	if (validation_layer_enabled) {
-		instance_extensions.emplace_back(VK_EXT_DEBUG_UTILS_EXTENSION_NAME);
-	}
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME);
+	#ifdef _WIN32
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME);
+	#else
+		device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME);
+	#endif
+		device_extensions.emplace_back(VK_KHR_DEVICE_GROUP_EXTENSION_NAME);
+
+		instance_create_info.enabledExtensionCount = (uint32_t)instance_extensions.size();
+		instance_create_info.ppEnabledExtensionNames = instance_extensions.data();
+
+		VkDebugUtilsMessengerCreateInfoEXT debug_messenger_create_info = {};
+		debug_messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
+		debug_messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
+		debug_messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
+		debug_messenger_create_info.pfnUserCallback = vk_debug_callback;
+		debug_messenger_create_info.pUserData = nullptr;
+
+		if (validation_layer_enabled) {
+			instance_create_info.pNext = &debug_messenger_create_info;
+		}
 
-	for (uint32_t i = 0; i < n_ngx_device_extensions; ++i) {
-		device_extensions.emplace_back(ngx_device_extensions[i]);
-	}
+		VK_CHECK_THROW(vkCreateInstance(&instance_create_info, nullptr, &m_vk_instance));
 
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME);
-#ifdef _WIN32
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME);
-#else
-	device_extensions.emplace_back(VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME);
-#endif
-	device_extensions.emplace_back(VK_KHR_DEVICE_GROUP_EXTENSION_NAME);
+		if (validation_layer_enabled) {
+			auto CreateDebugUtilsMessengerEXT = [](VkInstance instance, const VkDebugUtilsMessengerCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugUtilsMessengerEXT* pDebugMessenger) {
+				auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkCreateDebugUtilsMessengerEXT");
+				if (func != nullptr) {
+					return func(instance, pCreateInfo, pAllocator, pDebugMessenger);
+				} else {
+					return VK_ERROR_EXTENSION_NOT_PRESENT;
+				}
+			};
 
-	instance_create_info.enabledExtensionCount = (uint32_t)instance_extensions.size();
-	instance_create_info.ppEnabledExtensionNames = instance_extensions.data();
+			if (CreateDebugUtilsMessengerEXT(m_vk_instance, &debug_messenger_create_info, nullptr, &m_vk_debug_messenger) != VK_SUCCESS) {
+				tlog::warning() << "Vulkan: could not initialize debug messenger.";
+			}
+		}
 
-	VkDebugUtilsMessengerCreateInfoEXT debug_messenger_create_info = {};
-	debug_messenger_create_info.sType = VK_STRUCTURE_TYPE_DEBUG_UTILS_MESSENGER_CREATE_INFO_EXT;
-	debug_messenger_create_info.messageSeverity = VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT;
-	debug_messenger_create_info.messageType = VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_PERFORMANCE_BIT_EXT;
-	debug_messenger_create_info.pfnUserCallback = vk_debug_callback;
-	debug_messenger_create_info.pUserData = nullptr;
+		// -------------------------------
+		// Vulkan Physical Device
+		// -------------------------------
+		uint32_t n_devices = 0;
+		vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, nullptr);
 
-	if (validation_layer_enabled) {
-		instance_create_info.pNext = &debug_messenger_create_info;
-	}
+		if (n_devices == 0) {
+			throw std::runtime_error{"Failed to find GPUs with Vulkan support."};
+		}
 
-	VK_CHECK_THROW(vkCreateInstance(&instance_create_info, nullptr, &vk_instance));
+		std::vector<VkPhysicalDevice> devices(n_devices);
+		vkEnumeratePhysicalDevices(m_vk_instance, &n_devices, devices.data());
 
-	if (validation_layer_enabled) {
-		auto CreateDebugUtilsMessengerEXT = [](VkInstance instance, const VkDebugUtilsMessengerCreateInfoEXT* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkDebugUtilsMessengerEXT* pDebugMessenger) {
-			auto func = (PFN_vkCreateDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkCreateDebugUtilsMessengerEXT");
-			if (func != nullptr) {
-				return func(instance, pCreateInfo, pAllocator, pDebugMessenger);
-			} else {
-				return VK_ERROR_EXTENSION_NOT_PRESENT;
-			}
+		struct QueueFamilyIndices {
+			int graphics_family = -1;
+			int compute_family = -1;
+			int transfer_family = -1;
+			int all_family = -1;
 		};
 
-		if (CreateDebugUtilsMessengerEXT(vk_instance, &debug_messenger_create_info, nullptr, &vk_debug_messenger) != VK_SUCCESS) {
-			tlog::warning() << "Vulkan: could not initialize debug messenger.";
-		}
-	}
-
-	// -------------------------------
-	// Vulkan Physical Device
-	// -------------------------------
-	uint32_t n_devices = 0;
-	vkEnumeratePhysicalDevices(vk_instance, &n_devices, nullptr);
+		auto find_queue_families = [](VkPhysicalDevice device) {
+			QueueFamilyIndices indices;
 
-	if (n_devices == 0) {
-		throw std::runtime_error{"Failed to find GPUs with Vulkan support."};
-	}
+			uint32_t queue_family_count = 0;
+			vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
 
-	std::vector<VkPhysicalDevice> devices(n_devices);
-	vkEnumeratePhysicalDevices(vk_instance, &n_devices, devices.data());
+			std::vector<VkQueueFamilyProperties> queue_families(queue_family_count);
+			vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
 
-	struct QueueFamilyIndices {
-		int graphics_family = -1;
-		int compute_family = -1;
-		int transfer_family = -1;
-		int all_family = -1;
-	};
+			int i = 0;
+			for (const auto& queue_family : queue_families) {
+				if (queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
+					indices.graphics_family = i;
+				}
 
-	auto find_queue_families = [](VkPhysicalDevice device) {
-		QueueFamilyIndices indices;
+				if (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) {
+					indices.compute_family = i;
+				}
 
-		uint32_t queue_family_count = 0;
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, nullptr);
+				if (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT) {
+					indices.transfer_family = i;
+				}
 
-		std::vector<VkQueueFamilyProperties> queue_families(queue_family_count);
-		vkGetPhysicalDeviceQueueFamilyProperties(device, &queue_family_count, queue_families.data());
+				if ((queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) && (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) && (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT)) {
+					indices.all_family = i;
+				}
 
-		int i = 0;
-		for (const auto& queue_family : queue_families) {
-			if (queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) {
-				indices.graphics_family = i;
+				i++;
 			}
 
-			if (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) {
-				indices.compute_family = i;
-			}
+			return indices;
+		};
 
-			if (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT) {
-				indices.transfer_family = i;
-			}
+		cudaDeviceProp cuda_device_prop;
+		CUDA_CHECK_THROW(cudaGetDeviceProperties(&cuda_device_prop, tcnn::cuda_device()));
 
-			if ((queue_family.queueFlags & VK_QUEUE_GRAPHICS_BIT) && (queue_family.queueFlags & VK_QUEUE_COMPUTE_BIT) && (queue_family.queueFlags & VK_QUEUE_TRANSFER_BIT)) {
-				indices.all_family = i;
-			}
+		auto is_same_as_cuda_device = [&](VkPhysicalDevice device) {
+			VkPhysicalDeviceIDProperties physical_device_id_properties = {};
+			physical_device_id_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES;
+			physical_device_id_properties.pNext = NULL;
 
-			i++;
-		}
+			VkPhysicalDeviceProperties2 physical_device_properties = {};
+			physical_device_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
+			physical_device_properties.pNext = &physical_device_id_properties;
 
-		return indices;
-	};
+			vkGetPhysicalDeviceProperties2(device, &physical_device_properties);
 
-	cudaDeviceProp cuda_device_prop;
-	CUDA_CHECK_THROW(cudaGetDeviceProperties(&cuda_device_prop, tcnn::cuda_device()));
+			return !memcmp(&cuda_device_prop.uuid, physical_device_id_properties.deviceUUID, VK_UUID_SIZE) && find_queue_families(device).all_family >= 0;
+		};
 
-	auto is_same_as_cuda_device = [&](VkPhysicalDevice device) {
-		VkPhysicalDeviceIDProperties physical_device_id_properties = {};
-		physical_device_id_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ID_PROPERTIES;
-		physical_device_id_properties.pNext = NULL;
+		uint32_t device_id = 0;
+		for (uint32_t i = 0; i < n_devices; ++i) {
+			if (is_same_as_cuda_device(devices[i])) {
+				m_vk_physical_device = devices[i];
+				device_id = i;
+				break;
+			}
+		}
 
-		VkPhysicalDeviceProperties2 physical_device_properties = {};
-		physical_device_properties.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2;
-		physical_device_properties.pNext = &physical_device_id_properties;
+		if (m_vk_physical_device == VK_NULL_HANDLE) {
+			throw std::runtime_error{"Failed to find Vulkan device corresponding to CUDA device."};
+		}
 
-		vkGetPhysicalDeviceProperties2(device, &physical_device_properties);
+		// -------------------------------
+		// Vulkan Logical Device
+		// -------------------------------
+		VkPhysicalDeviceProperties physical_device_properties;
+		vkGetPhysicalDeviceProperties(m_vk_physical_device, &physical_device_properties);
+
+		QueueFamilyIndices indices = find_queue_families(m_vk_physical_device);
+
+		VkDeviceQueueCreateInfo queue_create_info{};
+		queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
+		queue_create_info.queueFamilyIndex = indices.all_family;
+		queue_create_info.queueCount = 1;
+
+		float queue_priority = 1.0f;
+		queue_create_info.pQueuePriorities = &queue_priority;
+
+		VkPhysicalDeviceFeatures device_features = {};
+		device_features.shaderStorageImageWriteWithoutFormat = true;
+
+		VkDeviceCreateInfo device_create_info = {};
+		device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
+		device_create_info.pQueueCreateInfos = &queue_create_info;
+		device_create_info.queueCreateInfoCount = 1;
+		device_create_info.pEnabledFeatures = &device_features;
+		device_create_info.enabledExtensionCount = (uint32_t)device_extensions.size();
+		device_create_info.ppEnabledExtensionNames = device_extensions.data();
+		device_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
+		device_create_info.ppEnabledLayerNames = layers.data();
+
+	#ifdef VK_EXT_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME
+		VkPhysicalDeviceBufferDeviceAddressFeaturesEXT buffer_device_address_feature = {};
+		buffer_device_address_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_EXT;
+		buffer_device_address_feature.bufferDeviceAddress = VK_TRUE;
+		device_create_info.pNext = &buffer_device_address_feature;
+	#else
+		throw std::runtime_error{"Buffer device address extension not available."};
+	#endif
+
+		VK_CHECK_THROW(vkCreateDevice(m_vk_physical_device, &device_create_info, nullptr, &m_vk_device));
+
+		// -----------------------------------------------
+		// Vulkan queue / command pool / command buffer
+		// -----------------------------------------------
+		vkGetDeviceQueue(m_vk_device, indices.all_family, 0, &m_vk_queue);
+
+		VkCommandPoolCreateInfo command_pool_info = {};
+		command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
+		command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
+		command_pool_info.queueFamilyIndex = indices.all_family;
+
+		VK_CHECK_THROW(vkCreateCommandPool(m_vk_device, &command_pool_info, nullptr, &m_vk_command_pool));
+
+		VkCommandBufferAllocateInfo command_buffer_alloc_info = {};
+		command_buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
+		command_buffer_alloc_info.commandPool = m_vk_command_pool;
+		command_buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
+		command_buffer_alloc_info.commandBufferCount = 1;
+
+		VK_CHECK_THROW(vkAllocateCommandBuffers(m_vk_device, &command_buffer_alloc_info, &m_vk_command_buffer));
+
+		// -------------------------------
+		// NGX init
+		// -------------------------------
+		std::wstring path;
+#ifdef _WIN32
+		path = fs::path::getcwd().wstr();
+#else
+		std::string tmp = fs::path::getcwd().str();
+		path = utf8_to_utf16(fs::path::getcwd().str());
+#endif
 
-		return !memcmp(&cuda_device_prop.uuid, physical_device_id_properties.deviceUUID, VK_UUID_SIZE) && find_queue_families(device).all_family >= 0;
-	};
+		NGX_CHECK_THROW(NVSDK_NGX_VULKAN_Init_with_ProjectID("ea75345e-5a42-4037-a5c9-59bf94dee157", NVSDK_NGX_ENGINE_TYPE_CUSTOM, "1.0.0", path.c_str(), m_vk_instance, m_vk_physical_device, m_vk_device));
+		m_ngx_initialized = true;
+
+		// -------------------------------
+		// Ensure DLSS capability
+		// -------------------------------
+		NGX_CHECK_THROW(NVSDK_NGX_VULKAN_GetCapabilityParameters(&m_ngx_parameters));
+
+		int needs_updated_driver = 0;
+		unsigned int min_driver_version_major = 0;
+		unsigned int min_driver_version_minor = 0;
+		NVSDK_NGX_Result result_updated_driver = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_NeedsUpdatedDriver, &needs_updated_driver);
+		NVSDK_NGX_Result result_min_driver_version_major = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMajor, &min_driver_version_major);
+		NVSDK_NGX_Result result_min_driver_version_minor = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMinor, &min_driver_version_minor);
+		if (result_updated_driver == NVSDK_NGX_Result_Success && result_min_driver_version_major == NVSDK_NGX_Result_Success && result_min_driver_version_minor == NVSDK_NGX_Result_Success) {
+			if (needs_updated_driver) {
+				throw std::runtime_error{fmt::format("Driver too old. Minimum version required is {}.{}", min_driver_version_major, min_driver_version_minor)};
+			}
+		}
 
-	uint32_t device_id = 0;
-	for (uint32_t i = 0; i < n_devices; ++i) {
-		if (is_same_as_cuda_device(devices[i])) {
-			vk_physical_device = devices[i];
-			device_id = i;
-			break;
+		int dlss_available  = 0;
+		NVSDK_NGX_Result ngx_result = m_ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_Available, &dlss_available);
+		if (ngx_result != NVSDK_NGX_Result_Success || !dlss_available) {
+			ngx_result = NVSDK_NGX_Result_Fail;
+			NVSDK_NGX_Parameter_GetI(m_ngx_parameters, NVSDK_NGX_Parameter_SuperSampling_FeatureInitResult, (int*)&ngx_result);
+			throw std::runtime_error{fmt::format("DLSS not available: {}", ngx_error_string(ngx_result))};
 		}
+
+		cleanup_guard.disarm();
+
+		tlog::success() << "Initialized Vulkan and NGX on GPU #" << device_id << ": " << physical_device_properties.deviceName;
 	}
 
-	if (vk_physical_device == VK_NULL_HANDLE) {
-		throw std::runtime_error{"Failed to find Vulkan device corresponding to CUDA device."};
+	virtual ~VulkanAndNgx() {
+		clear();
 	}
 
-	// -------------------------------
-	// Vulkan Logical Device
-	// -------------------------------
-	VkPhysicalDeviceProperties physical_device_properties;
-	vkGetPhysicalDeviceProperties(vk_physical_device, &physical_device_properties);
-
-	QueueFamilyIndices indices = find_queue_families(vk_physical_device);
-
-	VkDeviceQueueCreateInfo queue_create_info{};
-	queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
-	queue_create_info.queueFamilyIndex = indices.all_family;
-	queue_create_info.queueCount = 1;
-
-	float queue_priority = 1.0f;
-	queue_create_info.pQueuePriorities = &queue_priority;
-
-	VkPhysicalDeviceFeatures device_features = {};
-	device_features.shaderStorageImageWriteWithoutFormat = true;
-
-	VkDeviceCreateInfo device_create_info = {};
-	device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
-	device_create_info.pQueueCreateInfos = &queue_create_info;
-	device_create_info.queueCreateInfoCount = 1;
-	device_create_info.pEnabledFeatures = &device_features;
-	device_create_info.enabledExtensionCount = (uint32_t)device_extensions.size();
-	device_create_info.ppEnabledExtensionNames = device_extensions.data();
-	device_create_info.enabledLayerCount = static_cast<uint32_t>(layers.size());
-	device_create_info.ppEnabledLayerNames = layers.data();
-
-#ifdef VK_EXT_BUFFER_DEVICE_ADDRESS_EXTENSION_NAME
-	VkPhysicalDeviceBufferDeviceAddressFeaturesEXT buffer_device_address_feature = {};
-	buffer_device_address_feature.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_BUFFER_DEVICE_ADDRESS_FEATURES_EXT;
-	buffer_device_address_feature.bufferDeviceAddress = VK_TRUE;
-	device_create_info.pNext = &buffer_device_address_feature;
-#else
-	throw std::runtime_error{"Buffer device address extension not available."};
-#endif
+	void clear() {
+		if (m_ngx_parameters) {
+			NVSDK_NGX_VULKAN_DestroyParameters(m_ngx_parameters);
+			m_ngx_parameters = nullptr;
+		}
 
-	VK_CHECK_THROW(vkCreateDevice(vk_physical_device, &device_create_info, nullptr, &vk_device));
+		if (m_ngx_initialized) {
+			NVSDK_NGX_VULKAN_Shutdown();
+			m_ngx_initialized = false;
+		}
 
-	// -----------------------------------------------
-	// Vulkan queue / command pool / command buffer
-	// -----------------------------------------------
-	vkGetDeviceQueue(vk_device, indices.all_family, 0, &vk_queue);
+		if (m_vk_command_pool) {
+			vkDestroyCommandPool(m_vk_device, m_vk_command_pool, nullptr);
+			m_vk_command_pool = VK_NULL_HANDLE;
+		}
 
-	VkCommandPoolCreateInfo command_pool_info = {};
-	command_pool_info.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
-	command_pool_info.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT;
-	command_pool_info.queueFamilyIndex = indices.all_family;
+		if (m_vk_device) {
+			vkDestroyDevice(m_vk_device, nullptr);
+			m_vk_device = VK_NULL_HANDLE;
+		}
 
-	VK_CHECK_THROW(vkCreateCommandPool(vk_device, &command_pool_info, nullptr, &vk_command_pool));
+		if (m_vk_debug_messenger) {
+			auto DestroyDebugUtilsMessengerEXT = [](VkInstance instance, VkDebugUtilsMessengerEXT debugMessenger, const VkAllocationCallbacks* pAllocator) {
+				auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkDestroyDebugUtilsMessengerEXT");
+				if (func != nullptr) {
+					func(instance, debugMessenger, pAllocator);
+				}
+			};
 
-	VkCommandBufferAllocateInfo command_buffer_alloc_info = {};
-	command_buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
-	command_buffer_alloc_info.commandPool = vk_command_pool;
-	command_buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
-	command_buffer_alloc_info.commandBufferCount = 1;
+			DestroyDebugUtilsMessengerEXT(m_vk_instance, m_vk_debug_messenger, nullptr);
+			m_vk_debug_messenger = VK_NULL_HANDLE;
+		}
 
-	VK_CHECK_THROW(vkAllocateCommandBuffers(vk_device, &command_buffer_alloc_info, &vk_command_buffer));
+		if (m_vk_instance) {
+			vkDestroyInstance(m_vk_instance, nullptr);
+			m_vk_instance = VK_NULL_HANDLE;
+		}
+	}
 
-	// -------------------------------
-	// NGX init
-	// -------------------------------
-	std::wstring path;
-#ifdef _WIN32
-	path = fs::path::getcwd().wstr();
-#else
-	std::string tmp = fs::path::getcwd().str();
-	std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
-	path = converter.from_bytes(tmp);
-#endif
+	uint32_t vk_find_memory_type(uint32_t type_filter, VkMemoryPropertyFlags properties) {
+		VkPhysicalDeviceMemoryProperties mem_properties;
+		vkGetPhysicalDeviceMemoryProperties(m_vk_physical_device, &mem_properties);
 
-	NGX_CHECK_THROW(NVSDK_NGX_VULKAN_Init_with_ProjectID("ea75345e-5a42-4037-a5c9-59bf94dee157", NVSDK_NGX_ENGINE_TYPE_CUSTOM, "1.0.0", path.c_str(), vk_instance, vk_physical_device, vk_device));
-	ngx_initialized = true;
-
-	// -------------------------------
-	// Ensure DLSS capability
-	// -------------------------------
-	NGX_CHECK_THROW(NVSDK_NGX_VULKAN_GetCapabilityParameters(&ngx_parameters));
-
-	int needs_updated_driver = 0;
-	unsigned int min_driver_version_major = 0;
-	unsigned int min_driver_version_minor = 0;
-	NVSDK_NGX_Result result_updated_driver = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_NeedsUpdatedDriver, &needs_updated_driver);
-	NVSDK_NGX_Result result_min_driver_version_major = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMajor, &min_driver_version_major);
-	NVSDK_NGX_Result result_min_driver_version_minor = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_MinDriverVersionMinor, &min_driver_version_minor);
-	if (result_updated_driver == NVSDK_NGX_Result_Success && result_min_driver_version_major == NVSDK_NGX_Result_Success && result_min_driver_version_minor == NVSDK_NGX_Result_Success) {
-		if (needs_updated_driver) {
-			throw std::runtime_error{fmt::format("Driver too old. Minimum version required is {}.{}", min_driver_version_major, min_driver_version_minor)};
+		for (uint32_t i = 0; i < mem_properties.memoryTypeCount; i++) {
+			if (type_filter & (1 << i) && (mem_properties.memoryTypes[i].propertyFlags & properties) == properties) {
+				return i;
+			}
 		}
+
+		throw std::runtime_error{"Failed to find suitable memory type."};
 	}
 
-	int dlss_available  = 0;
-	NVSDK_NGX_Result ngx_result = ngx_parameters->Get(NVSDK_NGX_Parameter_SuperSampling_Available, &dlss_available);
-	if (ngx_result != NVSDK_NGX_Result_Success || !dlss_available) {
-		ngx_result = NVSDK_NGX_Result_Fail;
-		NVSDK_NGX_Parameter_GetI(ngx_parameters, NVSDK_NGX_Parameter_SuperSampling_FeatureInitResult, (int*)&ngx_result);
-		throw std::runtime_error{fmt::format("DLSS not available: {}", ngx_error_string(ngx_result))};
+	void vk_command_buffer_begin() {
+		VkCommandBufferBeginInfo begin_info = {};
+		begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
+		begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
+		begin_info.pInheritanceInfo = nullptr;
+
+		VK_CHECK_THROW(vkBeginCommandBuffer(m_vk_command_buffer, &begin_info));
 	}
 
-	tlog::success() << "Initialized Vulkan and NGX on GPU #" << device_id << ": " << physical_device_properties.deviceName;
-}
+	void vk_command_buffer_end() {
+		VK_CHECK_THROW(vkEndCommandBuffer(m_vk_command_buffer));
+	}
 
-size_t dlss_allocated_bytes() {
-	unsigned long long allocated_bytes = 0;
-	if (!ngx_parameters) {
-		return 0;
+	void vk_command_buffer_submit() {
+		VkSubmitInfo submit_info = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
+		submit_info.commandBufferCount = 1;
+		submit_info.pCommandBuffers = &m_vk_command_buffer;
+
+		VK_CHECK_THROW(vkQueueSubmit(m_vk_queue, 1, &submit_info, VK_NULL_HANDLE));
 	}
 
-	try {
-		NGX_CHECK_THROW(NGX_DLSS_GET_STATS(ngx_parameters, &allocated_bytes));
-	} catch (...) {
-		return 0;
+	void vk_synchronize() {
+		VK_CHECK_THROW(vkDeviceWaitIdle(m_vk_device));
 	}
 
-	return allocated_bytes;
-}
+	void vk_command_buffer_submit_sync() {
+		vk_command_buffer_submit();
+		vk_synchronize();
+	}
 
-void vk_command_buffer_begin() {
-	VkCommandBufferBeginInfo begin_info = {};
-	begin_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
-	begin_info.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
-	begin_info.pInheritanceInfo = nullptr;
+	void vk_command_buffer_end_and_submit_sync() {
+		vk_command_buffer_end();
+		vk_command_buffer_submit_sync();
+	}
 
-	VK_CHECK_THROW(vkBeginCommandBuffer(vk_command_buffer, &begin_info));
-}
+	const VkCommandBuffer& vk_command_buffer() const {
+		return m_vk_command_buffer;
+	}
 
-void vk_command_buffer_end() {
-	VK_CHECK_THROW(vkEndCommandBuffer(vk_command_buffer));
-}
+	const VkDevice& vk_device() const {
+		return m_vk_device;
+	}
 
-void vk_command_buffer_submit() {
-	VkSubmitInfo submit_info = { VK_STRUCTURE_TYPE_SUBMIT_INFO };
-	submit_info.commandBufferCount = 1;
-	submit_info.pCommandBuffers = &vk_command_buffer;
+	NVSDK_NGX_Parameter* ngx_parameters() const {
+		return m_ngx_parameters;
+	}
 
-	VK_CHECK_THROW(vkQueueSubmit(vk_queue, 1, &submit_info, VK_NULL_HANDLE));
-}
+	size_t allocated_bytes() const override {
+		unsigned long long allocated_bytes = 0;
+		if (!m_ngx_parameters) {
+			return 0;
+		}
 
-void vk_synchronize() {
-	VK_CHECK_THROW(vkDeviceWaitIdle(vk_device));
-}
+		try {
+			NGX_CHECK_THROW(NGX_DLSS_GET_STATS(m_ngx_parameters, &allocated_bytes));
+		} catch (...) {
+			return 0;
+		}
 
-void vk_command_buffer_submit_sync() {
-	vk_command_buffer_submit();
-	vk_synchronize();
-}
+		return allocated_bytes;
+	}
+
+	std::unique_ptr<IDlss> init_dlss(const Eigen::Vector2i& out_resolution) override;
+
+private:
+	VkInstance m_vk_instance = VK_NULL_HANDLE;
+	VkDebugUtilsMessengerEXT m_vk_debug_messenger = VK_NULL_HANDLE;
+	VkPhysicalDevice m_vk_physical_device = VK_NULL_HANDLE;
+	VkDevice m_vk_device = VK_NULL_HANDLE;
+	VkQueue m_vk_queue = VK_NULL_HANDLE;
+	VkCommandPool m_vk_command_pool = VK_NULL_HANDLE;
+	VkCommandBuffer m_vk_command_buffer = VK_NULL_HANDLE;
+	NVSDK_NGX_Parameter* m_ngx_parameters = nullptr;
+	bool m_ngx_initialized = false;
+};
 
-void vk_command_buffer_end_and_submit_sync() {
-	vk_command_buffer_end();
-	vk_command_buffer_submit_sync();
+std::shared_ptr<IDlssProvider> init_vulkan_and_ngx() {
+	return std::make_shared<VulkanAndNgx>();
 }
 
 class VulkanTexture {
 public:
-	VulkanTexture(const Vector2i& size, uint32_t n_channels) : m_size{size}, m_n_channels{n_channels} {
+	VulkanTexture(std::shared_ptr<VulkanAndNgx> vk, const Vector2i& size, uint32_t n_channels) : m_vk{vk}, m_size{size}, m_n_channels{n_channels} {
 		VkImageCreateInfo image_info{};
 		image_info.sType = VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO;
 		image_info.imageType = VK_IMAGE_TYPE_2D;
@@ -515,17 +580,17 @@ public:
 
 		image_info.pNext = &ext_image_info;
 
-		VK_CHECK_THROW(vkCreateImage(vk_device, &image_info, nullptr, &m_vk_image));
+		VK_CHECK_THROW(vkCreateImage(m_vk->vk_device(), &image_info, nullptr, &m_vk_image));
 
 		// Create device memory to back up the image
 		VkMemoryRequirements mem_requirements = {};
 
-		vkGetImageMemoryRequirements(vk_device, m_vk_image, &mem_requirements);
+		vkGetImageMemoryRequirements(m_vk->vk_device(), m_vk_image, &mem_requirements);
 
 		VkMemoryAllocateInfo mem_alloc_info = {};
 		mem_alloc_info.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
 		mem_alloc_info.allocationSize = mem_requirements.size;
-		mem_alloc_info.memoryTypeIndex = vk_find_memory_type(mem_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
+		mem_alloc_info.memoryTypeIndex = m_vk->vk_find_memory_type(mem_requirements.memoryTypeBits, VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
 
 		VkExportMemoryAllocateInfoKHR export_info = {};
 		export_info.sType = VK_STRUCTURE_TYPE_EXPORT_MEMORY_ALLOCATE_INFO_KHR;
@@ -533,10 +598,10 @@ public:
 
 		mem_alloc_info.pNext = &export_info;
 
-		VK_CHECK_THROW(vkAllocateMemory(vk_device, &mem_alloc_info, nullptr, &m_vk_device_memory));
-		VK_CHECK_THROW(vkBindImageMemory(vk_device, m_vk_image, m_vk_device_memory, 0));
+		VK_CHECK_THROW(vkAllocateMemory(m_vk->vk_device(), &mem_alloc_info, nullptr, &m_vk_device_memory));
+		VK_CHECK_THROW(vkBindImageMemory(m_vk->vk_device(), m_vk_image, m_vk_device_memory, 0));
 
-		vk_command_buffer_begin();
+		m_vk->vk_command_buffer_begin();
 
 		VkImageMemoryBarrier barrier = {};
 		barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER;
@@ -554,7 +619,7 @@ public:
 		barrier.dstAccessMask = VK_ACCESS_MEMORY_READ_BIT | VK_ACCESS_MEMORY_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_COLOR_ATTACHMENT_READ_BIT | VK_ACCESS_COLOR_ATTACHMENT_WRITE_BIT;
 
 		vkCmdPipelineBarrier(
-			vk_command_buffer,
+			m_vk->vk_command_buffer(),
 			VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
 			0,
 			0, nullptr,
@@ -562,7 +627,7 @@ public:
 			1, &barrier
 		);
 
-		vk_command_buffer_end_and_submit_sync();
+		m_vk->vk_command_buffer_end_and_submit_sync();
 
 		// Image view
 		VkImageViewCreateInfo view_info = {};
@@ -572,7 +637,7 @@ public:
 		view_info.format = image_info.format;
 		view_info.subresourceRange = barrier.subresourceRange;
 
-		VK_CHECK_THROW(vkCreateImageView(vk_device, &view_info, nullptr, &m_vk_image_view));
+		VK_CHECK_THROW(vkCreateImageView(m_vk->vk_device(), &view_info, nullptr, &m_vk_image_view));
 
 		// Map to NGX
 		m_ngx_resource = NVSDK_NGX_Create_ImageView_Resource_VK(m_vk_image_view, m_vk_image, view_info.subresourceRange, image_info.format, m_size.x(), m_size.y(), true);
@@ -584,21 +649,21 @@ public:
 		handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_WIN32_HANDLE_INFO_KHR;
 		handle_info.memory = m_vk_device_memory;
 		handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_BIT;
-		auto pfn_vkGetMemory = (PFN_vkGetMemoryWin32HandleKHR)vkGetDeviceProcAddr(vk_device, "vkGetMemoryWin32HandleKHR");
+		auto pfn_vkGetMemory = (PFN_vkGetMemoryWin32HandleKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryWin32HandleKHR");
 #else
 		int handle = -1;
 		VkMemoryGetFdInfoKHR handle_info = {};
 		handle_info.sType = VK_STRUCTURE_TYPE_MEMORY_GET_FD_INFO_KHR;
 		handle_info.memory = m_vk_device_memory;
 		handle_info.handleType = VK_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD_BIT_KHR;
-		auto pfn_vkGetMemory = (PFN_vkGetMemoryFdKHR)vkGetDeviceProcAddr(vk_device, "vkGetMemoryFdKHR");
+		auto pfn_vkGetMemory = (PFN_vkGetMemoryFdKHR)vkGetDeviceProcAddr(m_vk->vk_device(), "vkGetMemoryFdKHR");
 #endif
 
 		if (!pfn_vkGetMemory) {
 			throw std::runtime_error{"Failed to locate pfn_vkGetMemory."};
 		}
 
-		VK_CHECK_THROW(pfn_vkGetMemory(vk_device, &handle_info, &handle));
+		VK_CHECK_THROW(pfn_vkGetMemory(m_vk->vk_device(), &handle_info, &handle));
 
 		// Map handle to CUDA memory
 		cudaExternalMemoryHandleDesc external_memory_handle_desc = {};
@@ -687,15 +752,15 @@ public:
 		}
 
 		if (m_vk_image_view) {
-			vkDestroyImageView(vk_device, m_vk_image_view, nullptr);
+			vkDestroyImageView(m_vk->vk_device(), m_vk_image_view, nullptr);
 		}
 
 		if (m_vk_image) {
-			vkDestroyImage(vk_device, m_vk_image, nullptr);
+			vkDestroyImage(m_vk->vk_device(), m_vk_image, nullptr);
 		}
 
 		if (m_vk_device_memory) {
-			vkFreeMemory(vk_device, m_vk_device_memory, nullptr);
+			vkFreeMemory(m_vk->vk_device(), m_vk_device_memory, nullptr);
 		}
 	}
 
@@ -720,6 +785,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk;
+
 	Vector2i m_size;
 	uint32_t m_n_channels;
 
@@ -765,7 +832,7 @@ struct DlssFeatureSpecs {
 	}
 };
 
-DlssFeatureSpecs dlss_feature_specs(const Eigen::Vector2i& out_resolution, EDlssQuality quality) {
+DlssFeatureSpecs dlss_feature_specs(NVSDK_NGX_Parameter* ngx_parameters, const Eigen::Vector2i& out_resolution, EDlssQuality quality) {
 	DlssFeatureSpecs specs;
 	specs.quality = quality;
 	specs.out_resolution = out_resolution;
@@ -790,7 +857,7 @@ DlssFeatureSpecs dlss_feature_specs(const Eigen::Vector2i& out_resolution, EDlss
 
 class DlssFeature {
 public:
-	DlssFeature(const DlssFeatureSpecs& specs, bool is_hdr, bool sharpen) : m_specs{specs}, m_is_hdr{is_hdr}, m_sharpen{sharpen} {
+	DlssFeature(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const DlssFeatureSpecs& specs, bool is_hdr, bool sharpen) : m_vk_and_ngx{vk_and_ngx}, m_specs{specs}, m_is_hdr{is_hdr}, m_sharpen{sharpen} {
 		// Initialize DLSS
 		unsigned int creation_node_mask = 1;
 		unsigned int visibility_node_mask = 1;
@@ -799,7 +866,7 @@ public:
 		dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_MVLowRes : 0;
 		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_MVJittered : 0;
 		dlss_create_feature_flags |= is_hdr ? NVSDK_NGX_DLSS_Feature_Flags_IsHDR : 0;
-		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_DepthInverted : 0;
+		dlss_create_feature_flags |= true ? NVSDK_NGX_DLSS_Feature_Flags_DepthInverted : 0;
 		dlss_create_feature_flags |= sharpen ? NVSDK_NGX_DLSS_Feature_Flags_DoSharpening : 0;
 		dlss_create_feature_flags |= false ? NVSDK_NGX_DLSS_Feature_Flags_AutoExposure : 0;
 
@@ -815,15 +882,15 @@ public:
 		dlss_create_params.InFeatureCreateFlags = dlss_create_feature_flags;
 
 		{
-			vk_command_buffer_begin();
-			ScopeGuard command_buffer_guard{[&]() { vk_command_buffer_end_and_submit_sync(); }};
+			m_vk_and_ngx->vk_command_buffer_begin();
+			ScopeGuard command_buffer_guard{[&]() { m_vk_and_ngx->vk_command_buffer_end_and_submit_sync(); }};
 
-			NGX_CHECK_THROW(NGX_VULKAN_CREATE_DLSS_EXT(vk_command_buffer, creation_node_mask, visibility_node_mask, &m_ngx_dlss, ngx_parameters, &dlss_create_params));
+			NGX_CHECK_THROW(NGX_VULKAN_CREATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), creation_node_mask, visibility_node_mask, &m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_create_params));
 		}
 	}
 
-	DlssFeature(const Eigen::Vector2i& out_resolution, bool is_hdr, bool sharpen, EDlssQuality quality)
-	: DlssFeature{dlss_feature_specs(out_resolution, quality), is_hdr, sharpen} {}
+	DlssFeature(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const Eigen::Vector2i& out_resolution, bool is_hdr, bool sharpen, EDlssQuality quality)
+	: DlssFeature{vk_and_ngx, dlss_feature_specs(vk_and_ngx->ngx_parameters(), out_resolution, quality), is_hdr, sharpen} {}
 
 	~DlssFeature() {
 		cudaDeviceSynchronize();
@@ -832,7 +899,7 @@ public:
 			NVSDK_NGX_VULKAN_ReleaseFeature(m_ngx_dlss);
 		}
 
-		vk_synchronize();
+		m_vk_and_ngx->vk_synchronize();
 	}
 
 	void run(
@@ -850,7 +917,7 @@ public:
 			throw std::runtime_error{"May only specify non-zero sharpening, when DlssFeature has been created with sharpen option."};
 		}
 
-		vk_command_buffer_begin();
+		m_vk_and_ngx->vk_command_buffer_begin();
 
 		NVSDK_NGX_VK_DLSS_Eval_Params dlss_params;
 		memset(&dlss_params, 0, sizeof(dlss_params));
@@ -868,9 +935,9 @@ public:
 		dlss_params.InMVScaleY = 1.0f;
 		dlss_params.InRenderSubrectDimensions = {(uint32_t)in_resolution.x(), (uint32_t)in_resolution.y()};
 
-		NGX_CHECK_THROW(NGX_VULKAN_EVALUATE_DLSS_EXT(vk_command_buffer, m_ngx_dlss, ngx_parameters, &dlss_params));
+		NGX_CHECK_THROW(NGX_VULKAN_EVALUATE_DLSS_EXT(m_vk_and_ngx->vk_command_buffer(), m_ngx_dlss, m_vk_and_ngx->ngx_parameters(), &dlss_params));
 
-		vk_command_buffer_end_and_submit_sync();
+		m_vk_and_ngx->vk_command_buffer_end_and_submit_sync();
 	}
 
 	bool is_hdr() const {
@@ -898,6 +965,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk_and_ngx;
+
 	NVSDK_NGX_Handle* m_ngx_dlss = {};
 	DlssFeatureSpecs m_specs;
 	bool m_is_hdr;
@@ -906,28 +975,29 @@ private:
 
 class Dlss : public IDlss {
 public:
-	Dlss(const Eigen::Vector2i& max_out_resolution)
+	Dlss(std::shared_ptr<VulkanAndNgx> vk_and_ngx, const Eigen::Vector2i& max_out_resolution)
 	:
+	m_vk_and_ngx{vk_and_ngx},
 	m_max_out_resolution{max_out_resolution},
 	// Allocate all buffers at output resolution and use dynamic sub-rects
 	// to use subsets of them. This avoids re-allocations when using DLSS
 	// with dynamically changing input resolution.
-	m_frame_buffer{max_out_resolution, 4},
-	m_depth_buffer{max_out_resolution, 1},
-	m_mvec_buffer{max_out_resolution, 2},
-	m_exposure_buffer{{1, 1}, 1},
-	m_output_buffer{max_out_resolution, 4}
+	m_frame_buffer{m_vk_and_ngx, max_out_resolution, 4},
+	m_depth_buffer{m_vk_and_ngx, max_out_resolution, 1},
+	m_mvec_buffer{m_vk_and_ngx, max_out_resolution, 2},
+	m_exposure_buffer{m_vk_and_ngx, {1, 1}, 1},
+	m_output_buffer{m_vk_and_ngx, max_out_resolution, 4}
 	{
 		// Various quality modes of DLSS
 		for (int i = 0; i < (int)EDlssQuality::NumDlssQualitySettings; ++i) {
 			try {
-				auto specs = dlss_feature_specs(max_out_resolution, (EDlssQuality)i);
+				auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), max_out_resolution, (EDlssQuality)i);
 
 				// Only emplace the specs if the feature can be created in practice!
-				DlssFeature{specs, true, true};
-				DlssFeature{specs, true, false};
-				DlssFeature{specs, false, true};
-				DlssFeature{specs, false, false};
+				DlssFeature{m_vk_and_ngx, specs, true, true};
+				DlssFeature{m_vk_and_ngx, specs, true, false};
+				DlssFeature{m_vk_and_ngx, specs, false, true};
+				DlssFeature{m_vk_and_ngx, specs, false, false};
 				m_dlss_specs.emplace_back(specs);
 			} catch (...) {}
 		}
@@ -943,13 +1013,13 @@ public:
 
 		for (const auto& out_resolution : reduced_out_resolutions) {
 			try {
-				auto specs = dlss_feature_specs(out_resolution, EDlssQuality::UltraPerformance);
+				auto specs = dlss_feature_specs(m_vk_and_ngx->ngx_parameters(), out_resolution, EDlssQuality::UltraPerformance);
 
 				// Only emplace the specs if the feature can be created in practice!
-				DlssFeature{specs, true, true};
-				DlssFeature{specs, true, false};
-				DlssFeature{specs, false, true};
-				DlssFeature{specs, false, false};
+				DlssFeature{m_vk_and_ngx, specs, true, true};
+				DlssFeature{m_vk_and_ngx, specs, true, false};
+				DlssFeature{m_vk_and_ngx, specs, false, true};
+				DlssFeature{m_vk_and_ngx, specs, false, false};
 				m_dlss_specs.emplace_back(specs);
 			} catch (...) {}
 		}
@@ -977,7 +1047,7 @@ public:
 		}
 
 		if (!m_dlss_feature || m_dlss_feature->is_hdr() != is_hdr || m_dlss_feature->sharpen() != sharpen || m_dlss_feature->quality() != specs.quality || m_dlss_feature->out_resolution() != specs.out_resolution) {
-			m_dlss_feature.reset(new DlssFeature{specs.out_resolution, is_hdr, sharpen, specs.quality});
+			m_dlss_feature.reset(new DlssFeature{m_vk_and_ngx, specs.out_resolution, is_hdr, sharpen, specs.quality});
 		}
 	}
 
@@ -1060,6 +1130,8 @@ public:
 	}
 
 private:
+	std::shared_ptr<VulkanAndNgx> m_vk_and_ngx;
+
 	std::unique_ptr<DlssFeature> m_dlss_feature;
 	std::vector<DlssFeatureSpecs> m_dlss_specs;
 
@@ -1072,47 +1144,8 @@ private:
 	Vector2i m_max_out_resolution;
 };
 
-std::shared_ptr<IDlss> dlss_init(const Eigen::Vector2i& out_resolution) {
-	return std::make_shared<Dlss>(out_resolution);
-}
-
-void vulkan_and_ngx_destroy() {
-	if (ngx_parameters) {
-		NVSDK_NGX_VULKAN_DestroyParameters(ngx_parameters);
-		ngx_parameters = nullptr;
-	}
-
-	if (ngx_initialized) {
-		NVSDK_NGX_VULKAN_Shutdown();
-		ngx_initialized = false;
-	}
-
-	if (vk_command_pool) {
-		vkDestroyCommandPool(vk_device, vk_command_pool, nullptr);
-		vk_command_pool = VK_NULL_HANDLE;
-	}
-
-	if (vk_device) {
-		vkDestroyDevice(vk_device, nullptr);
-		vk_device = VK_NULL_HANDLE;
-	}
-
-	if (vk_debug_messenger) {
-		auto DestroyDebugUtilsMessengerEXT = [](VkInstance instance, VkDebugUtilsMessengerEXT debugMessenger, const VkAllocationCallbacks* pAllocator) {
-			auto func = (PFN_vkDestroyDebugUtilsMessengerEXT)vkGetInstanceProcAddr(instance, "vkDestroyDebugUtilsMessengerEXT");
-			if (func != nullptr) {
-				func(instance, debugMessenger, pAllocator);
-			}
-		};
-
-		DestroyDebugUtilsMessengerEXT(vk_instance, vk_debug_messenger, nullptr);
-		vk_debug_messenger = VK_NULL_HANDLE;
-	}
-
-	if (vk_instance) {
-		vkDestroyInstance(vk_instance, nullptr);
-		vk_instance = VK_NULL_HANDLE;
-	}
+std::unique_ptr<IDlss> VulkanAndNgx::init_dlss(const Eigen::Vector2i& out_resolution) {
+	return std::make_unique<Dlss>(shared_from_this(), out_resolution);
 }
 
 NGP_NAMESPACE_END
diff --git a/src/main.cu b/src/main.cu
index f05b489598b3cb0208251230e359310954a8b2d3..ac79bd362f70f3e354fd2029bb4c907809dfa8aa 100644
--- a/src/main.cu
+++ b/src/main.cu
@@ -62,6 +62,13 @@ int main_func(const std::vector<std::string>& arguments) {
 		{"no-gui"},
 	};
 
+	Flag vr_flag{
+		parser,
+		"VR",
+		"Enables VR",
+		{"vr"}
+	};
+
 	Flag no_train_flag{
 		parser,
 		"NO_TRAIN",
@@ -170,6 +177,10 @@ int main_func(const std::vector<std::string>& arguments) {
 		testbed.init_window(width_flag ? get(width_flag) : 1920, height_flag ? get(height_flag) : 1080);
 	}
 
+	if (vr_flag) {
+		testbed.init_vr();
+	}
+
 	// Render/training loop
 	while (testbed.frame()) {
 		if (!gui) {
diff --git a/src/marching_cubes.cu b/src/marching_cubes.cu
index 28c28585ab86bd8e8104319ecde973a325fae723..2fc595405c5835268f7f83893e820bfe65cd714c 100644
--- a/src/marching_cubes.cu
+++ b/src/marching_cubes.cu
@@ -98,11 +98,11 @@ bool check_shader(uint32_t handle, const char* desc, bool program) {
 
 uint32_t compile_shader(bool pixel, const char* code) {
 	GLuint g_VertHandle = glCreateShader(pixel ? GL_FRAGMENT_SHADER : GL_VERTEX_SHADER );
-	const char* glsl_version = "#version 330\n";
+	const char* glsl_version = "#version 140\n";
 	const GLchar* strings[2] = { glsl_version, code};
 	glShaderSource(g_VertHandle, 2, strings, NULL);
 	glCompileShader(g_VertHandle);
-	if (!check_shader(g_VertHandle, pixel?"pixel":"vertex", false)) {
+	if (!check_shader(g_VertHandle, pixel? "pixel" : "vertex", false)) {
 		glDeleteShader(g_VertHandle);
 		return 0;
 	}
@@ -173,9 +173,9 @@ void draw_mesh_gl(
 
 	if (!program) {
 		vs = compile_shader(false, R"foo(
-layout (location = 0) in vec3 pos;
-layout (location = 1) in vec3 nor;
-layout (location = 2) in vec3 col;
+in vec3 pos;
+in vec3 nor;
+in vec3 col;
 out vec3 vtxcol;
 uniform mat4 camera;
 uniform vec2 f;
@@ -198,16 +198,11 @@ void main()
 }
 )foo");
 		ps = compile_shader(true, R"foo(
-layout (location = 0) out vec4 o;
+out vec4 o;
 in vec3 vtxcol;
 uniform int mode;
 void main() {
-	if (mode == 3) {
-		vec3 tricol = vec3((ivec3(923, 3572, 5423) * gl_PrimitiveID) & 255) * (1.0 / 255.0);
-		o = vec4(tricol, 1.0);
-	} else {
-		o = vec4(vtxcol, 1.0);
-	}
+	o = vec4(vtxcol, 1.0);
 }
 )foo");
 		program = glCreateProgram();
diff --git a/src/nerf_loader.cu b/src/nerf_loader.cu
index 3fd76ca003e845e183bdd3b914e3af12bc21c9c1..69af304f8e3aed090d0ffe81fd46d38dcfd4e1af 100644
--- a/src/nerf_loader.cu
+++ b/src/nerf_loader.cu
@@ -231,6 +231,10 @@ void read_lens(const nlohmann::json& json, Lens& lens, Vector2f& principal_point
 		mode = ELensMode::LatLong;
 	}
 
+	if (json.contains("equirectangular")) {
+		mode = ELensMode::Equirectangular;
+	}
+
 	// If there was an outer distortion mode, don't override it with nothing.
 	if (mode != ELensMode::Perspective) {
 		lens.mode = mode;
diff --git a/src/openxr_hmd.cu b/src/openxr_hmd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..960737237bb0f6ee1322da4ece538374400965da
--- /dev/null
+++ b/src/openxr_hmd.cu
@@ -0,0 +1,1240 @@
+/*
+ * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * NVIDIA CORPORATION and its licensors retain all intellectual property
+ * and proprietary rights in and to this software, related documentation
+ * and any modifications thereto.  Any use, reproduction, disclosure or
+ * distribution of this software and related documentation without an express
+ * license agreement from NVIDIA CORPORATION is strictly prohibited.
+ */
+
+/** @file   openxr_hmd.cu
+ *  @author Thomas Müller & Ingo Esser & Robert Menzel, NVIDIA
+ *  @brief  Wrapper around the OpenXR API, providing access to
+ *          per-eye framebuffers, lens parameters, visible area,
+ *          view, hand, and eye poses, as well as controller inputs.
+ */
+
+#define NOMINMAX
+
+#include <neural-graphics-primitives/common_device.cuh>
+#include <neural-graphics-primitives/marching_cubes.h>
+#include <neural-graphics-primitives/openxr_hmd.h>
+#include <neural-graphics-primitives/render_buffer.h>
+
+#include <openxr/openxr_reflection.h>
+
+#include <fmt/format.h>
+
+#include <imgui/imgui.h>
+
+#include <tinylogger/tinylogger.h>
+
+#include <tiny-cuda-nn/common.h>
+
+#include <string>
+#include <vector>
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wmissing-field-initializers" //TODO: XR struct are uninitiaized apart from their type
+#endif
+
+using namespace Eigen;
+using namespace tcnn;
+
+NGP_NAMESPACE_BEGIN
+
+// function XrEnumStr turns enum into string for printing
+// uses expansion macro and data provided in openxr_reflection.h
+#define XR_ENUM_CASE_STR(name, val) \
+	case name:                      \
+		return #name;
+#define XR_ENUM_STR(enum_type)                                                     \
+	constexpr const char* XrEnumStr(enum_type e) {                                 \
+		switch (e) {                                                               \
+			XR_LIST_ENUM_##enum_type(XR_ENUM_CASE_STR) default : return "Unknown"; \
+		}                                                                          \
+	}
+
+XR_ENUM_STR(XrViewConfigurationType)
+XR_ENUM_STR(XrEnvironmentBlendMode)
+XR_ENUM_STR(XrReferenceSpaceType)
+XR_ENUM_STR(XrStructureType)
+XR_ENUM_STR(XrSessionState)
+
+/// Checks the result of a xrXXXXXX call and throws an error on failure
+#define XR_CHECK_THROW(x)                                                                                   \
+	do {                                                                                                              \
+		XrResult result = x;                                                                                          \
+		if (XR_FAILED(result)) {                                                                                      \
+			char buffer[XR_MAX_RESULT_STRING_SIZE];                                                                   \
+			XrResult result_to_string_result = xrResultToString(m_instance, result, buffer);                            \
+			if (XR_FAILED(result_to_string_result)) {                                                                 \
+				throw std::runtime_error{std::string(FILE_LINE " " #x " failed, but could not obtain error string")}; \
+			} else {                                                                                                  \
+				throw std::runtime_error{std::string(FILE_LINE " " #x " failed with error ") + buffer};               \
+			}                                                                                                         \
+		}                                                                                                             \
+	} while(0)
+
+OpenXRHMD::Swapchain::Swapchain(XrSwapchainCreateInfo& rgba_create_info, XrSwapchainCreateInfo& depth_create_info, XrSession& session, XrInstance& m_instance) {
+	ScopeGuard cleanup_guard{[&]() { clear(); }};
+
+	XR_CHECK_THROW(xrCreateSwapchain(session, &rgba_create_info, &handle));
+
+	width = rgba_create_info.width;
+	height = rgba_create_info.height;
+
+	{
+		uint32_t size;
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, 0, &size, nullptr));
+
+		images_gl.resize(size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR});
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(handle, size, &size, (XrSwapchainImageBaseHeader*)images_gl.data()));
+
+		// One framebuffer per swapchain image
+		framebuffers_gl.resize(size);
+	}
+
+	if (depth_create_info.format != 0) {
+		XR_CHECK_THROW(xrCreateSwapchain(session, &depth_create_info, &depth_handle));
+
+		uint32_t depth_size;
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, 0, &depth_size, nullptr));
+
+		depth_images_gl.resize(depth_size, {XR_TYPE_SWAPCHAIN_IMAGE_OPENGL_KHR});
+		XR_CHECK_THROW(xrEnumerateSwapchainImages(depth_handle, depth_size, &depth_size, (XrSwapchainImageBaseHeader*)depth_images_gl.data()));
+
+		// We might have a different number of depth swapchain images as we have framebuffers,
+		// so we will need to bind an acquired depth image to the current framebuffer on the
+		// fly later on.
+	}
+
+	glGenFramebuffers(framebuffers_gl.size(), framebuffers_gl.data());
+
+	cleanup_guard.disarm();
+}
+
+OpenXRHMD::Swapchain::~Swapchain() {
+	clear();
+}
+
+void OpenXRHMD::Swapchain::clear() {
+	if (!framebuffers_gl.empty()) {
+		glDeleteFramebuffers(framebuffers_gl.size(), framebuffers_gl.data());
+	}
+
+	if (depth_handle != XR_NULL_HANDLE) {
+		xrDestroySwapchain(depth_handle);
+		depth_handle = XR_NULL_HANDLE;
+	}
+
+	if (handle != XR_NULL_HANDLE) {
+		xrDestroySwapchain(handle);
+		handle = XR_NULL_HANDLE;
+	}
+}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+OpenXRHMD::OpenXRHMD(HDC hdc, HGLRC hglrc) {
+#elif defined(XR_USE_PLATFORM_XLIB)
+OpenXRHMD::OpenXRHMD(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) {
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+OpenXRHMD::OpenXRHMD(wl_display* display) {
+#endif
+	ScopeGuard cleanup_guard{[&]() { clear(); }};
+
+	init_create_xr_instance();
+	init_get_xr_system();
+	init_configure_xr_views();
+	init_check_for_xr_blend_mode();
+#if defined(XR_USE_PLATFORM_WIN32)
+	init_open_gl(hdc, hglrc);
+#elif defined(XR_USE_PLATFORM_XLIB)
+	init_open_gl(xDisplay, visualid, glxFBConfig, glxDrawable, glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	init_open_gl(display);
+#endif
+	init_xr_session();
+	init_xr_actions();
+	init_xr_spaces();
+	init_xr_swapchain_open_gl();
+	init_open_gl_shaders();
+
+	cleanup_guard.disarm();
+	tlog::success() << "Initialized OpenXR for " << m_system_properties.systemName;
+	// tlog::success() << " "
+	// 	<< " depth=" << (m_supports_composition_layer_depth ? "true" : "false")
+	// 	<< " mask=" << (m_supports_hidden_area_mask ? "true" : "false")
+	// 	<< " eye=" << (m_supports_hidden_area_mask ? "true" : "false")
+	// 	;
+}
+
+OpenXRHMD::~OpenXRHMD() {
+	clear();
+}
+
+void OpenXRHMD::clear() {
+	auto xr_destroy = [&](auto& handle, auto destroy_fun) {
+		if (handle != XR_NULL_HANDLE) {
+			destroy_fun(handle);
+			handle = XR_NULL_HANDLE;
+		}
+	};
+
+	xr_destroy(m_pose_action, xrDestroyAction);
+	xr_destroy(m_thumbstick_actions[0], xrDestroyAction);
+	xr_destroy(m_thumbstick_actions[1], xrDestroyAction);
+	xr_destroy(m_press_action, xrDestroyAction);
+	xr_destroy(m_grab_action, xrDestroyAction);
+
+	xr_destroy(m_action_set, xrDestroyActionSet);
+
+	m_swapchains.clear();
+	xr_destroy(m_space, xrDestroySpace);
+	xr_destroy(m_session, xrDestroySession);
+	xr_destroy(m_instance, xrDestroyInstance);
+}
+
+void OpenXRHMD::init_create_xr_instance() {
+	std::vector<const char*> layers = {};
+	std::vector<const char*> extensions = {
+		XR_KHR_OPENGL_ENABLE_EXTENSION_NAME,
+	};
+
+	auto print_extension_properties = [](const char* layer_name) {
+		uint32_t size;
+		xrEnumerateInstanceExtensionProperties(layer_name, 0, &size, nullptr);
+		std::vector<XrExtensionProperties> props(size, {XR_TYPE_EXTENSION_PROPERTIES});
+		xrEnumerateInstanceExtensionProperties(layer_name, size, &size, props.data());
+		tlog::info() << fmt::format("Extensions ({}):", props.size());
+		for (XrExtensionProperties extension : props) {
+			tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion);
+		}
+	};
+
+	uint32_t size;
+	xrEnumerateApiLayerProperties(0, &size, nullptr);
+	m_api_layer_properties.clear();
+	m_api_layer_properties.resize(size, {XR_TYPE_API_LAYER_PROPERTIES});
+	xrEnumerateApiLayerProperties(size, &size, m_api_layer_properties.data());
+
+	if (m_print_api_layers) {
+		tlog::info() << fmt::format("API Layers ({}):", m_api_layer_properties.size());
+		for (auto p : m_api_layer_properties) {
+			tlog::info() << fmt::format(
+				"{} (v {}.{}.{}, {}) {}",
+				p.layerName,
+				XR_VERSION_MAJOR(p.specVersion),
+				XR_VERSION_MINOR(p.specVersion),
+				XR_VERSION_PATCH(p.specVersion),
+				p.layerVersion,
+				p.description
+			);
+			print_extension_properties(p.layerName);
+		}
+	}
+
+	if (layers.size() != 0) {
+		for (const auto& e : layers) {
+			bool found = false;
+			for (XrApiLayerProperties layer : m_api_layer_properties) {
+				if (strcmp(e, layer.layerName) == 0) {
+					found = true;
+					break;
+				}
+			}
+
+			if (!found) {
+				throw std::runtime_error{fmt::format("OpenXR API layer {} not found", e)};
+			}
+		}
+	}
+
+	xrEnumerateInstanceExtensionProperties(nullptr, 0, &size, nullptr);
+	m_instance_extension_properties.clear();
+	m_instance_extension_properties.resize(size, {XR_TYPE_EXTENSION_PROPERTIES});
+	xrEnumerateInstanceExtensionProperties(nullptr, size, &size, m_instance_extension_properties.data());
+
+	if (m_print_extensions) {
+		tlog::info() << fmt::format("Instance extensions ({}):", m_instance_extension_properties.size());
+		for (XrExtensionProperties extension : m_instance_extension_properties) {
+			tlog::info() << fmt::format("\t{} (Version {})", extension.extensionName, extension.extensionVersion);
+		}
+	}
+
+	auto has_extension = [&](const char* e) {
+		for (XrExtensionProperties extension : m_instance_extension_properties) {
+			if (strcmp(e, extension.extensionName) == 0) {
+				return true;
+			}
+		}
+
+		return false;
+	};
+
+	for (const auto& e : extensions) {
+		if (!has_extension(e)) {
+			throw std::runtime_error{fmt::format("Required OpenXR extension {} not found", e)};
+		}
+	}
+
+	auto add_extension_if_supported = [&](const char* extension) {
+		if (has_extension(extension)) {
+			extensions.emplace_back(extension);
+			return true;
+		}
+
+		return false;
+	};
+
+	if (add_extension_if_supported(XR_KHR_COMPOSITION_LAYER_DEPTH_EXTENSION_NAME)) {
+		m_supports_composition_layer_depth = true;
+	}
+
+	if (add_extension_if_supported(XR_KHR_VISIBILITY_MASK_EXTENSION_NAME)) {
+		m_supports_hidden_area_mask = true;
+	}
+
+	if (add_extension_if_supported(XR_EXT_EYE_GAZE_INTERACTION_EXTENSION_NAME)) {
+		m_supports_eye_tracking = true;
+	}
+
+	XrInstanceCreateInfo instance_create_info = {XR_TYPE_INSTANCE_CREATE_INFO};
+	instance_create_info.applicationInfo = {};
+	strncpy(instance_create_info.applicationInfo.applicationName, "Instant Neural Graphics Primitives v" NGP_VERSION, XR_MAX_APPLICATION_NAME_SIZE);
+	instance_create_info.applicationInfo.applicationVersion = 1;
+	strncpy(instance_create_info.applicationInfo.engineName, "Instant Neural Graphics Primitives v" NGP_VERSION, XR_MAX_ENGINE_NAME_SIZE);
+	instance_create_info.applicationInfo.engineVersion = 1;
+	instance_create_info.applicationInfo.apiVersion = XR_CURRENT_API_VERSION;
+	instance_create_info.enabledExtensionCount = (uint32_t)extensions.size();
+	instance_create_info.enabledExtensionNames = extensions.data();
+	instance_create_info.enabledApiLayerCount = (uint32_t)layers.size();
+	instance_create_info.enabledApiLayerNames = layers.data();
+
+	if (XR_FAILED(xrCreateInstance(&instance_create_info, &m_instance))) {
+		throw std::runtime_error{"Failed to create OpenXR instance"};
+	}
+
+	XR_CHECK_THROW(xrGetInstanceProperties(m_instance, &m_instance_properties));
+	if (m_print_instance_properties) {
+		tlog::info() << "Instance Properties";
+		tlog::info() << fmt::format("\t        runtime name: '{}'", m_instance_properties.runtimeName);
+		const auto& v = m_instance_properties.runtimeVersion;
+		tlog::info() << fmt::format(
+			"\t     runtime version: {}.{}.{}",
+			XR_VERSION_MAJOR(v),
+			XR_VERSION_MINOR(v),
+			XR_VERSION_PATCH(v)
+		);
+	}
+}
+
+void OpenXRHMD::init_get_xr_system() {
+	XrSystemGetInfo system_get_info = {XR_TYPE_SYSTEM_GET_INFO, nullptr, XR_FORM_FACTOR_HEAD_MOUNTED_DISPLAY};
+	XR_CHECK_THROW(xrGetSystem(m_instance, &system_get_info, &m_system_id));
+
+	XR_CHECK_THROW(xrGetSystemProperties(m_instance, m_system_id, &m_system_properties));
+	if (m_print_system_properties) {
+		tlog::info() << "System Properties";
+		tlog::info() << fmt::format("\t                name: '{}'", m_system_properties.systemName);
+		tlog::info() << fmt::format("\t            vendorId: {:#x}", m_system_properties.vendorId);
+		tlog::info() << fmt::format("\t            systemId: {:#x}", m_system_properties.systemId);
+		tlog::info() << fmt::format("\t     max layer count: {}", m_system_properties.graphicsProperties.maxLayerCount);
+		tlog::info() << fmt::format("\t       max img width: {}", m_system_properties.graphicsProperties.maxSwapchainImageWidth);
+		tlog::info() << fmt::format("\t      max img height: {}", m_system_properties.graphicsProperties.maxSwapchainImageHeight);
+		tlog::info() << fmt::format("\torientation tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO");
+		tlog::info() << fmt::format("\t   position tracking: {}", m_system_properties.trackingProperties.orientationTracking ? "YES" : "NO");
+	}
+}
+
+void OpenXRHMD::init_configure_xr_views() {
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, 0, &size, nullptr));
+	std::vector<XrViewConfigurationType> view_config_types(size);
+	XR_CHECK_THROW(xrEnumerateViewConfigurations(m_instance, m_system_id, size, &size, view_config_types.data()));
+
+	if (m_print_view_configuration_types) {
+		tlog::info() << fmt::format("View Configuration Types ({}):", view_config_types.size());
+		for (const auto& t : view_config_types) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(t));
+		}
+	}
+
+	// view configurations we support, in descending preference
+	const std::vector<XrViewConfigurationType> preferred_view_config_types = {
+		//XR_VIEW_CONFIGURATION_TYPE_PRIMARY_QUAD_VARJO,
+		XR_VIEW_CONFIGURATION_TYPE_PRIMARY_STEREO
+	};
+
+	bool found = false;
+	for (const auto& p : preferred_view_config_types) {
+		for (const auto& t : view_config_types) {
+			if (p == t) {
+				found = true;
+				m_view_configuration_type = t;
+			}
+		}
+	}
+
+	if (!found) {
+		throw std::runtime_error{"Could not find a suitable OpenXR view configuration type"};
+	}
+
+	// get view configuration properties
+	XR_CHECK_THROW(xrGetViewConfigurationProperties(m_instance, m_system_id, m_view_configuration_type, &m_view_configuration_properties));
+	if (m_print_view_configuration_properties) {
+		tlog::info() << "View Configuration Properties:";
+		tlog::info() << fmt::format("\t         Type: {}", XrEnumStr(m_view_configuration_type));
+		tlog::info() << fmt::format("\t         FOV Mutable: {}", m_view_configuration_properties.fovMutable ? "YES" : "NO");
+	}
+
+	// enumerate view configuration views
+	XR_CHECK_THROW(xrEnumerateViewConfigurationViews(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr));
+	m_view_configuration_views.clear();
+	m_view_configuration_views.resize(size, {XR_TYPE_VIEW_CONFIGURATION_VIEW});
+	XR_CHECK_THROW(xrEnumerateViewConfigurationViews(
+		m_instance,
+		m_system_id,
+		m_view_configuration_type,
+		size,
+		&size,
+		m_view_configuration_views.data()
+	));
+
+	if (m_print_view_configuration_view) {
+		tlog::info() << "View Configuration Views, Width x Height x Samples";
+		for (size_t i = 0; i < m_view_configuration_views.size(); ++i) {
+			const auto& view = m_view_configuration_views[i];
+			tlog::info() << fmt::format(
+				"\tView {}\tRecommended: {}x{}x{}  Max: {}x{}x{}",
+				i,
+				view.recommendedImageRectWidth,
+				view.recommendedImageRectHeight,
+				view.recommendedSwapchainSampleCount,
+				view.maxImageRectWidth,
+				view.maxImageRectHeight,
+				view.maxSwapchainSampleCount
+			);
+		}
+	}
+}
+
+void OpenXRHMD::init_check_for_xr_blend_mode() {
+	// enumerate environment blend modes
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes(m_instance, m_system_id, m_view_configuration_type, 0, &size, nullptr));
+	m_environment_blend_modes.resize(size);
+	XR_CHECK_THROW(xrEnumerateEnvironmentBlendModes(
+		m_instance,
+		m_system_id,
+		m_view_configuration_type,
+		size,
+		&size,
+		m_environment_blend_modes.data()
+	));
+
+	if (m_print_environment_blend_modes) {
+		tlog::info() << fmt::format("Environment Blend Modes ({}):", m_environment_blend_modes.size());
+	}
+
+	bool found = false;
+	for (const auto& m : m_environment_blend_modes) {
+		if (m_print_environment_blend_modes) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(m));
+		}
+
+		if (m == m_environment_blend_mode) {
+			found = true;
+		}
+	}
+
+	if (!found) {
+		throw std::runtime_error{fmt::format("OpenXR environment blend mode {} not found", XrEnumStr(m_environment_blend_mode))};
+	}
+}
+
+void OpenXRHMD::init_xr_actions() {
+	// paths for left (0) and right (1) hands
+	XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/left", &m_hand_paths[0]));
+	XR_CHECK_THROW(xrStringToPath(m_instance, "/user/hand/right", &m_hand_paths[1]));
+
+	// create action set
+	XrActionSetCreateInfo action_set_create_info{XR_TYPE_ACTION_SET_CREATE_INFO, nullptr, "actionset", "actionset", 0};
+	XR_CHECK_THROW(xrCreateActionSet(m_instance, &action_set_create_info, &m_action_set));
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"hand_pose",
+			XR_ACTION_TYPE_POSE_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Hand pose"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_pose_action));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"thumbstick_left",
+			XR_ACTION_TYPE_VECTOR2F_INPUT,
+			0,
+			nullptr,
+			"Left thumbstick"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[0]));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"thumbstick_right",
+			XR_ACTION_TYPE_VECTOR2F_INPUT,
+			0,
+			nullptr,
+			"Right thumbstick"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_thumbstick_actions[1]));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"press",
+			XR_ACTION_TYPE_BOOLEAN_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Press"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_press_action));
+	}
+
+	{
+		XrActionCreateInfo action_create_info{
+			XR_TYPE_ACTION_CREATE_INFO,
+			nullptr,
+			"grab",
+			XR_ACTION_TYPE_FLOAT_INPUT,
+			(uint32_t)m_hand_paths.size(),
+			m_hand_paths.data(),
+			"Grab"
+		};
+		XR_CHECK_THROW(xrCreateAction(m_action_set, &action_create_info, &m_grab_action));
+	}
+
+	auto create_binding = [&](XrAction action, const std::string& binding_path_str) {
+		XrPath binding;
+		XR_CHECK_THROW(xrStringToPath(m_instance, binding_path_str.c_str(), &binding));
+		return XrActionSuggestedBinding{action, binding};
+	};
+
+	auto suggest_bindings = [&](const std::string& interaction_profile_path_str, const std::vector<XrActionSuggestedBinding>& bindings) {
+		XrPath interaction_profile;
+		XR_CHECK_THROW(xrStringToPath(m_instance, interaction_profile_path_str.c_str(), &interaction_profile));
+		XrInteractionProfileSuggestedBinding suggested_binding{
+			XR_TYPE_INTERACTION_PROFILE_SUGGESTED_BINDING,
+			nullptr,
+			interaction_profile,
+			(uint32_t)bindings.size(),
+			bindings.data()
+		};
+		XR_CHECK_THROW(xrSuggestInteractionProfileBindings(m_instance, &suggested_binding));
+	};
+
+	suggest_bindings("/interaction_profiles/khr/simple_controller", {
+		create_binding(m_pose_action, "/user/hand/left/input/grip/pose"),
+		create_binding(m_pose_action, "/user/hand/right/input/grip/pose"),
+	});
+
+	auto suggest_controller_bindings = [&](const std::string& xy, const std::string& press, const std::string& grab, const std::string& interaction_profile_path_str) {
+		suggest_bindings(interaction_profile_path_str, {
+			create_binding(m_pose_action, "/user/hand/left/input/grip/pose"),
+			create_binding(m_pose_action, "/user/hand/right/input/grip/pose"),
+			create_binding(m_thumbstick_actions[0], std::string{"/user/hand/left/input/"} + xy),
+			create_binding(m_thumbstick_actions[1], std::string{"/user/hand/right/input/"} + xy),
+			create_binding(m_press_action, std::string{"/user/hand/left/input/"} + press),
+			create_binding(m_press_action, std::string{"/user/hand/right/input/"} + press),
+			create_binding(m_grab_action, std::string{"/user/hand/left/input/"} + grab),
+			create_binding(m_grab_action, std::string{"/user/hand/right/input/"} + grab),
+		});
+	};
+
+	suggest_controller_bindings("trackpad",   "select/click",     "trackpad/click", "/interaction_profiles/google/daydream_controller");
+	suggest_controller_bindings("trackpad",   "trackpad/click",   "trigger/click",  "/interaction_profiles/htc/vive_controller");
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  "/interaction_profiles/microsoft/motion_controller");
+	suggest_controller_bindings("trackpad",   "trackpad/click",   "trigger/click",  "/interaction_profiles/oculus/go_controller");
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  "/interaction_profiles/oculus/touch_controller");
+	suggest_controller_bindings("thumbstick", "thumbstick/click", "trigger/value",  "/interaction_profiles/valve/index_controller");
+
+	// Xbox controller (currently not functional)
+	suggest_bindings("/interaction_profiles/microsoft/xbox_controller", {
+		create_binding(m_thumbstick_actions[0], std::string{"/user/gamepad/input/thumbstick_left"}),
+		create_binding(m_thumbstick_actions[1], std::string{"/user/gamepad/input/thumbstick_right"}),
+	});
+}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+void OpenXRHMD::init_open_gl(HDC hdc, HGLRC hglrc) {
+#elif defined(XR_USE_PLATFORM_XLIB)
+void OpenXRHMD::init_open_gl(Display* xDisplay, uint32_t visualid, GLXFBConfig glxFBConfig, GLXDrawable glxDrawable, GLXContext glxContext) {
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+void OpenXRHMD::init_open_gl(wl_display* display) {
+#endif
+	// GL graphics requirements
+	PFN_xrGetOpenGLGraphicsRequirementsKHR xrGetOpenGLGraphicsRequirementsKHR = nullptr;
+	XR_CHECK_THROW(xrGetInstanceProcAddr(
+		m_instance,
+		"xrGetOpenGLGraphicsRequirementsKHR",
+		reinterpret_cast<PFN_xrVoidFunction*>(&xrGetOpenGLGraphicsRequirementsKHR)
+	));
+
+	XrGraphicsRequirementsOpenGLKHR graphics_requirements{XR_TYPE_GRAPHICS_REQUIREMENTS_OPENGL_KHR};
+	xrGetOpenGLGraphicsRequirementsKHR(m_instance, m_system_id, &graphics_requirements);
+	XrVersion min_version = graphics_requirements.minApiVersionSupported;
+	GLint major = 0;
+	GLint minor = 0;
+	glGetIntegerv(GL_MAJOR_VERSION, &major);
+	glGetIntegerv(GL_MINOR_VERSION, &minor);
+	const XrVersion have_version = XR_MAKE_VERSION(major, minor, 0);
+
+	if (have_version < min_version) {
+		tlog::info() << fmt::format(
+			"Required OpenGL version: {}.{}, found OpenGL version: {}.{}",
+			XR_VERSION_MAJOR(min_version),
+			XR_VERSION_MINOR(min_version),
+			major,
+			minor
+		);
+
+		throw std::runtime_error{"Insufficient graphics API support"};
+	}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+	m_graphics_binding.hDC = hdc;
+	m_graphics_binding.hGLRC = hglrc;
+#elif defined(XR_USE_PLATFORM_XLIB)
+	m_graphics_binding.xDisplay = xDisplay;
+	m_graphics_binding.visualid = visualid;
+	m_graphics_binding.glxFBConfig = glxFBConfig;
+	m_graphics_binding.glxDrawable = glxDrawable;
+	m_graphics_binding.glxContext = glxContext;
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+	m_graphics_binding.display = display;
+#endif
+}
+
+void OpenXRHMD::init_xr_session() {
+	// create session
+	XrSessionCreateInfo create_info{
+		XR_TYPE_SESSION_CREATE_INFO,
+		reinterpret_cast<const XrBaseInStructure*>(&m_graphics_binding),
+		0,
+		m_system_id
+	};
+
+	XR_CHECK_THROW(xrCreateSession(m_instance, &create_info, &m_session));
+
+	// tlog::info() << fmt::format("Created session {}", fmt::ptr(m_session));
+}
+
+void OpenXRHMD::init_xr_spaces() {
+	// reference space
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, 0, &size, nullptr));
+	m_reference_spaces.clear();
+	m_reference_spaces.resize(size);
+	XR_CHECK_THROW(xrEnumerateReferenceSpaces(m_session, size, &size, m_reference_spaces.data()));
+
+	if (m_print_reference_spaces) {
+		tlog::info() << fmt::format("Reference spaces ({}):", m_reference_spaces.size());
+		for (const auto& r : m_reference_spaces) {
+			tlog::info() << fmt::format("\t{}", XrEnumStr(r));
+		}
+	}
+
+	XrReferenceSpaceCreateInfo reference_space_create_info{XR_TYPE_REFERENCE_SPACE_CREATE_INFO};
+	reference_space_create_info.referenceSpaceType = XR_REFERENCE_SPACE_TYPE_LOCAL;
+	reference_space_create_info.poseInReferenceSpace = XrPosef{};
+	reference_space_create_info.poseInReferenceSpace.orientation.w = 1.0f;
+	XR_CHECK_THROW(xrCreateReferenceSpace(m_session, &reference_space_create_info, &m_space));
+	XR_CHECK_THROW(xrGetReferenceSpaceBoundsRect(m_session, reference_space_create_info.referenceSpaceType, &m_bounds));
+
+	if (m_print_reference_spaces) {
+		tlog::info() << fmt::format("Using reference space {}", XrEnumStr(reference_space_create_info.referenceSpaceType));
+		tlog::info() << fmt::format("Reference space boundaries: {} x {}", m_bounds.width, m_bounds.height);
+	}
+
+	// action space
+	XrActionSpaceCreateInfo action_space_create_info{XR_TYPE_ACTION_SPACE_CREATE_INFO};
+	action_space_create_info.action = m_pose_action;
+	action_space_create_info.poseInActionSpace.orientation.w = 1.0f;
+	action_space_create_info.subactionPath = m_hand_paths[0];
+	XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[0]));
+	action_space_create_info.subactionPath = m_hand_paths[1];
+	XR_CHECK_THROW(xrCreateActionSpace(m_session, &action_space_create_info, &m_hand_spaces[1]));
+
+	// attach action set
+	XrSessionActionSetsAttachInfo attach_info{XR_TYPE_SESSION_ACTION_SETS_ATTACH_INFO};
+	attach_info.countActionSets = 1;
+	attach_info.actionSets = &m_action_set;
+	XR_CHECK_THROW(xrAttachSessionActionSets(m_session, &attach_info));
+}
+
+void OpenXRHMD::init_xr_swapchain_open_gl() {
+	// swap chains
+	uint32_t size;
+	XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, 0, &size, nullptr));
+	std::vector<int64_t> swapchain_formats(size);
+	XR_CHECK_THROW(xrEnumerateSwapchainFormats(m_session, size, &size, swapchain_formats.data()));
+
+	if (m_print_available_swapchain_formats) {
+		tlog::info() << fmt::format("Swapchain formats ({}):", swapchain_formats.size());
+		for (const auto& f : swapchain_formats) {
+			tlog::info() << fmt::format("\t{:#x}", f);
+		}
+	}
+
+	auto find_compatible_swapchain_format = [&](const std::vector<int64_t>& candidates) {
+		for (auto format : candidates) {
+			if (std::find(std::begin(swapchain_formats), std::end(swapchain_formats), format) != std::end(swapchain_formats)) {
+				return format;
+			}
+		}
+
+		throw std::runtime_error{"No compatible OpenXR swapchain format found"};
+	};
+
+	m_swapchain_rgba_format = find_compatible_swapchain_format({
+		GL_SRGB8_ALPHA8,
+		GL_SRGB8,
+		GL_RGBA8,
+	});
+
+	if (m_supports_composition_layer_depth) {
+		m_swapchain_depth_format = find_compatible_swapchain_format({
+			GL_DEPTH_COMPONENT32F,
+			GL_DEPTH_COMPONENT24,
+			GL_DEPTH_COMPONENT16,
+		});
+	}
+
+	// tlog::info() << fmt::format("Chosen swapchain format: {:#x}", m_swapchain_rgba_format);
+	for (const auto& vcv : m_view_configuration_views) {
+		XrSwapchainCreateInfo rgba_swapchain_create_info{XR_TYPE_SWAPCHAIN_CREATE_INFO};
+		rgba_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_COLOR_ATTACHMENT_BIT;
+		rgba_swapchain_create_info.format = m_swapchain_rgba_format;
+		rgba_swapchain_create_info.sampleCount = 1;
+		rgba_swapchain_create_info.width = vcv.recommendedImageRectWidth;
+		rgba_swapchain_create_info.height = vcv.recommendedImageRectHeight;
+		rgba_swapchain_create_info.faceCount = 1;
+		rgba_swapchain_create_info.arraySize = 1;
+		rgba_swapchain_create_info.mipCount = 1;
+
+		XrSwapchainCreateInfo depth_swapchain_create_info = rgba_swapchain_create_info;
+		depth_swapchain_create_info.usageFlags = XR_SWAPCHAIN_USAGE_SAMPLED_BIT | XR_SWAPCHAIN_USAGE_DEPTH_STENCIL_ATTACHMENT_BIT;
+		depth_swapchain_create_info.format = m_swapchain_depth_format;
+
+		m_swapchains.emplace_back(rgba_swapchain_create_info, depth_swapchain_create_info, m_session, m_instance);
+	}
+}
+
+void OpenXRHMD::init_open_gl_shaders() {
+	// Hidden area mask program
+	{
+		static const char* shader_vert = R"(#version 140
+			in vec2 pos;
+			uniform mat4 project;
+			void main() {
+				vec4 pos = project * vec4(pos, -1.0, 1.0);
+				pos.xyz /= pos.w;
+				pos.y *= -1.0;
+				gl_Position = pos;
+			})";
+
+		static const char* shader_frag = R"(#version 140
+			out vec4 frag_color;
+			void main() {
+				frag_color = vec4(0.0, 0.0, 0.0, 1.0);
+			})";
+
+		GLuint vert = glCreateShader(GL_VERTEX_SHADER);
+		glShaderSource(vert, 1, &shader_vert, NULL);
+		glCompileShader(vert);
+		check_shader(vert, "OpenXR hidden area mask vertex shader", false);
+
+		GLuint frag = glCreateShader(GL_FRAGMENT_SHADER);
+		glShaderSource(frag, 1, &shader_frag, NULL);
+		glCompileShader(frag);
+		check_shader(frag, "OpenXR hidden area mask fragment shader", false);
+
+		m_hidden_area_mask_program = glCreateProgram();
+		glAttachShader(m_hidden_area_mask_program, vert);
+		glAttachShader(m_hidden_area_mask_program, frag);
+		glLinkProgram(m_hidden_area_mask_program);
+		check_shader(m_hidden_area_mask_program, "OpenXR hidden area mask shader program", true);
+
+		glDeleteShader(vert);
+		glDeleteShader(frag);
+	}
+}
+
+void OpenXRHMD::session_state_change(XrSessionState state, ControlFlow& flow) {
+	//tlog::info() << fmt::format("New session state {}", XrEnumStr(state));
+	switch (state) {
+		case XR_SESSION_STATE_READY: {
+			XrSessionBeginInfo sessionBeginInfo {XR_TYPE_SESSION_BEGIN_INFO};
+			sessionBeginInfo.primaryViewConfigurationType = m_view_configuration_type;
+			XR_CHECK_THROW(xrBeginSession(m_session, &sessionBeginInfo));
+			break;
+		}
+		case XR_SESSION_STATE_STOPPING: {
+			XR_CHECK_THROW(xrEndSession(m_session));
+			break;
+		}
+		case XR_SESSION_STATE_EXITING: {
+			flow = ControlFlow::QUIT;
+			break;
+		}
+		case XR_SESSION_STATE_LOSS_PENDING: {
+			flow = ControlFlow::RESTART;
+			break;
+		}
+		default: {
+			break;
+		}
+	}
+}
+
+OpenXRHMD::ControlFlow OpenXRHMD::poll_events() {
+	bool more = true;
+	ControlFlow flow = ControlFlow::CONTINUE;
+	while (more) {
+		// poll events
+		XrEventDataBuffer event {XR_TYPE_EVENT_DATA_BUFFER, nullptr};
+		XrResult result = xrPollEvent(m_instance, &event);
+
+		if (XR_FAILED(result)) {
+			tlog::error() << "xrPollEvent failed";
+		} else if (XR_SUCCESS == result) {
+			switch (event.type) {
+				case XR_TYPE_EVENT_DATA_SESSION_STATE_CHANGED: {
+					const XrEventDataSessionStateChanged& e = *reinterpret_cast<XrEventDataSessionStateChanged*>(&event);
+					//tlog::info() << "Session state change";
+					//tlog::info() << fmt::format("\t from {}\t   to {}", XrEnumStr(m_session_state), XrEnumStr(e.state));
+					//tlog::info() << fmt::format("\t session {}, time {}", fmt::ptr(e.session), e.time);
+					m_session_state = e.state;
+					session_state_change(e.state, flow);
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_INSTANCE_LOSS_PENDING: {
+					flow = ControlFlow::RESTART;
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_VISIBILITY_MASK_CHANGED_KHR: {
+					m_hidden_area_masks.clear();
+					break;
+				}
+
+				case XR_TYPE_EVENT_DATA_INTERACTION_PROFILE_CHANGED: {
+					break; // Can ignore
+				}
+
+				default: {
+					tlog::info() << fmt::format("Unhandled event type {}", XrEnumStr(event.type));
+					break;
+				}
+			}
+		} else if (XR_EVENT_UNAVAILABLE == result) {
+			more = false;
+		}
+	}
+	return flow;
+}
+
+__global__ void read_hidden_area_mask_kernel(const Vector2i resolution, cudaSurfaceObject_t surface, uint8_t* __restrict__ mask) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	uint32_t idx = x + resolution.x() * y;
+	surf2Dread(&mask[idx], surface, x, y);
+}
+
+std::shared_ptr<Buffer2D<uint8_t>> OpenXRHMD::rasterize_hidden_area_mask(uint32_t view_index, const XrCompositionLayerProjectionView& view) {
+	if (!m_supports_hidden_area_mask) {
+		return {};
+	}
+
+	PFN_xrGetVisibilityMaskKHR xrGetVisibilityMaskKHR = nullptr;
+	XR_CHECK_THROW(xrGetInstanceProcAddr(
+		m_instance,
+		"xrGetVisibilityMaskKHR",
+		reinterpret_cast<PFN_xrVoidFunction*>(&xrGetVisibilityMaskKHR)
+	));
+
+	XrVisibilityMaskKHR visibility_mask{XR_TYPE_VISIBILITY_MASK_KHR};
+	XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask));
+
+	if (visibility_mask.vertexCountOutput == 0 || visibility_mask.indexCountOutput == 0) {
+		return nullptr;
+	}
+
+	std::vector<XrVector2f> vertices(visibility_mask.vertexCountOutput);
+	std::vector<uint32_t> indices(visibility_mask.indexCountOutput);
+
+	visibility_mask.vertices = vertices.data();
+	visibility_mask.indices = indices.data();
+
+	visibility_mask.vertexCapacityInput = visibility_mask.vertexCountOutput;
+	visibility_mask.indexCapacityInput = visibility_mask.indexCountOutput;
+
+	XR_CHECK_THROW(xrGetVisibilityMaskKHR(m_session, m_view_configuration_type, view_index, XR_VISIBILITY_MASK_TYPE_HIDDEN_TRIANGLE_MESH_KHR, &visibility_mask));
+
+	CUDA_CHECK_THROW(cudaDeviceSynchronize());
+
+	Vector2i size = {view.subImage.imageRect.extent.width, view.subImage.imageRect.extent.height};
+
+	bool tex = glIsEnabled(GL_TEXTURE_2D);
+	bool depth = glIsEnabled(GL_DEPTH_TEST);
+	bool cull = glIsEnabled(GL_CULL_FACE);
+	GLint previous_texture_id;
+	glGetIntegerv(GL_TEXTURE_BINDING_2D, &previous_texture_id);
+
+	if (!tex) glEnable(GL_TEXTURE_2D);
+	if (depth) glDisable(GL_DEPTH_TEST);
+	if (cull) glDisable(GL_CULL_FACE);
+
+	// Generate texture to hold hidden area mask. Single channel, value of 1 means visible and 0 means masked away
+	ngp::GLTexture mask_texture;
+	mask_texture.resize(size, 1, true);
+	glBindTexture(GL_TEXTURE_2D, mask_texture.texture());
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+
+	GLuint framebuffer = 0;
+	glGenFramebuffers(1, &framebuffer);
+	glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
+	glFramebufferTexture(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, mask_texture.texture(), 0);
+
+	GLenum draw_buffers[1] = {GL_COLOR_ATTACHMENT0};
+	glDrawBuffers(1, draw_buffers);
+
+	glViewport(0, 0, size.x(), size.y());
+
+	// Draw hidden area mask
+	GLuint vao;
+	glGenVertexArrays(1, &vao);
+	glBindVertexArray(vao);
+
+	GLuint vertex_buffer;
+	glGenBuffers(1, &vertex_buffer);
+	glEnableVertexAttribArray(0);
+	glBindBuffer(GL_ARRAY_BUFFER, vertex_buffer);
+	glBufferData(GL_ARRAY_BUFFER, sizeof(XrVector2f) * vertices.size(), vertices.data(), GL_STATIC_DRAW);
+	glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 0, (void*)0);
+
+	GLuint index_buffer;
+	glGenBuffers(1, &index_buffer);
+	glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, index_buffer);
+	glBufferData(GL_ELEMENT_ARRAY_BUFFER, sizeof(uint32_t) * indices.size(), indices.data(), GL_STATIC_DRAW);
+
+	glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
+	glClear(GL_COLOR_BUFFER_BIT);
+	glUseProgram(m_hidden_area_mask_program);
+
+	XrMatrix4x4f proj;
+	XrMatrix4x4f_CreateProjectionFov(&proj, GRAPHICS_OPENGL, view.fov, 1.0f / 128.0f, 128.0f);
+
+	GLuint project_id = glGetUniformLocation(m_hidden_area_mask_program, "project");
+	glUniformMatrix4fv(project_id, 1, GL_FALSE, &proj.m[0]);
+
+	glDrawElements(GL_TRIANGLES, indices.size(), GL_UNSIGNED_INT, (void*)0);
+	glFinish();
+
+	glDisableVertexAttribArray(0);
+	glDeleteBuffers(1, &vertex_buffer);
+	glDeleteBuffers(1, &index_buffer);
+	glDeleteVertexArrays(1, &vao);
+	glDeleteFramebuffers(1, &framebuffer);
+
+	glBindVertexArray(0);
+	glUseProgram(0);
+
+	// restore old state
+	if (!tex) glDisable(GL_TEXTURE_2D);
+	if (depth) glEnable(GL_DEPTH_TEST);
+	if (cull) glEnable(GL_CULL_FACE);
+	glBindTexture(GL_TEXTURE_2D, previous_texture_id);
+	glBindFramebuffer(GL_FRAMEBUFFER, 0);
+
+	std::shared_ptr<Buffer2D<uint8_t>> mask = std::make_shared<Buffer2D<uint8_t>>(size);
+
+	const dim3 threads = { 16, 8, 1 };
+	const dim3 blocks = { div_round_up((uint32_t)size.x(), threads.x), div_round_up((uint32_t)size.y(), threads.y), 1 };
+
+	read_hidden_area_mask_kernel<<<blocks, threads>>>(size, mask_texture.surface(), mask->data());
+	CUDA_CHECK_THROW(cudaDeviceSynchronize());
+
+	return mask;
+}
+
+Matrix<float, 3, 4> convert_xr_matrix_to_eigen(const XrMatrix4x4f& m) {
+	Matrix<float, 3, 4> out;
+
+	for (size_t i = 0; i < 3; ++i) {
+		for (size_t j = 0; j < 4; ++j) {
+			out(i, j) = m.m[i + j * 4];
+		}
+	}
+
+	// Flip Y and Z axes to match NGP conventions
+	out(0, 1) *= -1.f;
+	out(1, 0) *= -1.f;
+
+	out(0, 2) *= -1.f;
+	out(2, 0) *= -1.f;
+
+	out(1, 3) *= -1.f;
+	out(2, 3) *= -1.f;
+
+	return out;
+}
+
+Matrix<float, 3, 4> convert_xr_pose_to_eigen(const XrPosef& pose) {
+	XrMatrix4x4f matrix;
+	XrVector3f unit_scale{1.0f, 1.0f, 1.0f};
+	XrMatrix4x4f_CreateTranslationRotationScale(&matrix, &pose.position, &pose.orientation, &unit_scale);
+	return convert_xr_matrix_to_eigen(matrix);
+}
+
+OpenXRHMD::FrameInfoPtr OpenXRHMD::begin_frame() {
+	XrFrameWaitInfo frame_wait_info{XR_TYPE_FRAME_WAIT_INFO};
+	XR_CHECK_THROW(xrWaitFrame(m_session, &frame_wait_info, &m_frame_state));
+
+	XrFrameBeginInfo frame_begin_info{XR_TYPE_FRAME_BEGIN_INFO};
+	XR_CHECK_THROW(xrBeginFrame(m_session, &frame_begin_info));
+
+	if (!m_frame_state.shouldRender) {
+		return std::make_shared<FrameInfo>();
+	}
+
+	uint32_t num_views = (uint32_t)m_swapchains.size();
+	// TODO assert m_view_configuration_views.size() == m_swapchains.size()
+
+	// locate views
+	std::vector<XrView> views(num_views, {XR_TYPE_VIEW});
+
+	XrViewState viewState{XR_TYPE_VIEW_STATE};
+
+	XrViewLocateInfo view_locate_info{XR_TYPE_VIEW_LOCATE_INFO};
+	view_locate_info.viewConfigurationType = m_view_configuration_type;
+	view_locate_info.displayTime = m_frame_state.predictedDisplayTime;
+	view_locate_info.space = m_space;
+
+	XR_CHECK_THROW(xrLocateViews(m_session, &view_locate_info, &viewState, uint32_t(views.size()), &num_views, views.data()));
+
+	if (!(viewState.viewStateFlags & XR_VIEW_STATE_POSITION_VALID_BIT) || !(viewState.viewStateFlags & XR_VIEW_STATE_ORIENTATION_VALID_BIT)) {
+		return std::make_shared<FrameInfo>();
+	}
+
+	m_hidden_area_masks.resize(num_views);
+
+	// Fill frame information
+	if (!m_previous_frame_info) {
+		m_previous_frame_info = std::make_shared<FrameInfo>();
+	}
+
+	FrameInfoPtr frame_info = std::make_shared<FrameInfo>(*m_previous_frame_info);
+	frame_info->views.resize(m_swapchains.size());
+
+	for (size_t i = 0; i < m_swapchains.size(); ++i) {
+		const auto& sc = m_swapchains[i];
+
+		XrSwapchainImageAcquireInfo image_acquire_info{XR_TYPE_SWAPCHAIN_IMAGE_ACQUIRE_INFO};
+		XrSwapchainImageWaitInfo image_wait_info{XR_TYPE_SWAPCHAIN_IMAGE_WAIT_INFO, nullptr, XR_INFINITE_DURATION};
+
+		uint32_t image_index;
+		XR_CHECK_THROW(xrAcquireSwapchainImage(sc.handle, &image_acquire_info, &image_index));
+		XR_CHECK_THROW(xrWaitSwapchainImage(sc.handle, &image_wait_info));
+
+		FrameInfo::View& v = frame_info->views[i];
+		v.framebuffer = sc.framebuffers_gl[image_index];
+		v.view.pose = views[i].pose;
+		v.view.fov = views[i].fov;
+		v.view.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}};
+		v.view.subImage.imageArrayIndex = 0;
+		v.view.subImage.swapchain = sc.handle;
+
+		glBindFramebuffer(GL_FRAMEBUFFER, sc.framebuffers_gl[image_index]);
+		glClearColor(0.0f, 0.0f, 0.0f, 0.0f);
+		glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
+		glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, sc.images_gl.at(image_index).image, 0);
+
+		if (sc.depth_handle != XR_NULL_HANDLE) {
+			uint32_t depth_image_index;
+			XR_CHECK_THROW(xrAcquireSwapchainImage(sc.depth_handle, &image_acquire_info, &depth_image_index));
+			XR_CHECK_THROW(xrWaitSwapchainImage(sc.depth_handle, &image_wait_info));
+
+			glFramebufferTexture2D(GL_FRAMEBUFFER, GL_DEPTH_ATTACHMENT, GL_TEXTURE_2D, sc.depth_images_gl.at(depth_image_index).image, 0);
+
+			v.depth_info.subImage.imageRect = XrRect2Di{{0, 0}, {sc.width, sc.height}};
+			v.depth_info.subImage.imageArrayIndex = 0;
+			v.depth_info.subImage.swapchain = sc.depth_handle;
+			v.depth_info.minDepth = 0.0f;
+			v.depth_info.maxDepth = 1.0f;
+
+			// To be overwritten with actual near and far planes by end_frame
+			v.depth_info.nearZ = 1.0f / 128.0f;
+			v.depth_info.farZ = 128.0f;
+		}
+
+		glBindFramebuffer(GL_FRAMEBUFFER, 0);
+
+		if (!m_hidden_area_masks.at(i)) {
+			m_hidden_area_masks.at(i) = rasterize_hidden_area_mask(i, v.view);
+		}
+
+		v.hidden_area_mask = m_hidden_area_masks.at(i);
+		v.pose = convert_xr_pose_to_eigen(v.view.pose);
+	}
+
+	XrActiveActionSet active_action_set{m_action_set, XR_NULL_PATH};
+	XrActionsSyncInfo sync_info{XR_TYPE_ACTIONS_SYNC_INFO};
+	sync_info.countActiveActionSets = 1;
+	sync_info.activeActionSets = &active_action_set;
+	XR_CHECK_THROW(xrSyncActions(m_session, &sync_info));
+
+	for (size_t i = 0; i < 2; ++i) {
+		// Hand pose
+		{
+			XrActionStatePose pose_state{XR_TYPE_ACTION_STATE_POSE};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_pose_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStatePose(m_session, &get_info, &pose_state));
+
+			frame_info->hands[i].pose_active = pose_state.isActive;
+			if (frame_info->hands[i].pose_active) {
+				XrSpaceLocation space_location{XR_TYPE_SPACE_LOCATION};
+				XR_CHECK_THROW(xrLocateSpace(m_hand_spaces[i], m_space, m_frame_state.predictedDisplayTime, &space_location));
+				frame_info->hands[i].pose = convert_xr_pose_to_eigen(space_location.pose);
+			}
+		}
+
+		// Stick
+		{
+			XrActionStateVector2f thumbstick_state{XR_TYPE_ACTION_STATE_VECTOR2F};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_thumbstick_actions[i];
+			XR_CHECK_THROW(xrGetActionStateVector2f(m_session, &get_info, &thumbstick_state));
+
+			if (thumbstick_state.isActive) {
+				frame_info->hands[i].thumbstick.x() = thumbstick_state.currentState.x;
+				frame_info->hands[i].thumbstick.y() = thumbstick_state.currentState.y;
+			} else {
+				frame_info->hands[i].thumbstick = Vector2f::Zero();
+			}
+		}
+
+		// Press
+		{
+			XrActionStateBoolean press_state{XR_TYPE_ACTION_STATE_BOOLEAN};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_press_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStateBoolean(m_session, &get_info, &press_state));
+
+			if (press_state.isActive) {
+				frame_info->hands[i].pressing = press_state.currentState;
+			} else {
+				frame_info->hands[i].pressing = 0.0f;
+			}
+		}
+
+		// Grab
+		{
+			XrActionStateFloat grab_state{XR_TYPE_ACTION_STATE_FLOAT};
+			XrActionStateGetInfo get_info{XR_TYPE_ACTION_STATE_GET_INFO};
+			get_info.action = m_grab_action;
+			get_info.subactionPath = m_hand_paths[i];
+			XR_CHECK_THROW(xrGetActionStateFloat(m_session, &get_info, &grab_state));
+
+			if (grab_state.isActive) {
+				frame_info->hands[i].grab_strength = grab_state.currentState;
+			} else {
+				frame_info->hands[i].grab_strength = 0.0f;
+			}
+
+			bool was_grabbing = frame_info->hands[i].grabbing;
+			frame_info->hands[i].grabbing = frame_info->hands[i].grab_strength >= 0.5f;
+
+			if (frame_info->hands[i].grabbing) {
+				frame_info->hands[i].prev_grab_pos = was_grabbing ? frame_info->hands[i].grab_pos : frame_info->hands[i].pose.col(3);
+				frame_info->hands[i].grab_pos = frame_info->hands[i].pose.col(3);
+			}
+		}
+	}
+
+	m_previous_frame_info = frame_info;
+	return frame_info;
+}
+
+void OpenXRHMD::end_frame(FrameInfoPtr frame_info, float znear, float zfar) {
+	std::vector<XrCompositionLayerProjectionView> layer_projection_views(frame_info->views.size());
+	for (size_t i = 0; i < layer_projection_views.size(); ++i) {
+		auto& v = frame_info->views[i];
+		auto& view = layer_projection_views[i];
+
+		view = v.view;
+
+		// release swapchain image
+		XrSwapchainImageReleaseInfo release_info{XR_TYPE_SWAPCHAIN_IMAGE_RELEASE_INFO};
+		XR_CHECK_THROW(xrReleaseSwapchainImage(v.view.subImage.swapchain, &release_info));
+
+		if (v.depth_info.subImage.swapchain != XR_NULL_HANDLE) {
+			XR_CHECK_THROW(xrReleaseSwapchainImage(v.depth_info.subImage.swapchain, &release_info));
+			v.depth_info.nearZ = znear;
+			v.depth_info.farZ = zfar;
+			// The following line being commented means that our provided depth buffer
+			// _isn't_ actually passed to the runtime for reprojection. So far,
+			// experimentation has shown that runtimes do a better job at reprojection
+			// without getting a depth buffer from us, so we leave it disabled for now.
+			// view.next = &v.depth_info;
+		}
+	}
+
+	XrCompositionLayerProjection layer{XR_TYPE_COMPOSITION_LAYER_PROJECTION};
+	layer.space = m_space;
+	layer.viewCount = uint32_t(layer_projection_views.size());
+	layer.views = layer_projection_views.data();
+
+	std::vector<XrCompositionLayerBaseHeader*> layers;
+	if (layer.viewCount) {
+		layers.push_back(reinterpret_cast<XrCompositionLayerBaseHeader*>(&layer));
+	}
+
+	XrFrameEndInfo frame_end_info{XR_TYPE_FRAME_END_INFO};
+	frame_end_info.displayTime = m_frame_state.predictedDisplayTime;
+	frame_end_info.environmentBlendMode = m_environment_blend_mode;
+	frame_end_info.layerCount = (uint32_t)layers.size();
+	frame_end_info.layers = layers.data();
+	XR_CHECK_THROW(xrEndFrame(m_session, &frame_end_info));
+}
+
+NGP_NAMESPACE_END
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
diff --git a/src/python_api.cu b/src/python_api.cu
index f69056f3f46a65aff0b7f18e7cec7a5fa56f239b..8a8d436405d11488fd550c77f5ff6468300cde91 100644
--- a/src/python_api.cu
+++ b/src/python_api.cu
@@ -157,6 +157,7 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 	}
 
 	auto end_cam_matrix = m_smoothed_camera;
+	auto prev_camera_matrix = m_smoothed_camera;
 
 	for (int i = 0; i < spp; ++i) {
 		float start_alpha = ((float)i)/(float)spp * shutter_fraction;
@@ -164,6 +165,9 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 
 		auto sample_start_cam_matrix = start_cam_matrix;
 		auto sample_end_cam_matrix = log_space_lerp(start_cam_matrix, end_cam_matrix, shutter_fraction);
+		if (i == 0) {
+			prev_camera_matrix = sample_start_cam_matrix;
+		}
 
 		if (path_animation_enabled) {
 			set_camera_from_time(start_time + (end_time-start_time) * (start_alpha + end_alpha) / 2.0f);
@@ -174,7 +178,21 @@ py::array_t<float> Testbed::render_to_cpu(int width, int height, int spp, bool l
 			autofocus();
 		}
 
-		render_frame(sample_start_cam_matrix, sample_end_cam_matrix, Eigen::Vector4f::Zero(), m_windowless_render_surface, !linear);
+		render_frame(
+			m_stream.get(),
+			sample_start_cam_matrix,
+			sample_end_cam_matrix,
+			prev_camera_matrix,
+			m_screen_center,
+			m_relative_focal_length,
+			{0.0f, 0.0f, 0.0f, 1.0f},
+			{},
+			{},
+			m_visualized_dimension,
+			m_windowless_render_surface,
+			!linear
+		);
+		prev_camera_matrix = sample_start_cam_matrix;
 	}
 
 	// For cam smoothing when rendering the next frame.
@@ -303,6 +321,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.value("FTheta", ELensMode::FTheta)
 		.value("LatLong", ELensMode::LatLong)
 		.value("OpenCVFisheye", ELensMode::OpenCVFisheye)
+		.value("Equirectangular", ELensMode::Equirectangular)
 		.export_values();
 
 	py::class_<BoundingBox>(m, "BoundingBox")
@@ -344,12 +363,13 @@ PYBIND11_MODULE(pyngp, m) {
 		.def("clear_training_data", &Testbed::clear_training_data, "Clears training data to free up GPU memory.")
 		// General control
 #ifdef NGP_GUI
-		.def("init_window", &Testbed::init_window, "Init a GLFW window that shows real-time progress and a GUI. 'second_window' creates a second copy of the output in its own window",
+		.def("init_window", &Testbed::init_window, "Init a GLFW window that shows real-time progress and a GUI. 'second_window' creates a second copy of the output in its own window.",
 			py::arg("width"),
 			py::arg("height"),
 			py::arg("hidden") = false,
 			py::arg("second_window") = false
 		)
+		.def("init_vr", &Testbed::init_vr, "Init rendering to a connected and active VR headset. Requires a GUI window to have been previously created via `init_window`.")
 		.def_readwrite("keyboard_event_callback", &Testbed::m_keyboard_event_callback)
 		.def("is_key_pressed", [](py::object& obj, int key) { return ImGui::IsKeyPressed(key); })
 		.def("is_key_down", [](py::object& obj, int key) { return ImGui::IsKeyDown(key); })
@@ -431,6 +451,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.def_readwrite("dynamic_res_target_fps", &Testbed::m_dynamic_res_target_fps)
 		.def_readwrite("fixed_res_factor", &Testbed::m_fixed_res_factor)
 		.def_readwrite("background_color", &Testbed::m_background_color)
+		.def_readwrite("render_transparency_as_checkerboard", &Testbed::m_render_transparency_as_checkerboard)
 		.def_readwrite("shall_train", &Testbed::m_train)
 		.def_readwrite("shall_train_encoding", &Testbed::m_train_encoding)
 		.def_readwrite("shall_train_network", &Testbed::m_train_network)
@@ -493,7 +514,7 @@ PYBIND11_MODULE(pyngp, m) {
 		.def_property("dlss",
 			[](py::object& obj) { return obj.cast<Testbed&>().m_dlss; },
 			[](const py::object& obj, bool value) {
-				if (value && !obj.cast<Testbed&>().m_dlss_supported) {
+				if (value && !obj.cast<Testbed&>().m_dlss_provider) {
 					if (obj.cast<Testbed&>().m_render_window) {
 						throw std::runtime_error{"DLSS not supported."};
 					} else {
@@ -660,7 +681,6 @@ PYBIND11_MODULE(pyngp, m) {
 	image
 		.def_readonly("training", &Testbed::Image::training)
 		.def_readwrite("random_mode", &Testbed::Image::random_mode)
-		.def_readwrite("pos", &Testbed::Image::pos)
 		;
 
 	py::class_<Testbed::Image::Training>(image, "Training")
diff --git a/src/render_buffer.cu b/src/render_buffer.cu
index c6d06e250c5200c96c8455b8de2db5a375a3ff0b..aa0c10dd4950f743bb756c86bff5474d6e962f1c 100644
--- a/src/render_buffer.cu
+++ b/src/render_buffer.cu
@@ -47,30 +47,40 @@ void CudaSurface2D::free() {
 	m_surface = 0;
 	if (m_array) {
 		cudaFreeArray(m_array);
-		g_total_n_bytes_allocated -= m_size.prod() * sizeof(float4);
+		g_total_n_bytes_allocated -= m_size.prod() * sizeof(float) * m_n_channels;
 	}
 	m_array = nullptr;
+	m_size = Vector2i::Zero();
+	m_n_channels = 0;
 }
 
-void CudaSurface2D::resize(const Vector2i& size) {
-	if (size == m_size) {
+void CudaSurface2D::resize(const Vector2i& size, int n_channels) {
+	if (size == m_size && n_channels == m_n_channels) {
 		return;
 	}
 
 	free();
 
-	m_size = size;
-
-	cudaChannelFormatDesc desc = cudaCreateChannelDesc<float4>();
+	cudaChannelFormatDesc desc;
+	switch (n_channels) {
+		case 1: desc = cudaCreateChannelDesc<float>(); break;
+		case 2: desc = cudaCreateChannelDesc<float2>(); break;
+		case 3: desc = cudaCreateChannelDesc<float3>(); break;
+		case 4: desc = cudaCreateChannelDesc<float4>(); break;
+		default: throw std::runtime_error{fmt::format("CudaSurface2D: unsupported number of channels {}", n_channels)};
+	}
 	CUDA_CHECK_THROW(cudaMallocArray(&m_array, &desc, size.x(), size.y(), cudaArraySurfaceLoadStore));
 
-	g_total_n_bytes_allocated += m_size.prod() * sizeof(float4);
+	g_total_n_bytes_allocated += m_size.prod() * sizeof(float) * n_channels;
 
 	struct cudaResourceDesc resource_desc;
 	memset(&resource_desc, 0, sizeof(resource_desc));
 	resource_desc.resType = cudaResourceTypeArray;
 	resource_desc.res.array.array = m_array;
 	CUDA_CHECK_THROW(cudaCreateSurfaceObject(&m_surface, &resource_desc));
+
+	m_size = size;
+	m_n_channels = n_channels;
 }
 
 #ifdef NGP_GUI
@@ -91,14 +101,14 @@ GLuint GLTexture::texture() {
 
 cudaSurfaceObject_t GLTexture::surface() {
 	if (!m_cuda_mapping) {
-		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size);
+		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size, m_n_channels);
 	}
 	return m_cuda_mapping->surface();
 }
 
 cudaArray_t GLTexture::array() {
 	if (!m_cuda_mapping) {
-		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size);
+		m_cuda_mapping = std::make_unique<CUDAMapping>(texture(), m_size, m_n_channels);
 	}
 	return m_cuda_mapping->array();
 }
@@ -108,12 +118,14 @@ void GLTexture::blit_from_cuda_mapping() {
 		return;
 	}
 
-	if (m_internal_format != GL_RGBA32F || m_format != GL_RGBA || m_is_8bit) {
-		throw std::runtime_error{"Can only blit from CUDA mapping if the texture is RGBA float."};
+	if (m_is_8bit) {
+		throw std::runtime_error{"Can only blit from CUDA mapping if the texture is float."};
 	}
 
 	const float* data_cpu = m_cuda_mapping->data_cpu();
-	glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA32F, m_size.x(), m_size.y(), 0, GL_RGBA, GL_FLOAT, data_cpu);
+
+	glBindTexture(GL_TEXTURE_2D, m_texture_id);
+	glTexImage2D(GL_TEXTURE_2D, 0, m_internal_format, m_size.x(), m_size.y(), 0, m_format, GL_FLOAT, data_cpu);
 }
 
 void GLTexture::load(const fs::path& path) {
@@ -173,8 +185,7 @@ void GLTexture::resize(const Vector2i& new_size, int n_channels, bool is_8bit) {
 	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
 }
 
-
-GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size) : m_size{size} {
+GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size, int n_channels) : m_size{size}, m_n_channels{n_channels} {
 	static bool s_is_cuda_interop_supported = !is_wsl();
 	if (s_is_cuda_interop_supported) {
 		cudaError_t err = cudaGraphicsGLRegisterImage(&m_graphics_resource, texture_id, GL_TEXTURE_2D, cudaGraphicsRegisterFlagsSurfaceLoadStore);
@@ -187,8 +198,8 @@ GLTexture::CUDAMapping::CUDAMapping(GLuint texture_id, const Vector2i& size) : m
 	if (!s_is_cuda_interop_supported) {
 		// falling back to a regular cuda surface + CPU copy of data
 		m_cuda_surface = std::make_unique<CudaSurface2D>();
-		m_cuda_surface->resize(size);
-		m_data_cpu.resize(m_size.prod() * 4);
+		m_cuda_surface->resize(size, n_channels);
+		m_data_cpu.resize(m_size.prod() * n_channels);
 		return;
 	}
 
@@ -212,7 +223,7 @@ GLTexture::CUDAMapping::~CUDAMapping() {
 }
 
 const float* GLTexture::CUDAMapping::data_cpu() {
-	CUDA_CHECK_THROW(cudaMemcpy2DFromArray(m_data_cpu.data(), m_size.x() * sizeof(float) * 4, array(), 0, 0, m_size.x() * sizeof(float) * 4, m_size.y(), cudaMemcpyDeviceToHost));
+	CUDA_CHECK_THROW(cudaMemcpy2DFromArray(m_data_cpu.data(), m_size.x() * sizeof(float) * m_n_channels, array(), 0, 0, m_size.x() * sizeof(float) * m_n_channels, m_size.y(), cudaMemcpyDeviceToHost));
 	return m_data_cpu.data();
 }
 #endif //NGP_GUI
@@ -362,11 +373,11 @@ __global__ void overlay_image_kernel(
 	float fx = x+0.5f;
 	float fy = y+0.5f;
 
-	fx-=resolution.x()*0.5f; fx/=zoom; fx+=screen_center.x() * resolution.x();
-	fy-=resolution.y()*0.5f; fy/=zoom; fy+=screen_center.y() * resolution.y();
+	fx -= resolution.x() * 0.5f; fx /= zoom; fx += screen_center.x() * resolution.x();
+	fy -= resolution.y() * 0.5f; fy /= zoom; fy += screen_center.y() * resolution.y();
 
-	float u = (fx-resolution.x()*0.5f) * scale  + image_resolution.x()*0.5f;
-	float v = (fy-resolution.y()*0.5f) * scale  + image_resolution.y()*0.5f;
+	float u = (fx - resolution.x() * 0.5f) * scale  + image_resolution.x() * 0.5f;
+	float v = (fy - resolution.y() * 0.5f) * scale  + image_resolution.y() * 0.5f;
 
 	int srcx = floorf(u);
 	int srcy = floorf(v);
@@ -431,7 +442,8 @@ __global__ void overlay_depth_kernel(
 	float depth_scale,
 	Vector2i image_resolution,
 	int fov_axis,
-	float zoom, Eigen::Vector2f screen_center,
+	float zoom,
+	Eigen::Vector2f screen_center,
 	cudaSurfaceObject_t surface
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
@@ -443,14 +455,14 @@ __global__ void overlay_depth_kernel(
 
 	float scale = image_resolution[fov_axis] / float(resolution[fov_axis]);
 
-	float fx = x+0.5f;
-	float fy = y+0.5f;
+	float fx = x + 0.5f;
+	float fy = y + 0.5f;
 
-	fx-=resolution.x()*0.5f; fx/=zoom; fx+=screen_center.x() * resolution.x();
-	fy-=resolution.y()*0.5f; fy/=zoom; fy+=screen_center.y() * resolution.y();
+	fx -= resolution.x() * 0.5f; fx /= zoom; fx += screen_center.x() * resolution.x();
+	fy -= resolution.y() * 0.5f; fy /= zoom; fy += screen_center.y() * resolution.y();
 
-	float u = (fx-resolution.x()*0.5f) * scale  + image_resolution.x()*0.5f;
-	float v = (fy-resolution.y()*0.5f) * scale  + image_resolution.y()*0.5f;
+	float u = (fx - resolution.x() * 0.5f) * scale + image_resolution.x() * 0.5f;
+	float v = (fy - resolution.y() * 0.5f) * scale + image_resolution.y() * 0.5f;
 
 	int srcx = floorf(u);
 	int srcy = floorf(v);
@@ -568,15 +580,42 @@ __global__ void dlss_splat_kernel(
 	surf2Dwrite(color, surface, x * sizeof(float4), y);
 }
 
+__global__ void depth_splat_kernel(
+	Vector2i resolution,
+	float znear,
+	float zfar,
+	float* __restrict__ depth_buffer,
+	cudaSurfaceObject_t surface
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	uint32_t idx = x + resolution.x() * y;
+	surf2Dwrite(to_ndc_depth(depth_buffer[idx], znear, zfar), surface, x * sizeof(float), y);
+}
+
+void CudaRenderBufferView::clear(cudaStream_t stream) const {
+	size_t n_pixels = resolution.prod();
+	CUDA_CHECK_THROW(cudaMemsetAsync(frame_buffer, 0, n_pixels * sizeof(Array4f), stream));
+	CUDA_CHECK_THROW(cudaMemsetAsync(depth_buffer, 0, n_pixels * sizeof(float), stream));
+}
+
 void CudaRenderBuffer::resize(const Vector2i& res) {
 	m_in_resolution = res;
 	m_frame_buffer.enlarge(res.x() * res.y());
 	m_depth_buffer.enlarge(res.x() * res.y());
+	if (m_depth_target) {
+		m_depth_target->resize(res, 1);
+	}
 	m_accumulate_buffer.enlarge(res.x() * res.y());
 
 	Vector2i out_res = m_dlss ? m_dlss->out_resolution() : res;
 	auto prev_out_res = out_resolution();
-	m_surface_provider->resize(out_res);
+	m_rgba_target->resize(out_res, 4);
 
 	if (out_resolution() != prev_out_res) {
 		reset_accumulation();
@@ -584,8 +623,7 @@ void CudaRenderBuffer::resize(const Vector2i& res) {
 }
 
 void CudaRenderBuffer::clear_frame(cudaStream_t stream) {
-	CUDA_CHECK_THROW(cudaMemsetAsync(m_frame_buffer.data(), 0, m_frame_buffer.bytes(), stream));
-	CUDA_CHECK_THROW(cudaMemsetAsync(m_depth_buffer.data(), 0, m_depth_buffer.bytes(), stream));
+	view().clear(stream);
 }
 
 void CudaRenderBuffer::accumulate(float exposure, cudaStream_t stream) {
@@ -610,10 +648,10 @@ void CudaRenderBuffer::accumulate(float exposure, cudaStream_t stream) {
 	++m_spp;
 }
 
-void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color, EColorSpace output_color_space, cudaStream_t stream) {
+void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color, EColorSpace output_color_space, float znear, float zfar, cudaStream_t stream) {
 	assert(m_dlss || out_resolution() == in_resolution());
 
-	auto res = m_dlss ? in_resolution() : out_resolution();
+	auto res = in_resolution();
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	tonemap_kernel<<<blocks, threads, 0, stream>>>(
@@ -646,6 +684,10 @@ void CudaRenderBuffer::tonemap(float exposure, const Array4f& background_color,
 		const dim3 out_blocks = { div_round_up((uint32_t)out_res.x(), threads.x), div_round_up((uint32_t)out_res.y(), threads.y), 1 };
 		dlss_splat_kernel<<<out_blocks, threads, 0, stream>>>(out_res, m_dlss->output(), surface());
 	}
+
+	if (m_depth_target) {
+		depth_splat_kernel<<<blocks, threads, 0, stream>>>(res, znear, zfar, depth_buffer(), m_depth_target->surface());
+	}
 }
 
 void CudaRenderBuffer::overlay_image(
@@ -726,10 +768,10 @@ void CudaRenderBuffer::overlay_false_color(Vector2i training_resolution, bool to
 	);
 }
 
-void CudaRenderBuffer::enable_dlss(const Eigen::Vector2i& max_out_res) {
+void CudaRenderBuffer::enable_dlss(IDlssProvider& dlss_provider, const Eigen::Vector2i& max_out_res) {
 #ifdef NGP_VULKAN
 	if (!m_dlss || m_dlss->max_out_resolution() != max_out_res) {
-		m_dlss = dlss_init(max_out_res);
+		m_dlss = dlss_provider.init_dlss(max_out_res);
 	}
 
 	if (m_dlss) {
diff --git a/src/testbed.cu b/src/testbed.cu
index 23735633885daba0ad7550ee7edfe59a93c851de..9cd0282f04d5c9f3ded8a363a586afb81cbd7bd1 100644
--- a/src/testbed.cu
+++ b/src/testbed.cu
@@ -186,11 +186,30 @@ void Testbed::set_mode(ETestbedMode mode) {
 	m_distortion = {};
 	m_training_data_available = false;
 
+	// Clear device-owned data that might be mode-specific
+	for (auto&& device : m_devices) {
+		device.clear();
+	}
+
 	// Reset paths that might be attached to the chosen mode
 	m_data_path = {};
 
 	m_testbed_mode = mode;
 
+	// Set various defaults depending on mode
+	if (m_testbed_mode == ETestbedMode::Nerf) {
+		if (m_devices.size() > 1) {
+			m_use_aux_devices = true;
+		}
+
+		if (m_dlss_provider) {
+			m_dlss = true;
+		}
+	} else {
+		m_use_aux_devices = false;
+		m_dlss = false;
+	}
+
 	reset_camera();
 }
 
@@ -348,8 +367,8 @@ void Testbed::reset_accumulation(bool due_to_camera_movement, bool immediate_red
 
 	if (!due_to_camera_movement || !reprojection_available()) {
 		m_windowless_render_surface.reset_accumulation();
-		for (auto& tex : m_render_surfaces) {
-			tex.reset_accumulation();
+		for (auto& view : m_views) {
+			view.render_buffer->reset_accumulation();
 		}
 	}
 }
@@ -359,8 +378,13 @@ void Testbed::set_visualized_dim(int dim) {
 	reset_accumulation();
 }
 
-void Testbed::translate_camera(const Vector3f& rel) {
-	m_camera.col(3) += m_camera.block<3, 3>(0, 0) * rel * m_bounding_radius;
+void Testbed::translate_camera(const Vector3f& rel, const Matrix3f& rot, bool allow_up_down) {
+	Vector3f movement = rot * rel;
+	if (!allow_up_down) {
+		movement -= movement.dot(m_up_dir) * m_up_dir;
+	}
+
+	m_camera.col(3) += movement;
 	reset_accumulation(true);
 }
 
@@ -425,15 +449,28 @@ void Testbed::set_camera_to_training_view(int trainview) {
 	m_scale = std::max((old_look_at - view_pos()).dot(view_dir()), 0.1f);
 	m_nerf.render_with_lens_distortion = true;
 	m_nerf.render_lens = m_nerf.training.dataset.metadata[trainview].lens;
-	m_screen_center = Vector2f::Constant(1.0f) - m_nerf.training.dataset.metadata[0].principal_point;
+	if (!supports_dlss(m_nerf.render_lens.mode)) {
+		m_dlss = false;
+	}
+
+	m_screen_center = Vector2f::Constant(1.0f) - m_nerf.training.dataset.metadata[trainview].principal_point;
+	m_nerf.training.view = trainview;
 }
 
 void Testbed::reset_camera() {
 	m_fov_axis = 1;
-	set_fov(50.625f);
-	m_zoom = 1.f;
+	m_zoom = 1.0f;
 	m_screen_center = Vector2f::Constant(0.5f);
-	m_scale = m_testbed_mode == ETestbedMode::Image ? 1.0f : 1.5f;
+
+	if (m_testbed_mode == ETestbedMode::Image) {
+		// Make image full-screen at the given view distance
+		m_relative_focal_length = Vector2f::Ones();
+		m_scale = 1.0f;
+	} else {
+		set_fov(50.625f);
+		m_scale = 1.5f;
+	}
+
 	m_camera <<
 		1.0f, 0.0f, 0.0f, 0.5f,
 		0.0f, -1.0f, 0.0f, 0.5f,
@@ -630,7 +667,7 @@ void Testbed::imgui() {
 							m_smoothed_camera = m_camera;
 						}
 					} else {
-						m_pip_render_surface->reset_accumulation();
+						m_pip_render_buffer->reset_accumulation();
 					}
 				}
 			}
@@ -639,7 +676,7 @@ void Testbed::imgui() {
 				float w = ImGui::GetContentRegionAvail().x;
 				if (m_camera_path.update_cam_from_path) {
 					m_picture_in_picture_res = 0;
-					ImGui::Image((ImTextureID)(size_t)m_render_textures.front()->texture(), ImVec2(w, w * 9.0f / 16.0f));
+					ImGui::Image((ImTextureID)(size_t)m_rgba_render_textures.front()->texture(), ImVec2(w, w * 9.0f / 16.0f));
 				} else {
 					m_picture_in_picture_res = (float)std::min((int(w)+31)&(~31), 1920/4);
 					ImGui::Image((ImTextureID)(size_t)m_pip_render_texture->texture(), ImVec2(w, w * 9.0f / 16.0f));
@@ -684,7 +721,7 @@ void Testbed::imgui() {
 
 				auto elapsed = std::chrono::steady_clock::now() - m_camera_path.render_start_time;
 
-				uint32_t progress = m_camera_path.render_frame_idx * m_camera_path.render_settings.spp + m_render_surfaces.front().spp();
+				uint32_t progress = m_camera_path.render_frame_idx * m_camera_path.render_settings.spp + m_views.front().render_buffer->spp();
 				uint32_t goal = m_camera_path.render_settings.n_frames() * m_camera_path.render_settings.spp;
 				auto est_remaining = elapsed * (float)(goal - progress) / std::max(progress, 1u);
 
@@ -718,7 +755,11 @@ void Testbed::imgui() {
 
 	ImGui::Begin("instant-ngp v" NGP_VERSION);
 
-	size_t n_bytes = tcnn::total_n_bytes_allocated() + g_total_n_bytes_allocated + dlss_allocated_bytes();
+	size_t n_bytes = tcnn::total_n_bytes_allocated() + g_total_n_bytes_allocated;
+	if (m_dlss_provider) {
+		n_bytes += m_dlss_provider->allocated_bytes();
+	}
+
 	ImGui::Text("Frame: %.2f ms (%.1f FPS); Mem: %s", m_frame_ms.ema_val(), 1000.0f / m_frame_ms.ema_val(), bytes_to_string(n_bytes).c_str());
 	bool accum_reset = false;
 
@@ -728,41 +769,58 @@ void Testbed::imgui() {
 		if (imgui_colored_button(m_train ? "Stop training" : "Start training", 0.4)) {
 			set_train(!m_train);
 		}
+
+
+		ImGui::SameLine();
+		if (imgui_colored_button("Reset training", 0.f)) {
+			reload_network_from_file();
+		}
+
 		ImGui::SameLine();
-		ImGui::Checkbox("Train encoding", &m_train_encoding);
+		ImGui::Checkbox("encoding", &m_train_encoding);
 		ImGui::SameLine();
-		ImGui::Checkbox("Train network", &m_train_network);
+		ImGui::Checkbox("network", &m_train_network);
 		ImGui::SameLine();
-		ImGui::Checkbox("Random levels", &m_max_level_rand_training);
+		ImGui::Checkbox("rand levels", &m_max_level_rand_training);
 		if (m_testbed_mode == ETestbedMode::Nerf) {
-			ImGui::Checkbox("Train envmap", &m_nerf.training.train_envmap);
+			ImGui::Checkbox("envmap", &m_nerf.training.train_envmap);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train extrinsics", &m_nerf.training.optimize_extrinsics);
+			ImGui::Checkbox("extrinsics", &m_nerf.training.optimize_extrinsics);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train exposure", &m_nerf.training.optimize_exposure);
+			ImGui::Checkbox("exposure", &m_nerf.training.optimize_exposure);
 			ImGui::SameLine();
-			ImGui::Checkbox("Train distortion", &m_nerf.training.optimize_distortion);
+			ImGui::Checkbox("distortion", &m_nerf.training.optimize_distortion);
+
 			if (m_nerf.training.dataset.n_extra_learnable_dims) {
-				ImGui::Checkbox("Train latent codes", &m_nerf.training.optimize_extra_dims);
+				ImGui::SameLine();
+				ImGui::Checkbox("latents", &m_nerf.training.optimize_extra_dims);
 			}
+
+
 			static bool export_extrinsics_in_quat_format = true;
-			if (imgui_colored_button("Export extrinsics", 0.4f)) {
-				m_nerf.training.export_camera_extrinsics(m_imgui.extrinsics_path, export_extrinsics_in_quat_format);
+			static bool extrinsics_have_been_optimized = false;
+
+			if (m_nerf.training.optimize_extrinsics) {
+				extrinsics_have_been_optimized = true;
 			}
 
-			ImGui::SameLine();
-			ImGui::PushItemWidth(400.f);
-			ImGui::InputText("File##Extrinsics file path", m_imgui.extrinsics_path, sizeof(m_imgui.extrinsics_path));
-			ImGui::PopItemWidth();
-			ImGui::SameLine();
-			ImGui::Checkbox("Quaternion format", &export_extrinsics_in_quat_format);
-		}
-		if (imgui_colored_button("Reset training", 0.f)) {
-			reload_network_from_file();
+			if (extrinsics_have_been_optimized) {
+				if (imgui_colored_button("Export extrinsics", 0.4f)) {
+					m_nerf.training.export_camera_extrinsics(m_imgui.extrinsics_path, export_extrinsics_in_quat_format);
+				}
+
+				ImGui::SameLine();
+				ImGui::Checkbox("as quaternions", &export_extrinsics_in_quat_format);
+				ImGui::InputText("File##Extrinsics file path", m_imgui.extrinsics_path, sizeof(m_imgui.extrinsics_path));
+			}
 		}
+
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		ImGui::SliderInt("Batch size", (int*)&m_training_batch_size, 1 << 12, 1 << 22, "%d", ImGuiSliderFlags_Logarithmic);
 		ImGui::SameLine();
 		ImGui::DragInt("Seed", (int*)&m_seed, 1.0f, 0, std::numeric_limits<int>::max());
-		ImGui::SliderInt("Batch size", (int*)&m_training_batch_size, 1 << 12, 1 << 22, "%d", ImGuiSliderFlags_Logarithmic);
+		ImGui::PopItemWidth();
+
 		m_training_batch_size = next_multiple(m_training_batch_size, batch_size_granularity);
 
 		if (m_train) {
@@ -778,9 +836,11 @@ void Testbed::imgui() {
 		} else {
 			ImGui::Text("Training paused");
 		}
+
 		if (m_testbed_mode == ETestbedMode::Nerf) {
 			ImGui::Text("Rays/batch: %d, Samples/ray: %.2f, Batch size: %d/%d", m_nerf.training.counters_rgb.rays_per_batch, (float)m_nerf.training.counters_rgb.measured_batch_size / (float)m_nerf.training.counters_rgb.rays_per_batch, m_nerf.training.counters_rgb.measured_batch_size, m_nerf.training.counters_rgb.measured_batch_size_before_compaction);
 		}
+
 		float elapsed_training = std::chrono::duration<float>(std::chrono::steady_clock::now() - m_training_start_time_point).count();
 		ImGui::Text("Steps: %d, Loss: %0.6f (%0.2f dB), Elapsed: %.1fs", m_training_step, m_loss_scalar.ema_val(), linear_to_db(m_loss_scalar.ema_val()), elapsed_training);
 		ImGui::PlotLines("loss graph", m_loss_graph.data(), std::min(m_loss_graph_samples, m_loss_graph.size()), (m_loss_graph_samples < m_loss_graph.size()) ? 0 : (m_loss_graph_samples % m_loss_graph.size()), 0, FLT_MAX, FLT_MAX, ImVec2(0, 50.f));
@@ -848,85 +908,75 @@ void Testbed::imgui() {
 	if (!m_training_data_available) { ImGui::EndDisabled(); }
 
 	if (ImGui::CollapsingHeader("Rendering", ImGuiTreeNodeFlags_DefaultOpen)) {
-		ImGui::Checkbox("Render", &m_render);
-		ImGui::SameLine();
-
-		const auto& render_tex = m_render_surfaces.front();
-		std::string spp_string = m_dlss ? std::string{""} : fmt::format("({} spp)", std::max(render_tex.spp(), 1u));
-		ImGui::Text(": %.01fms for %dx%d %s", m_render_ms.ema_val(), render_tex.in_resolution().x(), render_tex.in_resolution().y(), spp_string.c_str());
-
-		if (m_dlss_supported) {
-			if (!m_single_view) {
-				ImGui::BeginDisabled();
-				m_dlss = false;
-			}
-
-			if (ImGui::Checkbox("DLSS", &m_dlss)) {
-				accum_reset = true;
+		if (!m_hmd) {
+			if (ImGui::Button("Connect to VR/AR headset")) {
+				try {
+					init_vr();
+				} catch (const std::runtime_error& e) {
+					imgui_error_string = e.what();
+					ImGui::OpenPopup("Error");
+				}
 			}
-
-			if (render_tex.dlss()) {
-				ImGui::SameLine();
-				ImGui::Text("(automatic quality setting: %s)", DlssQualityStrArray[(int)render_tex.dlss()->quality()]);
-				ImGui::SliderFloat("DLSS sharpening", &m_dlss_sharpening, 0.0f, 1.0f, "%.02f");
+		} else if (ImGui::TreeNodeEx("VR/AR settings", ImGuiTreeNodeFlags_DefaultOpen)) {
+			if (m_devices.size() > 1 && m_testbed_mode == ETestbedMode::Nerf) {
+				ImGui::Checkbox("Multi-GPU rendering (one per eye)", &m_use_aux_devices);
 			}
 
-			if (!m_single_view) {
-				ImGui::EndDisabled();
+			accum_reset |= ImGui::Checkbox("Foveated rendering", &m_foveated_rendering) && !m_dlss;
+			if (m_foveated_rendering) {
+				accum_reset |= ImGui::SliderFloat("Maximum foveation", &m_foveated_rendering_max_scaling, 1.0f, 16.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat) && !m_dlss;
 			}
+			ImGui::TreePop();
 		}
 
-		ImGui::Checkbox("Dynamic resolution", &m_dynamic_res);
+		ImGui::Checkbox("Render", &m_render);
+		ImGui::SameLine();
+
+		const auto& render_buffer = m_views.front().render_buffer;
+		std::string spp_string = m_dlss ? std::string{""} : fmt::format("({} spp)", std::max(render_buffer->spp(), 1u));
+		ImGui::Text(": %.01fms for %dx%d %s", m_render_ms.ema_val(), render_buffer->in_resolution().x(), render_buffer->in_resolution().y(), spp_string.c_str());
+
+		ImGui::SameLine();
 		if (ImGui::Checkbox("VSync", &m_vsync)) {
 			glfwSwapInterval(m_vsync ? 1 : 0);
 		}
-		ImGui::SliderFloat("Target FPS", &m_dynamic_res_target_fps, 2.0f, 144.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
-		ImGui::SliderInt("Max spp", &m_max_spp, 0, 1024, "%d", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
 
-		if (!m_dynamic_res) {
-			ImGui::SliderInt("Fixed resolution factor", &m_fixed_res_factor, 8, 64);
-		}
-
-		if (m_testbed_mode == ETestbedMode::Nerf && m_nerf.training.dataset.has_light_dirs) {
-			Vector3f light_dir = m_nerf.light_dir.normalized();
-			if (ImGui::TreeNodeEx("Light Dir (Polar)", ImGuiTreeNodeFlags_DefaultOpen)) {
-				float phi = atan2f(m_nerf.light_dir.x(), m_nerf.light_dir.z());
-				float theta = asinf(m_nerf.light_dir.y());
-				bool spin = ImGui::SliderFloat("Light Dir Theta", &theta, -PI() / 2.0f, PI() / 2.0f);
-				spin |= ImGui::SliderFloat("Light Dir Phi", &phi, -PI(), PI());
-				if (spin) {
-					float sin_phi, cos_phi;
-					sincosf(phi, &sin_phi, &cos_phi);
-					float cos_theta=cosf(theta);
-					m_nerf.light_dir = {sin_phi * cos_theta,sinf(theta),cos_phi * cos_theta};
-					accum_reset = true;
-				}
-				ImGui::TreePop();
-			}
-			if (ImGui::TreeNode("Light Dir (Cartesian)")) {
-				accum_reset |= ImGui::SliderFloat("Light Dir X", ((float*)(&m_nerf.light_dir)) + 0, -1.0f, 1.0f);
-				accum_reset |= ImGui::SliderFloat("Light Dir Y", ((float*)(&m_nerf.light_dir)) + 1, -1.0f, 1.0f);
-				accum_reset |= ImGui::SliderFloat("Light Dir Z", ((float*)(&m_nerf.light_dir)) + 2, -1.0f, 1.0f);
-				ImGui::TreePop();
-			}
+		if (!m_dlss_provider) { ImGui::BeginDisabled(); }
+		accum_reset |= ImGui::Checkbox("DLSS", &m_dlss);
+
+		if (render_buffer->dlss()) {
+			ImGui::SameLine();
+			ImGui::Text("(%s)", DlssQualityStrArray[(int)render_buffer->dlss()->quality()]);
+			ImGui::SameLine();
+			ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+			ImGui::SliderFloat("Sharpening", &m_dlss_sharpening, 0.0f, 1.0f, "%.02f");
+			ImGui::PopItemWidth();
 		}
-		if (m_testbed_mode == ETestbedMode::Nerf && m_nerf.training.dataset.n_extra_learnable_dims) {
-			accum_reset |= ImGui::SliderInt("training image latent code for inference", (int*)&m_nerf.extra_dim_idx_for_inference, 0, m_nerf.training.dataset.n_images-1);
+
+		if (!m_dlss_provider) {
+			ImGui::SameLine();
+			ImGui::Text("(unsupported on this system)");
+			ImGui::EndDisabled();
 		}
-		accum_reset |= ImGui::Combo("Render mode", (int*)&m_render_mode, RenderModeStr);
-		if (m_testbed_mode == ETestbedMode::Nerf)  {
-			accum_reset |= ImGui::Combo("Groundtruth Render mode", (int*)&m_ground_truth_render_mode, GroundTruthRenderModeStr);
-			accum_reset |= ImGui::SliderFloat("Groundtruth Alpha", &m_ground_truth_alpha, 0.0f, 1.0f, "%.02f", ImGuiSliderFlags_AlwaysClamp);
+
+		ImGui::Checkbox("Dynamic resolution", &m_dynamic_res);
+		ImGui::SameLine();
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		if (m_dynamic_res) {
+			ImGui::SliderFloat("Target FPS", &m_dynamic_res_target_fps, 2.0f, 144.0f, "%.01f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
+		} else {
+			ImGui::SliderInt("Resolution factor", &m_fixed_res_factor, 8, 64);
 		}
-		accum_reset |= ImGui::Combo("Color space", (int*)&m_color_space, ColorSpaceStr);
+		ImGui::PopItemWidth();
+
+		accum_reset |= ImGui::Combo("Render mode", (int*)&m_render_mode, RenderModeStr);
 		accum_reset |= ImGui::Combo("Tonemap curve", (int*)&m_tonemap_curve, TonemapCurveStr);
 		accum_reset |= ImGui::ColorEdit4("Background", &m_background_color[0]);
+
 		if (ImGui::SliderFloat("Exposure", &m_exposure, -5.f, 5.f)) {
 			set_exposure(m_exposure);
 		}
 
-		accum_reset |= ImGui::Checkbox("Snap to pixel centers", &m_snap_to_pixel_centers);
-
 		float max_diam = (m_aabb.max-m_aabb.min).maxCoeff();
 		float render_diam = (m_render_aabb.max-m_render_aabb.min).maxCoeff();
 		float old_render_diam = render_diam;
@@ -988,11 +1038,52 @@ void Testbed::imgui() {
 			m_edit_render_aabb = false;
 		}
 
+		if (ImGui::TreeNode("Advanced rendering options")) {
+			ImGui::SliderInt("Max spp", &m_max_spp, 0, 1024, "%d", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
+			accum_reset |= ImGui::Checkbox("Render transparency as checkerboard", &m_render_transparency_as_checkerboard);
+			accum_reset |= ImGui::Combo("Color space", (int*)&m_color_space, ColorSpaceStr);
+			accum_reset |= ImGui::Checkbox("Snap to pixel centers", &m_snap_to_pixel_centers);
+
+			ImGui::TreePop();
+		}
+
 		if (m_testbed_mode == ETestbedMode::Nerf && ImGui::TreeNode("NeRF rendering options")) {
-			accum_reset |= ImGui::Checkbox("Apply lens distortion", &m_nerf.render_with_lens_distortion);
+			if (m_nerf.training.dataset.has_light_dirs) {
+				Vector3f light_dir = m_nerf.light_dir.normalized();
+				if (ImGui::TreeNodeEx("Light Dir (Polar)", ImGuiTreeNodeFlags_DefaultOpen)) {
+					float phi = atan2f(m_nerf.light_dir.x(), m_nerf.light_dir.z());
+					float theta = asinf(m_nerf.light_dir.y());
+					bool spin = ImGui::SliderFloat("Light Dir Theta", &theta, -PI() / 2.0f, PI() / 2.0f);
+					spin |= ImGui::SliderFloat("Light Dir Phi", &phi, -PI(), PI());
+					if (spin) {
+						float sin_phi, cos_phi;
+						sincosf(phi, &sin_phi, &cos_phi);
+						float cos_theta=cosf(theta);
+						m_nerf.light_dir = {sin_phi * cos_theta,sinf(theta),cos_phi * cos_theta};
+						accum_reset = true;
+					}
+					ImGui::TreePop();
+				}
+
+				if (ImGui::TreeNode("Light Dir (Cartesian)")) {
+					accum_reset |= ImGui::SliderFloat("Light Dir X", ((float*)(&m_nerf.light_dir)) + 0, -1.0f, 1.0f);
+					accum_reset |= ImGui::SliderFloat("Light Dir Y", ((float*)(&m_nerf.light_dir)) + 1, -1.0f, 1.0f);
+					accum_reset |= ImGui::SliderFloat("Light Dir Z", ((float*)(&m_nerf.light_dir)) + 2, -1.0f, 1.0f);
+					ImGui::TreePop();
+				}
+			}
+
+			if (m_nerf.training.dataset.n_extra_learnable_dims) {
+				accum_reset |= ImGui::SliderInt("training image latent code for inference", (int*)&m_nerf.extra_dim_idx_for_inference, 0, m_nerf.training.dataset.n_images-1);
+			}
+
+			accum_reset |= ImGui::Combo("Groundtruth render mode", (int*)&m_ground_truth_render_mode, GroundTruthRenderModeStr);
+			accum_reset |= ImGui::SliderFloat("Groundtruth alpha", &m_ground_truth_alpha, 0.0f, 1.0f, "%.02f", ImGuiSliderFlags_AlwaysClamp);
+
+			bool lens_changed = ImGui::Checkbox("Apply lens distortion", &m_nerf.render_with_lens_distortion);
 
 			if (m_nerf.render_with_lens_distortion) {
-				accum_reset |= ImGui::Combo("Lens mode", (int*)&m_nerf.render_lens.mode, LensModeStr);
+				lens_changed |= ImGui::Combo("Lens mode", (int*)&m_nerf.render_lens.mode, LensModeStr);
 				if (m_nerf.render_lens.mode == ELensMode::OpenCV) {
 					accum_reset |= ImGui::InputFloat("k1", &m_nerf.render_lens.params[0], 0.f, 0.f, "%.5f");
 					accum_reset |= ImGui::InputFloat("k2", &m_nerf.render_lens.params[1], 0.f, 0.f, "%.5f");
@@ -1012,6 +1103,12 @@ void Testbed::imgui() {
 					accum_reset |= ImGui::InputFloat("f_theta p3", &m_nerf.render_lens.params[3], 0.f, 0.f, "%.5f");
 					accum_reset |= ImGui::InputFloat("f_theta p4", &m_nerf.render_lens.params[4], 0.f, 0.f, "%.5f");
 				}
+
+				if (lens_changed && !supports_dlss(m_nerf.render_lens.mode)) {
+					m_dlss = false;
+				}
+
+				accum_reset |= lens_changed;
 			}
 
 			accum_reset |= ImGui::SliderFloat("Min transmittance", &m_nerf.render_min_transmittance, 0.0f, 1.0f, "%.3f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat);
@@ -1032,6 +1129,7 @@ void Testbed::imgui() {
 			}
 
 			accum_reset |= ImGui::Checkbox("Analytic normals", &m_sdf.analytic_normals);
+			accum_reset |= ImGui::Checkbox("Floor", &m_floor_enable);
 
 			accum_reset |= ImGui::SliderFloat("Normals epsilon", &m_sdf.fd_normals_epsilon, 0.00001f, 0.1f, "%.6g", ImGuiSliderFlags_Logarithmic);
 			accum_reset |= ImGui::SliderFloat("Maximum distance", &m_sdf.maximum_distance, 0.00001f, 0.1f, "%.6g", ImGuiSliderFlags_Logarithmic);
@@ -1131,24 +1229,29 @@ void Testbed::imgui() {
 	}
 
 	if (ImGui::CollapsingHeader("Camera", ImGuiTreeNodeFlags_DefaultOpen)) {
-		if (ImGui::SliderFloat("Aperture size", &m_aperture_size, 0.0f, 0.1f)) {
+		ImGui::Checkbox("First person controls", &m_fps_camera);
+		ImGui::SameLine();
+		ImGui::Checkbox("Smooth motion", &m_camera_smoothing);
+		ImGui::SameLine();
+		ImGui::Checkbox("Autofocus", &m_autofocus);
+		ImGui::PushItemWidth(ImGui::GetWindowWidth() * 0.3f);
+		if (ImGui::SliderFloat("Aperture size", &m_aperture_size, 0.0f, 1.0f, "%.3f", ImGuiSliderFlags_Logarithmic | ImGuiSliderFlags_NoRoundToFormat)) {
 			m_dlss = false;
 			accum_reset = true;
 		}
+		ImGui::SameLine();
+		accum_reset |= ImGui::SliderFloat("Focus depth", &m_slice_plane_z, -m_bounding_radius, m_bounding_radius);
+
 		float local_fov = fov();
 		if (ImGui::SliderFloat("Field of view", &local_fov, 0.0f, 120.0f)) {
 			set_fov(local_fov);
 			accum_reset = true;
 		}
+		ImGui::SameLine();
 		accum_reset |= ImGui::SliderFloat("Zoom", &m_zoom, 1.f, 10.f);
-		if (m_testbed_mode == ETestbedMode::Sdf) {
-			accum_reset |= ImGui::Checkbox("Floor", &m_floor_enable);
-			ImGui::SameLine();
-		}
+		ImGui::PopItemWidth();
+
 
-		ImGui::Checkbox("First person controls", &m_fps_camera);
-		ImGui::Checkbox("Smooth camera motion", &m_camera_smoothing);
-		ImGui::Checkbox("Autofocus", &m_autofocus);
 
 		if (ImGui::TreeNode("Advanced camera settings")) {
 			accum_reset |= ImGui::SliderFloat2("Screen center", &m_screen_center.x(), 0.f, 1.f);
@@ -1218,7 +1321,7 @@ void Testbed::imgui() {
 		}
 	}
 
-	if (ImGui::CollapsingHeader("Snapshot")) {
+	if (ImGui::CollapsingHeader("Snapshot", ImGuiTreeNodeFlags_DefaultOpen)) {
 		ImGui::Text("Snapshot");
 		ImGui::SameLine();
 		if (ImGui::Button("Save")) {
@@ -1329,7 +1432,7 @@ void Testbed::imgui() {
 			ImGui::Text("%dx%dx%d", res3d.x(), res3d.y(), res3d.z());
 			float thresh_range = (m_testbed_mode == ETestbedMode::Sdf) ? 0.5f : 10.f;
 			ImGui::SliderFloat("MC density threshold",&m_mesh.thresh, -thresh_range, thresh_range);
-			ImGui::Combo("Mesh render mode", (int*)&m_mesh_render_mode, "Off\0Vertex Colors\0Vertex Normals\0Face IDs\0");
+			ImGui::Combo("Mesh render mode", (int*)&m_mesh_render_mode, "Off\0Vertex Colors\0Vertex Normals\0\0");
 			ImGui::Checkbox("Unwrap mesh", &m_mesh.unwrap);
 			if (uint32_t tricount = m_mesh.indices.size()/3) {
 				ImGui::InputText("##OBJFile", m_imgui.mesh_path, sizeof(m_imgui.mesh_path));
@@ -1446,11 +1549,11 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 		view2world.setIdentity();
 		view2world.block<3,4>(0,0) = camera_matrix;
 
-		auto focal = calc_focal_length(Vector2i::Ones(), m_fov_axis, m_zoom);
+		auto focal = calc_focal_length(Vector2i::Ones(), m_relative_focal_length, m_fov_axis, m_zoom);
 		float zscale = 1.0f / focal[m_fov_axis];
 
 		float xyscale = (float)m_window_res[m_fov_axis];
-		Vector2f screen_center = render_screen_center();
+		Vector2f screen_center = render_screen_center(m_screen_center);
 		view2proj <<
 			xyscale, 0,       (float)m_window_res.x()*screen_center.x()*zscale, 0,
 			0,       xyscale, (float)m_window_res.y()*screen_center.y()*zscale, 0,
@@ -1478,12 +1581,12 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 				float flx = focal.x();
 				float fly = focal.y();
 				Matrix<float, 4, 4> view2proj_guizmo;
-				float zfar = 100.f;
-				float znear = 0.1f;
+				float zfar = m_ndc_zfar;
+				float znear = m_ndc_znear;
 				view2proj_guizmo <<
-					fly*2.f/aspect, 0, 0, 0,
-					0, -fly*2.f, 0, 0,
-					0, 0, (zfar+znear)/(zfar-znear), -(2.f*zfar*znear) / (zfar-znear),
+					fly * 2.f / aspect, 0, 0, 0,
+					0, -fly * 2.f, 0, 0,
+					0, 0, (zfar + znear) / (zfar - znear), -(2.f * zfar * znear) / (zfar - znear),
 					0, 0, 1, 0;
 				ImGuizmo::SetRect(0, 0, io.DisplaySize.x, io.DisplaySize.y);
 
@@ -1502,8 +1605,8 @@ void Testbed::draw_visualizations(ImDrawList* list, const Matrix<float, 3, 4>& c
 			}
 		}
 
-		if (m_camera_path.imgui_viz(list, view2proj, world2proj, world2view, focal, aspect)) {
-			m_pip_render_surface->reset_accumulation();
+		if (m_camera_path.imgui_viz(list, view2proj, world2proj, world2view, focal, aspect, m_ndc_znear, m_ndc_zfar)) {
+			m_pip_render_buffer->reset_accumulation();
 		}
 	}
 }
@@ -1675,46 +1778,66 @@ bool Testbed::keyboard_event() {
 
 	if (translate_vec != Vector3f::Zero()) {
 		m_fps_camera = true;
-		translate_camera(translate_vec);
+
+		// If VR is active, movement that isn't aligned with the current view
+		// direction is _very_ jarring to the user, so make keyboard-based
+		// movement aligned with the VR view, even though it is not an intended
+		// movement mechanism. (Users should use controllers.)
+		translate_camera(translate_vec, m_hmd && m_hmd->is_visible() ? m_views.front().camera0.block<3, 3>(0, 0) : m_camera.block<3, 3>(0, 0));
 	}
 
 	return false;
 }
 
-void Testbed::mouse_wheel(Vector2f m, float delta) {
+void Testbed::mouse_wheel() {
+	float delta = ImGui::GetIO().MouseWheel;
 	if (delta == 0) {
 		return;
 	}
 
-	if (!ImGui::GetIO().WantCaptureMouse) {
-		float scale_factor = pow(1.1f, -delta);
-		m_image.pos = (m_image.pos - m) / scale_factor + m;
-		set_scale(m_scale * scale_factor);
+	float scale_factor = pow(1.1f, -delta);
+	set_scale(m_scale * scale_factor);
+
+	// When in image mode, zoom around the hovered point.
+	if (m_testbed_mode == ETestbedMode::Image) {
+		Vector2i mouse = {ImGui::GetMousePos().x, ImGui::GetMousePos().y};
+		Vector3f offset = get_3d_pos_from_pixel(*m_views.front().render_buffer, mouse) - look_at();
+
+		// Don't center around infinitely distant points.
+		if (offset.norm() < 256.0f) {
+			m_camera.col(3) += offset * (1.0f - scale_factor);
+		}
 	}
 
 	reset_accumulation(true);
 }
 
-void Testbed::mouse_drag(const Vector2f& rel, int button) {
+Matrix3f Testbed::rotation_from_angles(const Vector2f& angles) const {
 	Vector3f up = m_up_dir;
 	Vector3f side = m_camera.col(0);
+	return (AngleAxisf(angles.x(), up) * AngleAxisf(angles.y(), side)).matrix();
+}
+
+void Testbed::mouse_drag() {
+	Vector2f rel = Vector2f{ImGui::GetIO().MouseDelta.x, ImGui::GetIO().MouseDelta.y} / (float)m_window_res[m_fov_axis];
+	Vector2i mouse = {ImGui::GetMousePos().x, ImGui::GetMousePos().y};
 
-	bool is_left_held = (button & 1) != 0;
-	bool is_right_held = (button & 2) != 0;
+	Vector3f up = m_up_dir;
+	Vector3f side = m_camera.col(0);
 
 	bool shift = ImGui::GetIO().KeyMods & ImGuiKeyModFlags_Shift;
-	if (is_left_held) {
+
+	// Left held
+	if (ImGui::GetIO().MouseDown[0]) {
 		if (shift) {
-			auto mouse = ImGui::GetMousePos();
-			determine_autofocus_target_from_pixel({mouse.x, mouse.y});
+			m_autofocus_target = get_3d_pos_from_pixel(*m_views.front().render_buffer, mouse);
+			m_autofocus = true;
+
 			reset_accumulation();
 		} else {
 			float rot_sensitivity = m_fps_camera ? 0.35f : 1.0f;
-			Matrix3f rot =
-				(AngleAxisf(static_cast<float>(-rel.x() * 2 * PI() * rot_sensitivity), up) * // Scroll sideways around up vector
-				AngleAxisf(static_cast<float>(-rel.y() * 2 * PI() * rot_sensitivity), side)).matrix(); // Scroll around side vector
+			Matrix3f rot = rotation_from_angles(-rel * 2 * PI() * rot_sensitivity);
 
-			m_image.pos += rel;
 			if (m_fps_camera) {
 				m_camera.block<3, 3>(0, 0) = rot * m_camera.block<3, 3>(0, 0);
 			} else {
@@ -1729,11 +1852,9 @@ void Testbed::mouse_drag(const Vector2f& rel, int button) {
 		}
 	}
 
-	if (is_right_held) {
-		Matrix3f rot =
-			(AngleAxisf(static_cast<float>(-rel.x() * 2 * PI()), up) * // Scroll sideways around up vector
-			AngleAxisf(static_cast<float>(-rel.y() * 2 * PI()), side)).matrix(); // Scroll around side vector
-
+	// Right held
+	if (ImGui::GetIO().MouseDown[1]) {
+		Matrix3f rot = rotation_from_angles(-rel * 2 * PI());
 		if (m_render_mode == ERenderMode::Shade) {
 			m_sun_dir = rot.transpose() * m_sun_dir;
 		}
@@ -1742,14 +1863,27 @@ void Testbed::mouse_drag(const Vector2f& rel, int button) {
 		reset_accumulation();
 	}
 
-	bool is_middle_held = (button & 4) != 0;
-	if (is_middle_held) {
-		translate_camera({-rel.x(), -rel.y(), 0.0f});
+	// Middle pressed
+	if (ImGui::GetIO().MouseClicked[2]) {
+		m_drag_depth = get_depth_from_renderbuffer(*m_views.front().render_buffer, mouse.cast<float>().cwiseQuotient(m_window_res.cast<float>()));
+	}
+
+	// Middle held
+	if (ImGui::GetIO().MouseDown[2]) {
+		Vector3f translation = Vector3f{-rel.x(), -rel.y(), 0.0f} / m_zoom;
+
+		// If we have a valid depth value, scale the scene translation by it such that the
+		// hovered point in 3D space stays under the cursor.
+		if (m_drag_depth < 256.0f) {
+			translation *= m_drag_depth / m_relative_focal_length[m_fov_axis];
+		}
+
+		translate_camera(translation, m_camera.block<3, 3>(0, 0));
 	}
 }
 
-bool Testbed::begin_frame_and_handle_user_input() {
-	if (glfwWindowShouldClose(m_glfw_window) || ImGui::IsKeyDown(GLFW_KEY_ESCAPE) || ImGui::IsKeyDown(GLFW_KEY_Q)) {
+bool Testbed::begin_frame() {
+	if (glfwWindowShouldClose(m_glfw_window) || ImGui::IsKeyPressed(GLFW_KEY_ESCAPE) || ImGui::IsKeyPressed(GLFW_KEY_Q)) {
 		destroy_window();
 		return false;
 	}
@@ -1769,21 +1903,18 @@ bool Testbed::begin_frame_and_handle_user_input() {
 	ImGui::NewFrame();
 	ImGuizmo::BeginFrame();
 
+	return true;
+}
+
+void Testbed::handle_user_input() {
 	if (ImGui::IsKeyPressed(GLFW_KEY_TAB) || ImGui::IsKeyPressed(GLFW_KEY_GRAVE_ACCENT)) {
 		m_imgui.enabled = !m_imgui.enabled;
 	}
 
-	ImVec2 m = ImGui::GetMousePos();
-	int mb = 0;
-	float mw = 0.f;
-	ImVec2 relm = {};
-	if (!ImGui::IsAnyItemActive() && !ImGuizmo::IsUsing() && !ImGuizmo::IsOver()) {
-		relm = ImGui::GetIO().MouseDelta;
-		if (ImGui::GetIO().MouseDown[0]) mb |= 1;
-		if (ImGui::GetIO().MouseDown[1]) mb |= 2;
-		if (ImGui::GetIO().MouseDown[2]) mb |= 4;
-		mw = ImGui::GetIO().MouseWheel;
-		relm = {relm.x / (float)m_window_res.y(), relm.y / (float)m_window_res.y()};
+	// Only respond to mouse inputs when not interacting with ImGui
+	if (!ImGui::IsAnyItemActive() && !ImGuizmo::IsUsing() && !ImGuizmo::IsOver() && !ImGui::GetIO().WantCaptureMouse) {
+		mouse_wheel();
+		mouse_drag();
 	}
 
 	if (m_testbed_mode == ETestbedMode::Nerf && (m_render_ground_truth || m_nerf.training.render_error_overlay)) {
@@ -1791,21 +1922,150 @@ bool Testbed::begin_frame_and_handle_user_input() {
 		int bestimage = find_best_training_view(-1);
 		if (bestimage >= 0) {
 			m_nerf.training.view = bestimage;
-			if (mb == 0) {// snap camera to ground truth view on mouse up
+			if (ImGui::GetIO().MouseReleased[0]) {// snap camera to ground truth view on mouse up
 				set_camera_to_training_view(m_nerf.training.view);
 			}
 		}
 	}
 
 	keyboard_event();
-	mouse_wheel({m.x / (float)m_window_res.y(), m.y / (float)m_window_res.y()}, mw);
-	mouse_drag({relm.x, relm.y}, mb);
 
 	if (m_imgui.enabled) {
 		imgui();
 	}
+}
 
-	return true;
+Vector3f Testbed::vr_to_world(const Vector3f& pos) const {
+	return m_camera.block<3, 3>(0, 0) * pos * m_scale + m_camera.col(3);
+}
+
+void Testbed::begin_vr_frame_and_handle_vr_input() {
+	if (!m_hmd) {
+		m_vr_frame_info = nullptr;
+		return;
+	}
+
+	m_hmd->poll_events();
+	if (!m_hmd->must_run_frame_loop()) {
+		m_vr_frame_info = nullptr;
+		return;
+	}
+
+	m_vr_frame_info = m_hmd->begin_frame();
+
+	const auto& views = m_vr_frame_info->views;
+	size_t n_views = views.size();
+	size_t n_devices = m_devices.size();
+	if (n_views > 0) {
+		set_n_views(n_views);
+
+		Vector2i total_size = Vector2i::Zero();
+		for (size_t i = 0; i < n_views; ++i) {
+			Vector2i view_resolution = {views[i].view.subImage.imageRect.extent.width, views[i].view.subImage.imageRect.extent.height};
+			total_size += view_resolution;
+
+			m_views[i].full_resolution = view_resolution;
+
+			// Apply the VR pose relative to the world camera transform.
+			m_views[i].camera0.block<3, 3>(0, 0) = m_camera.block<3, 3>(0, 0) * views[i].pose.block<3, 3>(0, 0);
+			m_views[i].camera0.col(3) = vr_to_world(views[i].pose.col(3));
+			m_views[i].camera1 = m_views[i].camera0;
+
+			m_views[i].visualized_dimension = m_visualized_dimension;
+
+			const auto& xr_fov = views[i].view.fov;
+
+			// Compute the distance on the image plane (1 unit away from the camera) that an angle of the respective FOV spans
+			Vector2f rel_focal_length_left_down = 0.5f * fov_to_focal_length(Vector2i::Ones(), Vector2f{360.0f * xr_fov.angleLeft / PI(), 360.0f * xr_fov.angleDown / PI()});
+			Vector2f rel_focal_length_right_up = 0.5f * fov_to_focal_length(Vector2i::Ones(), Vector2f{360.0f * xr_fov.angleRight / PI(), 360.0f * xr_fov.angleUp / PI()});
+
+			// Compute total distance (for X and Y) that is spanned on the image plane.
+			m_views[i].relative_focal_length = rel_focal_length_right_up - rel_focal_length_left_down;
+
+			// Compute fraction of that distance that is spanned by the right-up part and set screen center accordingly.
+			Vector2f ratio = rel_focal_length_right_up.cwiseQuotient(m_views[i].relative_focal_length);
+			m_views[i].screen_center = { 1.0f - ratio.x(), ratio.y() };
+
+			// Fix up weirdness in the rendering pipeline
+			m_views[i].relative_focal_length[(m_fov_axis+1)%2] *= (float)view_resolution[(m_fov_axis+1)%2] / (float)view_resolution[m_fov_axis];
+			m_views[i].render_buffer->set_hidden_area_mask(views[i].hidden_area_mask);
+
+			// Render each view on a different GPU (if available)
+			m_views[i].device = m_use_aux_devices ? &m_devices.at(i % m_devices.size()) : &primary_device();
+		}
+
+		// Put all the views next to each other, but at half size
+		glfwSetWindowSize(m_glfw_window, total_size.x() / 2, (total_size.y() / 2) / n_views);
+
+		// VR controller input
+		const auto& hands = m_vr_frame_info->hands;
+		m_fps_camera = true;
+
+		// TRANSLATE BY STICK (if not pressing the stick)
+		if (!hands[0].pressing) {
+			Vector3f translate_vec = Vector3f{hands[0].thumbstick.x(), 0.0f, hands[0].thumbstick.y()} * m_camera_velocity * m_frame_ms.val() / 1000.0f;
+			if (translate_vec != Vector3f::Zero()) {
+				translate_camera(translate_vec, m_views.front().camera0.block<3, 3>(0, 0), false);
+			}
+		}
+
+		// TURN BY STICK (if not pressing the stick)
+		if (!hands[1].pressing) {
+			auto prev_camera = m_camera;
+
+			// Turn around the up vector (equivalent to x-axis mouse drag) with right joystick left/right
+			float sensitivity = 0.35f;
+			m_camera.block<3, 3>(0, 0) = rotation_from_angles({-2.0f * PI() * sensitivity * hands[1].thumbstick.x() * m_frame_ms.val() / 1000.0f, 0.0f}) * m_camera.block<3, 3>(0, 0);
+
+			// Translate camera such that center of rotation was about the current view
+			m_camera.col(3) += prev_camera.block<3, 3>(0, 0) * views[0].pose.col(3) * m_scale - m_camera.block<3, 3>(0, 0) * views[0].pose.col(3) * m_scale;
+		}
+
+		// TRANSLATE, SCALE, AND ROTATE BY GRAB
+		{
+			bool both_grabbing = hands[0].grabbing && hands[1].grabbing;
+			float drag_factor = both_grabbing ? 0.5f : 1.0f;
+
+			if (both_grabbing) {
+				drag_factor = 0.5f;
+
+				Vector3f prev_diff = hands[0].prev_grab_pos - hands[1].prev_grab_pos;
+				Vector3f diff = hands[0].grab_pos - hands[1].grab_pos;
+				Vector3f center = 0.5f * (hands[0].grab_pos + hands[1].grab_pos);
+
+				Vector3f center_world = vr_to_world(0.5f * (hands[0].grab_pos + hands[1].grab_pos));
+
+				// Scale around center position of the two dragging hands. Makes the scaling feel similar to phone pinch-to-zoom
+				float scale = m_scale * prev_diff.norm() / diff.norm();
+				m_camera.col(3) = (view_pos() - center_world) * (scale / m_scale) + center_world;
+				m_scale = scale;
+
+				// Take rotational component and project it to the nearest rotation about the up vector.
+				// We don't want to rotate the scene about any other axis.
+				Vector3f rot = prev_diff.normalized().cross(diff.normalized());
+				float rot_radians = std::asin(m_up_dir.dot(rot));
+
+				auto prev_camera = m_camera;
+				m_camera.block<3, 3>(0, 0) = AngleAxisf(rot_radians, m_up_dir) * m_camera.block<3, 3>(0, 0);
+				m_camera.col(3) += prev_camera.block<3, 3>(0, 0) * center * m_scale - m_camera.block<3, 3>(0, 0) * center * m_scale;
+			}
+
+			for (const auto& hand : hands) {
+				if (hand.grabbing) {
+					m_camera.col(3) -= drag_factor * m_camera.block<3, 3>(0, 0) * hand.drag() * m_scale;
+				}
+			}
+		}
+
+		// ERASE OCCUPANCY WHEN PRESSING STICK/TRACKPAD
+		if (m_testbed_mode == ETestbedMode::Nerf) {
+			for (const auto& hand : hands) {
+				if (hand.pressing) {
+					mark_density_grid_in_sphere_empty(vr_to_world(hand.pose.col(3)), m_scale * 0.05f, m_stream.get());
+				}
+			}
+		}
+	}
 }
 
 void Testbed::SecondWindow::draw(GLuint texture) {
@@ -1834,11 +2094,164 @@ void Testbed::SecondWindow::draw(GLuint texture) {
 	glfwMakeContextCurrent(old_context);
 }
 
+void Testbed::init_opengl_shaders() {
+	static const char* shader_vert = R"(#version 140
+		out vec2 UVs;
+		void main() {
+			UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2);
+			gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0);
+		})";
+
+	static const char* shader_frag = R"(#version 140
+		in vec2 UVs;
+		out vec4 frag_color;
+		uniform sampler2D rgba_texture;
+		uniform sampler2D depth_texture;
+
+		struct FoveationWarp {
+			float al, bl, cl;
+			float am, bm;
+			float ar, br, cr;
+			float switch_left, switch_right;
+			float inv_switch_left, inv_switch_right;
+		};
+
+		uniform FoveationWarp warp_x;
+		uniform FoveationWarp warp_y;
+
+		float unwarp(in FoveationWarp warp, float y) {
+			y = clamp(y, 0.0, 1.0);
+			if (y < warp.inv_switch_left) {
+				return (sqrt(-4.0 * warp.al * warp.cl + 4.0 * warp.al * y + warp.bl * warp.bl) - warp.bl) / (2.0 * warp.al);
+			} else if (y > warp.inv_switch_right) {
+				return (sqrt(-4.0 * warp.ar * warp.cr + 4.0 * warp.ar * y + warp.br * warp.br) - warp.br) / (2.0 * warp.ar);
+			} else {
+				return (y - warp.bm) / warp.am;
+			}
+		}
+
+		vec2 unwarp(in vec2 pos) {
+			return vec2(unwarp(warp_x, pos.x), unwarp(warp_y, pos.y));
+		}
+
+		void main() {
+			vec2 tex_coords = UVs;
+			tex_coords.y = 1.0 - tex_coords.y;
+			tex_coords = unwarp(tex_coords);
+			frag_color = texture(rgba_texture, tex_coords.xy);
+			//Uncomment the following line of code to visualize debug the depth buffer for debugging.
+			// frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0);
+			gl_FragDepth = texture(depth_texture, tex_coords.xy).r;
+		})";
+
+	GLuint vert = glCreateShader(GL_VERTEX_SHADER);
+	glShaderSource(vert, 1, &shader_vert, NULL);
+	glCompileShader(vert);
+	check_shader(vert, "Blit vertex shader", false);
+
+	GLuint frag = glCreateShader(GL_FRAGMENT_SHADER);
+	glShaderSource(frag, 1, &shader_frag, NULL);
+	glCompileShader(frag);
+	check_shader(frag, "Blit fragment shader", false);
+
+	m_blit_program = glCreateProgram();
+	glAttachShader(m_blit_program, vert);
+	glAttachShader(m_blit_program, frag);
+	glLinkProgram(m_blit_program);
+	check_shader(m_blit_program, "Blit shader program", true);
+
+	glDeleteShader(vert);
+	glDeleteShader(frag);
+
+	glGenVertexArrays(1, &m_blit_vao);
+}
+
+void Testbed::blit_texture(const Foveation& foveation, GLint rgba_texture, GLint rgba_filter_mode, GLint depth_texture, GLint framebuffer, const Vector2i& offset, const Vector2i& resolution) {
+	if (m_blit_program == 0) {
+		return;
+	}
+
+	// Blit image to OpenXR swapchain.
+	// Note that the OpenXR swapchain is 8bit while the rendering is in a float texture.
+	// As some XR runtimes do not support float swapchains, we can't render into it directly.
+
+	bool tex = glIsEnabled(GL_TEXTURE_2D);
+	bool depth = glIsEnabled(GL_DEPTH_TEST);
+	bool cull = glIsEnabled(GL_CULL_FACE);
+
+	if (!tex) glEnable(GL_TEXTURE_2D);
+	if (!depth) glEnable(GL_DEPTH_TEST);
+	if (cull) glDisable(GL_CULL_FACE);
+
+	glDepthFunc(GL_ALWAYS);
+	glDepthMask(GL_TRUE);
+
+	glBindVertexArray(m_blit_vao);
+	glUseProgram(m_blit_program);
+	glUniform1i(glGetUniformLocation(m_blit_program, "rgba_texture"), 0);
+	glUniform1i(glGetUniformLocation(m_blit_program, "depth_texture"), 1);
+
+	auto bind_warp = [&](const FoveationPiecewiseQuadratic& warp, const std::string& uniform_name) {
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".al").c_str()), warp.al);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bl").c_str()), warp.bl);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cl").c_str()), warp.cl);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".am").c_str()), warp.am);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".bm").c_str()), warp.bm);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".ar").c_str()), warp.ar);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".br").c_str()), warp.br);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".cr").c_str()), warp.cr);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_left").c_str()), warp.switch_left);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".switch_right").c_str()), warp.switch_right);
+
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_left").c_str()), warp.inv_switch_left);
+		glUniform1f(glGetUniformLocation(m_blit_program, (uniform_name + ".inv_switch_right").c_str()), warp.inv_switch_right);
+	};
+
+	bind_warp(foveation.warp_x, "warp_x");
+	bind_warp(foveation.warp_y, "warp_y");
+
+	glActiveTexture(GL_TEXTURE1);
+	glBindTexture(GL_TEXTURE_2D, depth_texture);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
+
+	glActiveTexture(GL_TEXTURE0);
+	glBindTexture(GL_TEXTURE_2D, rgba_texture);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, rgba_filter_mode);
+	glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, rgba_filter_mode);
+
+	glBindFramebuffer(GL_FRAMEBUFFER, framebuffer);
+	glViewport(offset.x(), offset.y(), resolution.x(), resolution.y());
+
+	glDrawArrays(GL_TRIANGLES, 0, 3);
+
+	glBindVertexArray(0);
+	glUseProgram(0);
+
+	glDepthFunc(GL_LESS);
+
+	// restore old state
+	if (!tex) glDisable(GL_TEXTURE_2D);
+	if (!depth) glDisable(GL_DEPTH_TEST);
+	if (cull) glEnable(GL_CULL_FACE);
+	glBindFramebuffer(GL_FRAMEBUFFER, 0);
+}
+
 void Testbed::draw_gui() {
 	// Make sure all the cuda code finished its business here
 	CUDA_CHECK_THROW(cudaDeviceSynchronize());
-	if (!m_render_textures.empty())
-		m_second_window.draw((GLuint)m_render_textures.front()->texture());
+
+	if (!m_rgba_render_textures.empty()) {
+		m_second_window.draw((GLuint)m_rgba_render_textures.front()->texture());
+	}
+
 	glfwMakeContextCurrent(m_glfw_window);
 	int display_w, display_h;
 	glfwGetFramebufferSize(m_glfw_window, &display_w, &display_h);
@@ -1846,56 +2259,42 @@ void Testbed::draw_gui() {
 	glClearColor(0.f, 0.f, 0.f, 0.f);
 	glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
 
+	glEnable(GL_BLEND);
+	glBlendEquationSeparate(GL_FUNC_ADD, GL_FUNC_ADD);
+	glBlendFuncSeparate(GL_ONE, GL_ONE_MINUS_SRC_ALPHA, GL_ONE, GL_ONE_MINUS_SRC_ALPHA);
 
-	ImDrawList* list = ImGui::GetBackgroundDrawList();
-	list->AddCallback([](const ImDrawList*, const ImDrawCmd*) {
-		glBlendEquationSeparate(GL_FUNC_ADD, GL_FUNC_ADD);
-		glBlendFuncSeparate(GL_ONE, GL_ONE_MINUS_SRC_ALPHA, GL_ONE, GL_ONE_MINUS_SRC_ALPHA);
-	}, nullptr);
-
-	if (m_single_view) {
-		list->AddImageQuad((ImTextureID)(size_t)m_render_textures.front()->texture(), ImVec2{0.f, 0.f}, ImVec2{(float)display_w, 0.f}, ImVec2{(float)display_w, (float)display_h}, ImVec2{0.f, (float)display_h}, ImVec2(0, 0), ImVec2(1, 0), ImVec2(1, 1), ImVec2(0, 1));
-	} else {
-		m_dlss = false;
+	Vector2i extent = Vector2f{(float)display_w / m_n_views.x(), (float)display_h / m_n_views.y()}.cast<int>();
 
-		int i = 0;
-		for (int y = 0; y < m_n_views.y(); ++y) {
-			for (int x = 0; x < m_n_views.x(); ++x) {
-				if (i >= m_render_surfaces.size()) {
-					break;
-				}
+	int i = 0;
+	for (int y = 0; y < m_n_views.y(); ++y) {
+		for (int x = 0; x < m_n_views.x(); ++x) {
+			if (i >= m_views.size()) {
+				break;
+			}
 
-				Vector2f top_left{x * m_view_size.x(), y * m_view_size.y()};
-
-				list->AddImageQuad(
-					(ImTextureID)(size_t)m_render_textures[i]->texture(),
-					ImVec2{top_left.x(),                          top_left.y()                         },
-					ImVec2{top_left.x() + (float)m_view_size.x(), top_left.y()                         },
-					ImVec2{top_left.x() + (float)m_view_size.x(), top_left.y() + (float)m_view_size.y()},
-					ImVec2{top_left.x(),                          top_left.y() + (float)m_view_size.y()},
-					ImVec2(0, 0),
-					ImVec2(1, 0),
-					ImVec2(1, 1),
-					ImVec2(0, 1)
-				);
+			auto& view = m_views[i];
+			Vector2i top_left{x * extent.x(), display_h - (y + 1) * extent.y()};
+			blit_texture(view.foveation, m_rgba_render_textures.at(i)->texture(), m_foveated_rendering ? GL_LINEAR : GL_NEAREST, m_depth_render_textures.at(i)->texture(), 0, top_left, extent);
 
-				++i;
-			}
+			++i;
 		}
 	}
+	glFinish();
+	glViewport(0, 0, display_w, display_h);
+
 
+	ImDrawList* list = ImGui::GetBackgroundDrawList();
 	list->AddCallback(ImDrawCallback_ResetRenderState, nullptr);
 
 	auto draw_mesh = [&]() {
 		glClear(GL_DEPTH_BUFFER_BIT);
 		Vector2i res(display_w, display_h);
-		Vector2f focal_length = calc_focal_length(res, m_fov_axis, m_zoom);
-		Vector2f screen_center = render_screen_center();
-		draw_mesh_gl(m_mesh.verts, m_mesh.vert_normals, m_mesh.vert_colors, m_mesh.indices, res, focal_length, m_smoothed_camera, screen_center, (int)m_mesh_render_mode);
+		Vector2f focal_length = calc_focal_length(res, m_relative_focal_length, m_fov_axis, m_zoom);
+		draw_mesh_gl(m_mesh.verts, m_mesh.vert_normals, m_mesh.vert_colors, m_mesh.indices, res, focal_length, m_smoothed_camera, render_screen_center(m_screen_center), (int)m_mesh_render_mode);
 	};
 
 	// Visualizations are only meaningful when rendering a single view
-	if (m_single_view) {
+	if (m_views.size() == 1) {
 		if (m_mesh.verts.size() != 0 && m_mesh.indices.size() != 0 && m_mesh_render_mode != EMeshRenderMode::Off) {
 			list->AddCallback([](const ImDrawList*, const ImDrawCmd* cmd) {
 				(*(decltype(draw_mesh)*)cmd->UserCallbackData)();
@@ -1955,7 +2354,7 @@ void Testbed::prepare_next_camera_path_frame() {
 	// If we're rendering a video, we'd like to accumulate multiple spp
 	// for motion blur. Hence dump the frame once the target spp has been reached
 	// and only reset _then_.
-	if (m_render_surfaces.front().spp() == m_camera_path.render_settings.spp) {
+	if (m_views.front().render_buffer->spp() == m_camera_path.render_settings.spp) {
 		auto tmp_dir = fs::path{"tmp"};
 		if (!tmp_dir.exists()) {
 			if (!fs::create_directory(tmp_dir)) {
@@ -1965,7 +2364,7 @@ void Testbed::prepare_next_camera_path_frame() {
 			}
 		}
 
-		Vector2i res = m_render_surfaces.front().out_resolution();
+		Vector2i res = m_views.front().render_buffer->out_resolution();
 		const dim3 threads = { 16, 8, 1 };
 		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 
@@ -1973,7 +2372,7 @@ void Testbed::prepare_next_camera_path_frame() {
 		to_8bit_color_kernel<<<blocks, threads>>>(
 			res,
 			EColorSpace::SRGB, // the GUI always renders in SRGB
-			m_render_surfaces.front().surface(),
+			m_views.front().render_buffer->surface(),
 			image_data.data()
 		);
 
@@ -2047,7 +2446,7 @@ void Testbed::prepare_next_camera_path_frame() {
 	const auto& rs = m_camera_path.render_settings;
 	m_camera_path.play_time = (float)((double)m_camera_path.render_frame_idx / (double)rs.n_frames());
 
-	if (m_render_surfaces.front().spp() == 0) {
+	if (m_views.front().render_buffer->spp() == 0) {
 		set_camera_from_time(m_camera_path.play_time);
 		apply_camera_smoothing(rs.frame_milliseconds());
 
@@ -2109,134 +2508,204 @@ void Testbed::train_and_render(bool skip_rendering) {
 		autofocus();
 	}
 
-	if (m_single_view) {
-		// Should have been created when the window was created.
-		assert(!m_render_surfaces.empty());
+#ifdef NGP_GUI
+	if (m_hmd && m_hmd->is_visible()) {
+		for (auto& view : m_views) {
+			view.visualized_dimension = m_visualized_dimension;
+		}
 
-		auto& render_buffer = m_render_surfaces.front();
+		m_n_views = {m_views.size(), 1};
 
-		{
-			// Don't count the time being spent allocating buffers and resetting DLSS as part of the frame time.
-			// Otherwise the dynamic resolution calculations for following frames will be thrown out of whack
-			// and may even start oscillating.
-			auto skip_start = std::chrono::steady_clock::now();
-			ScopeGuard skip_timing_guard{[&]() {
-				start += std::chrono::steady_clock::now() - skip_start;
-			}};
-			if (m_dlss) {
-				render_buffer.enable_dlss(m_window_res);
-				m_aperture_size = 0.0f;
-			} else {
-				render_buffer.disable_dlss();
-			}
+		m_nerf.render_with_lens_distortion = false;
+		reset_accumulation(true);
+	} else if (m_single_view) {
+		set_n_views(1);
+		m_n_views = {1, 1};
 
-			auto render_res = render_buffer.in_resolution();
-			if (render_res.isZero() || (m_train && m_training_step == 0)) {
-				render_res = m_window_res/16;
-			} else {
-				render_res = render_res.cwiseMin(m_window_res);
-			}
+		auto& view = m_views.front();
 
-			float render_time_per_fullres_frame = m_render_ms.val() / (float)render_res.x() / (float)render_res.y() * (float)m_window_res.x() * (float)m_window_res.y();
+		view.full_resolution = m_window_res;
 
-			// Make sure we don't starve training with slow rendering
-			float factor = std::sqrt(1000.0f / m_dynamic_res_target_fps / render_time_per_fullres_frame);
-			if (!m_dynamic_res) {
-				factor = 8.f/(float)m_fixed_res_factor;
-			}
+		view.camera0 = m_smoothed_camera;
 
-			factor = tcnn::clamp(factor, 1.0f/16.0f, 1.0f);
+		// Motion blur over the fraction of time that the shutter is open. Interpolate in log-space to preserve rotations.
+		view.camera1 = m_camera_path.rendering ? log_space_lerp(m_smoothed_camera, m_camera_path.render_frame_end_camera, m_camera_path.render_settings.shutter_fraction) : view.camera0;
 
-			if (factor > m_last_render_res_factor * 1.2f || factor < m_last_render_res_factor * 0.8f || factor == 1.0f || !m_dynamic_res) {
-				render_res = (m_window_res.cast<float>() * factor).cast<int>().cwiseMin(m_window_res).cwiseMax(m_window_res/16);
-				m_last_render_res_factor = factor;
-			}
+		view.visualized_dimension = m_visualized_dimension;
+		view.relative_focal_length = m_relative_focal_length;
+		view.screen_center = m_screen_center;
+		view.render_buffer->set_hidden_area_mask(nullptr);
+		view.foveation = {};
+		view.device = &primary_device();
+	} else {
+		int n_views = n_dimensions_to_visualize()+1;
 
-			if (m_camera_path.rendering) {
-				render_res = m_camera_path.render_settings.resolution;
-				m_last_render_res_factor = 1.0f;
-			}
+		float d = std::sqrt((float)m_window_res.x() * (float)m_window_res.y() / (float)n_views);
+
+		int nx = (int)std::ceil((float)m_window_res.x() / d);
+		int ny = (int)std::ceil((float)n_views / (float)nx);
+
+		m_n_views = {nx, ny};
+		Vector2i view_size = {m_window_res.x() / nx, m_window_res.y() / ny};
+
+		set_n_views(n_views);
 
-			if (render_buffer.dlss()) {
-				render_res = render_buffer.dlss()->clamp_resolution(render_res);
-				render_buffer.dlss()->update_feature(render_res, render_buffer.dlss()->is_hdr(), render_buffer.dlss()->sharpen());
+		int i = 0;
+		for (int y = 0; y < ny; ++y) {
+			for (int x = 0; x < nx; ++x) {
+				if (i >= n_views) {
+					break;
+				}
+
+				m_views[i].full_resolution = view_size;
+
+				m_views[i].camera0 = m_views[i].camera1 = m_smoothed_camera;
+				m_views[i].visualized_dimension = i-1;
+				m_views[i].relative_focal_length = m_relative_focal_length;
+				m_views[i].screen_center = m_screen_center;
+				m_views[i].render_buffer->set_hidden_area_mask(nullptr);
+				m_views[i].foveation = {};
+				m_views[i].device = &primary_device();
+				++i;
 			}
+		}
+	}
 
-			render_buffer.resize(render_res);
+	if (m_dlss) {
+		m_aperture_size = 0.0f;
+		if (!supports_dlss(m_nerf.render_lens.mode)) {
+			m_nerf.render_with_lens_distortion = false;
 		}
+	}
 
-		render_frame(
-			m_smoothed_camera,
-			m_camera_path.rendering ? log_space_lerp(m_smoothed_camera, m_camera_path.render_frame_end_camera, m_camera_path.render_settings.shutter_fraction) : m_smoothed_camera,
-			{0.0f, 0.0f, 0.0f, 1.0f},
-			render_buffer
-		);
+	// Update dynamic res and DLSS
+	{
+		// Don't count the time being spent allocating buffers and resetting DLSS as part of the frame time.
+		// Otherwise the dynamic resolution calculations for following frames will be thrown out of whack
+		// and may even start oscillating.
+		auto skip_start = std::chrono::steady_clock::now();
+		ScopeGuard skip_timing_guard{[&]() {
+			start += std::chrono::steady_clock::now() - skip_start;
+		}};
 
-#ifdef NGP_GUI
-		m_render_textures.front()->blit_from_cuda_mapping();
+		size_t n_pixels = 0, n_pixels_full_res = 0;
+		for (const auto& view : m_views) {
+			n_pixels += view.render_buffer->in_resolution().prod();
+			n_pixels_full_res += view.full_resolution.prod();
+		}
 
-		if (m_picture_in_picture_res > 0) {
-			Vector2i res(m_picture_in_picture_res, m_picture_in_picture_res * 9/16);
-			m_pip_render_surface->resize(res);
-			if (m_pip_render_surface->spp() < 8) {
-				// a bit gross, but let's copy the keyframe's state into the global state in order to not have to plumb through the fov etc to render_frame.
-				CameraKeyframe backup = copy_camera_to_keyframe();
-				CameraKeyframe pip_kf = m_camera_path.eval_camera_path(m_camera_path.play_time);
-				set_camera_from_keyframe(pip_kf);
-				render_frame(pip_kf.m(), pip_kf.m(), Eigen::Vector4f::Zero(), *m_pip_render_surface);
-				set_camera_from_keyframe(backup);
+		float pixel_ratio = (n_pixels == 0 || (m_train && m_training_step == 0)) ? (1.0f / 256.0f) : ((float)n_pixels / (float)n_pixels_full_res);
 
-				m_pip_render_texture->blit_from_cuda_mapping();
-			}
+		float last_factor = std::sqrt(pixel_ratio);
+		float factor = std::sqrt(pixel_ratio / m_render_ms.val() * 1000.0f / m_dynamic_res_target_fps);
+		if (!m_dynamic_res) {
+			factor = 8.f / (float)m_fixed_res_factor;
 		}
-#endif
-	} else {
-#ifdef NGP_GUI
-		// Don't do DLSS when multi-view rendering
-		m_dlss = false;
-		m_render_surfaces.front().disable_dlss();
 
-		int n_views = n_dimensions_to_visualize()+1;
+		factor = tcnn::clamp(factor, 1.0f / 16.0f, 1.0f);
 
-		float d = std::sqrt((float)m_window_res.x() * (float)m_window_res.y() / (float)n_views);
+		for (auto&& view : m_views) {
+			if (m_dlss) {
+				view.render_buffer->enable_dlss(*m_dlss_provider, view.full_resolution);
+			} else {
+				view.render_buffer->disable_dlss();
+			}
 
-		int nx = (int)std::ceil((float)m_window_res.x() / d);
-		int ny = (int)std::ceil((float)n_views / (float)nx);
+			Vector2i render_res = view.render_buffer->in_resolution();
+			Vector2i new_render_res = (view.full_resolution.cast<float>() * factor).cast<int>().cwiseMin(view.full_resolution).cwiseMax(view.full_resolution / 16);
 
-		m_n_views = {nx, ny};
-		m_view_size = {m_window_res.x() / nx, m_window_res.y() / ny};
+			if (m_camera_path.rendering) {
+				new_render_res = m_camera_path.render_settings.resolution;
+			}
+
+			float ratio = std::sqrt((float)render_res.prod() / (float)new_render_res.prod());
+			if (ratio > 1.2f || ratio < 0.8f || factor == 1.0f || !m_dynamic_res || m_camera_path.rendering) {
+				render_res = new_render_res;
+			}
+
+			if (view.render_buffer->dlss()) {
+				render_res = view.render_buffer->dlss()->clamp_resolution(render_res);
+				view.render_buffer->dlss()->update_feature(render_res, view.render_buffer->dlss()->is_hdr(), view.render_buffer->dlss()->sharpen());
+			}
 
-		while (m_render_surfaces.size() > n_views) {
-			m_render_surfaces.pop_back();
+			view.render_buffer->resize(render_res);
+
+			if (m_foveated_rendering) {
+				float foveation_warped_full_res_diameter = 0.55f;
+				Vector2f resolution_scale = render_res.cast<float>().cwiseQuotient(view.full_resolution.cast<float>());
+
+				// Only start foveation when DLSS if off or if DLSS is asked to do more than 1.5x upscaling.
+				// The reason for the 1.5x threshold is that DLSS can do up to 3x upscaling, at which point a foveation
+				// factor of 2x = 3.0x/1.5x corresponds exactly to bilinear super sampling, which is helpful in
+				// suppressing DLSS's artifacts.
+				float foveation_begin_factor = m_dlss ? 1.5f : 1.0f;
+
+				resolution_scale = (resolution_scale * foveation_begin_factor).cwiseMin(1.0f).cwiseMax(1.0f / m_foveated_rendering_max_scaling);
+				view.foveation = {resolution_scale, Vector2f::Ones() - view.screen_center, Vector2f::Constant(foveation_warped_full_res_diameter * 0.5f)};
+			} else {
+				view.foveation = {};
+			}
 		}
+	}
 
-		m_render_textures.resize(n_views);
-		while (m_render_surfaces.size() < n_views) {
-			size_t idx = m_render_surfaces.size();
-			m_render_textures[idx] = std::make_shared<GLTexture>();
-			m_render_surfaces.emplace_back(m_render_textures[idx]);
+	// Make sure all in-use auxiliary GPUs have the latest model and bitfield
+	std::unordered_set<CudaDevice*> devices_in_use;
+	for (auto& view : m_views) {
+		if (!view.device || devices_in_use.count(view.device) != 0) {
+			continue;
 		}
 
-		int i = 0;
-		for (int y = 0; y < ny; ++y) {
-			for (int x = 0; x < nx; ++x) {
-				if (i >= n_views) {
-					return;
-				}
+		devices_in_use.insert(view.device);
+		sync_device(*view.render_buffer, *view.device);
+	}
 
-				m_visualized_dimension = i-1;
-				m_render_surfaces[i].resize(m_view_size);
+	{
+		SyncedMultiStream synced_streams{m_stream.get(), m_views.size()};
+
+		std::vector<std::future<void>> futures(m_views.size());
+		for (size_t i = 0; i < m_views.size(); ++i) {
+			auto& view = m_views[i];
+			futures[i] = view.device->enqueue_task([this, &view, stream=synced_streams.get(i)]() {
+				auto device_guard = use_device(stream, *view.render_buffer, *view.device);
+				render_frame_main(*view.device, view.camera0, view.camera1, view.screen_center, view.relative_focal_length, {0.0f, 0.0f, 0.0f, 1.0f}, view.foveation, view.visualized_dimension);
+			});
+		}
 
-				render_frame(m_smoothed_camera, m_smoothed_camera, Eigen::Vector4f::Zero(), m_render_surfaces[i]);
+		for (size_t i = 0; i < m_views.size(); ++i) {
+			auto& view = m_views[i];
 
-				m_render_textures[i]->blit_from_cuda_mapping();
-				++i;
+			if (futures[i].valid()) {
+				futures[i].get();
 			}
+
+			render_frame_epilogue(synced_streams.get(i), view.camera0, view.prev_camera, view.screen_center, view.relative_focal_length, view.foveation, view.prev_foveation, *view.render_buffer, true);
+			view.prev_camera = view.camera0;
+			view.prev_foveation = view.foveation;
 		}
-#else
-		throw std::runtime_error{"Multi-view rendering is only supported when compiling with NGP_GUI."};
-#endif
 	}
+
+	for (size_t i = 0; i < m_views.size(); ++i) {
+		m_rgba_render_textures.at(i)->blit_from_cuda_mapping();
+		m_depth_render_textures.at(i)->blit_from_cuda_mapping();
+	}
+
+	if (m_picture_in_picture_res > 0) {
+		Vector2i res(m_picture_in_picture_res, m_picture_in_picture_res * 9/16);
+		m_pip_render_buffer->resize(res);
+		if (m_pip_render_buffer->spp() < 8) {
+			// a bit gross, but let's copy the keyframe's state into the global state in order to not have to plumb through the fov etc to render_frame.
+			CameraKeyframe backup = copy_camera_to_keyframe();
+			CameraKeyframe pip_kf = m_camera_path.eval_camera_path(m_camera_path.play_time);
+			set_camera_from_keyframe(pip_kf);
+			render_frame(m_stream.get(), pip_kf.m(), pip_kf.m(), pip_kf.m(), m_screen_center, m_relative_focal_length, Eigen::Vector4f::Zero(), {}, {}, m_visualized_dimension, *m_pip_render_buffer);
+			set_camera_from_keyframe(backup);
+
+			m_pip_render_texture->blit_from_cuda_mapping();
+		}
+	}
+#endif
+
+	CUDA_CHECK_THROW(cudaStreamSynchronize(m_stream.get()));
 }
 
 
@@ -2262,7 +2731,6 @@ void Testbed::create_second_window() {
 		win_x = 0x40000000;
 		win_y = 0x40000000;
 		static const char* copy_shader_vert = "\
-			layout (location = 0)\n\
 			in vec2 vertPos_data;\n\
 			out vec2 texCoords;\n\
 			void main(){\n\
@@ -2300,7 +2768,8 @@ void Testbed::create_second_window() {
 		1.0f, 1.0f,
 		1.0f, 1.0f,
 		1.0f, -1.0f,
-		-1.0f, -1.0f};
+		-1.0f, -1.0f
+	};
 	glBindBuffer(GL_ARRAY_BUFFER, m_second_window.vbo);
 	glBufferData(GL_ARRAY_BUFFER, sizeof(fsquadVerts), fsquadVerts, GL_STATIC_DRAW);
 	glVertexAttribPointer(0, 2, GL_FLOAT, GL_FALSE, 2 * sizeof(float), (void *)0);
@@ -2308,6 +2777,84 @@ void Testbed::create_second_window() {
 	glBindBuffer(GL_ARRAY_BUFFER, 0);
 	glBindVertexArray(0);
 }
+
+void Testbed::init_vr() {
+	try {
+		if (!m_glfw_window) {
+			throw std::runtime_error{"`init_window` must be called before `init_vr`"};
+		}
+
+#if defined(XR_USE_PLATFORM_WIN32)
+		m_hmd = std::make_unique<OpenXRHMD>(wglGetCurrentDC(), glfwGetWGLContext(m_glfw_window));
+#elif defined(XR_USE_PLATFORM_XLIB)
+		Display* xDisplay = glfwGetX11Display();
+		GLXContext glxContext = glfwGetGLXContext(m_glfw_window);
+
+		int glxFBConfigXID = 0;
+		glXQueryContext(xDisplay, glxContext, GLX_FBCONFIG_ID, &glxFBConfigXID);
+		int attributes[3] = { GLX_FBCONFIG_ID, glxFBConfigXID, 0 };
+		int nelements = 1;
+		GLXFBConfig* pglxFBConfig = glXChooseFBConfig(xDisplay, 0, attributes, &nelements);
+		if (nelements != 1 || !pglxFBConfig) {
+			throw std::runtime_error{"init_vr(): Couldn't obtain GLXFBConfig"};
+		}
+
+		GLXFBConfig glxFBConfig = *pglxFBConfig;
+
+		XVisualInfo* visualInfo = glXGetVisualFromFBConfig(xDisplay, glxFBConfig);
+		if (!visualInfo) {
+			throw std::runtime_error{"init_vr(): Couldn't obtain XVisualInfo"};
+		}
+
+		m_hmd = std::make_unique<OpenXRHMD>(xDisplay, visualInfo->visualid, glxFBConfig, glXGetCurrentDrawable(), glxContext);
+#elif defined(XR_USE_PLATFORM_WAYLAND)
+		m_hmd = std::make_unique<OpenXRHMD>(glfwGetWaylandDisplay());
+#endif
+
+		// DLSS + sharpening is instrumental in getting VR to look good.
+		if (m_dlss_provider) {
+			m_dlss = true;
+			m_foveated_rendering = true;
+
+			// VERY aggressive performance settings (detriment to quality)
+			// to allow maintaining VR-adequate frame rates.
+			m_nerf.render_min_transmittance = 0.2f;
+		}
+
+		// If multiple GPUs are available, shoot for 60 fps in VR.
+		// Otherwise, it wouldn't be realistic to expect more than 30.
+		m_dynamic_res_target_fps = m_devices.size() > 1 ? 60 : 30;
+
+		// Many VR runtimes perform optical flow for automatic reprojection / motion smoothing.
+		// This breaks down for solid-color background, sometimes leading to artifacts. Hence:
+		// set background color to transparent and, in spherical_checkerboard_kernel(...),
+		// blend a checkerboard. If the user desires a solid background nonetheless, they can
+		// set the background color to have an alpha value of 1.0 manually via the GUI or via Python.
+		m_background_color = {0.0f, 0.0f, 0.0f, 0.0f};
+		m_render_transparency_as_checkerboard = true;
+	} catch (const std::runtime_error& e) {
+		if (std::string{e.what()}.find("XR_ERROR_FORM_FACTOR_UNAVAILABLE") != std::string::npos) {
+			throw std::runtime_error{"Could not initialize VR. Ensure that SteamVR, OculusVR, or any other OpenXR-compatible runtime is running. Also set it as the active OpenXR runtime."};
+		} else {
+			throw std::runtime_error{fmt::format("Could not initialize VR: {}", e.what())};
+		}
+	}
+}
+
+void Testbed::set_n_views(size_t n_views) {
+	while (m_views.size() > n_views) {
+		m_views.pop_back();
+	}
+
+	m_rgba_render_textures.resize(n_views);
+	m_depth_render_textures.resize(n_views);
+	while (m_views.size() < n_views) {
+		size_t idx = m_views.size();
+		m_rgba_render_textures[idx] = std::make_shared<GLTexture>();
+		m_depth_render_textures[idx] = std::make_shared<GLTexture>();
+		m_views.emplace_back(View{std::make_shared<CudaRenderBuffer>(m_rgba_render_textures[idx], m_depth_render_textures[idx])});
+	}
+};
 #endif //NGP_GUI
 
 void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
@@ -2322,23 +2869,22 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 	}
 
 #ifdef NGP_VULKAN
-	try {
-		vulkan_and_ngx_init();
-		m_dlss_supported = true;
-		if (m_testbed_mode == ETestbedMode::Nerf) {
-			m_dlss = true;
+	// Only try to initialize DLSS (Vulkan+NGX) if the
+	// GPU is sufficiently new. Older GPUs don't support
+	// DLSS, so it is preferable to not make a futile
+	// attempt and emit a warning that confuses users.
+	if (primary_device().compute_capability() >= 70) {
+		try {
+			m_dlss_provider = init_vulkan_and_ngx();
+			if (m_testbed_mode == ETestbedMode::Nerf) {
+				m_dlss = true;
+			}
+		} catch (const std::runtime_error& e) {
+			tlog::warning() << "Could not initialize Vulkan and NGX. DLSS not supported. (" << e.what() << ")";
 		}
-	} catch (const std::runtime_error& e) {
-		tlog::warning() << "Could not initialize Vulkan and NGX. DLSS not supported. (" << e.what() << ")";
 	}
-#else
-	m_dlss_supported = false;
 #endif
 
-	glfwWindowHint(GLFW_CONTEXT_VERSION_MAJOR, 3);
-	glfwWindowHint(GLFW_CONTEXT_VERSION_MINOR, 3);
-	glfwWindowHint(GLFW_OPENGL_PROFILE, GLFW_OPENGL_CORE_PROFILE);
-	glfwWindowHint(GLFW_OPENGL_FORWARD_COMPAT, GLFW_TRUE);
 	glfwWindowHint(GLFW_VISIBLE, hidden ? GLFW_FALSE : GLFW_TRUE);
 	std::string title = "Instant Neural Graphics Primitives";
 	m_glfw_window = glfwCreateWindow(m_window_res.x(), m_window_res.y(), title.c_str(), NULL, NULL);
@@ -2358,6 +2904,17 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 #endif
 	glfwSwapInterval(0); // Disable vsync
 
+	GLint gl_version_minor, gl_version_major;
+	glGetIntegerv(GL_MINOR_VERSION, &gl_version_minor);
+	glGetIntegerv(GL_MAJOR_VERSION, &gl_version_major);
+
+	if (gl_version_major < 3 || (gl_version_major == 3 && gl_version_minor < 1)) {
+		throw std::runtime_error{fmt::format("Unsupported OpenGL version {}.{}. instant-ngp requires at least OpenGL 3.1", gl_version_major, gl_version_minor)};
+	}
+
+	tlog::success() << "Initialized OpenGL version " << glGetString(GL_VERSION);
+
+
 	glfwSetWindowUserPointer(m_glfw_window, this);
 	glfwSetDropCallback(m_glfw_window, [](GLFWwindow* window, int count, const char** paths) {
 		Testbed* testbed = (Testbed*)glfwGetWindowUserPointer(window);
@@ -2424,22 +2981,26 @@ void Testbed::init_window(int resw, int resh, bool hidden, bool second_window) {
 	io.ConfigInputTrickleEventQueue = false; // new ImGui event handling seems to make camera controls laggy if this is true.
 	ImGui::StyleColorsDark();
 	ImGui_ImplGlfw_InitForOpenGL(m_glfw_window, true);
-	ImGui_ImplOpenGL3_Init("#version 330 core");
+	ImGui_ImplOpenGL3_Init("#version 140");
 
 	ImGui::GetStyle().ScaleAllSizes(xscale);
 	ImFontConfig font_cfg;
 	font_cfg.SizePixels = 13.0f * xscale;
 	io.Fonts->AddFontDefault(&font_cfg);
 
+	init_opengl_shaders();
+
 	// Make sure there's at least one usable render texture
-	m_render_textures = { std::make_shared<GLTexture>() };
+	m_rgba_render_textures = { std::make_shared<GLTexture>() };
+	m_depth_render_textures = { std::make_shared<GLTexture>() };
 
-	m_render_surfaces.clear();
-	m_render_surfaces.emplace_back(m_render_textures.front());
-	m_render_surfaces.front().resize(m_window_res);
+	m_views.clear();
+	m_views.emplace_back(View{std::make_shared<CudaRenderBuffer>(m_rgba_render_textures.front(), m_depth_render_textures.front())});
+	m_views.front().full_resolution = m_window_res;
+	m_views.front().render_buffer->resize(m_views.front().full_resolution);
 
 	m_pip_render_texture = std::make_shared<GLTexture>();
-	m_pip_render_surface = std::make_unique<CudaRenderBuffer>(m_pip_render_texture);
+	m_pip_render_buffer = std::make_unique<CudaRenderBuffer>(m_pip_render_texture);
 
 	m_render_window = true;
 
@@ -2457,15 +3018,18 @@ void Testbed::destroy_window() {
 		throw std::runtime_error{"Window must be initialized to be destroyed."};
 	}
 
-	m_render_surfaces.clear();
-	m_render_textures.clear();
+	m_hmd.reset();
 
-	m_pip_render_surface.reset();
+	m_views.clear();
+	m_rgba_render_textures.clear();
+	m_depth_render_textures.clear();
+
+	m_pip_render_buffer.reset();
 	m_pip_render_texture.reset();
 
 #ifdef NGP_VULKAN
-	m_dlss_supported = m_dlss = false;
-	vulkan_and_ngx_destroy();
+	m_dlss = false;
+	m_dlss_provider.reset();
 #endif
 
 	ImGui_ImplOpenGL3_Shutdown();
@@ -2474,6 +3038,9 @@ void Testbed::destroy_window() {
 	glfwDestroyWindow(m_glfw_window);
 	glfwTerminate();
 
+	m_blit_program = 0;
+	m_blit_vao = 0;
+
 	m_glfw_window = nullptr;
 	m_render_window = false;
 #endif //NGP_GUI
@@ -2482,9 +3049,12 @@ void Testbed::destroy_window() {
 bool Testbed::frame() {
 #ifdef NGP_GUI
 	if (m_render_window) {
-		if (!begin_frame_and_handle_user_input()) {
+		if (!begin_frame()) {
 			return false;
 		}
+
+		begin_vr_frame_and_handle_vr_input();
+		handle_user_input();
 	}
 #endif
 
@@ -2496,7 +3066,7 @@ bool Testbed::frame() {
 	}
 	bool skip_rendering = m_render_skip_due_to_lack_of_camera_movement_counter++ != 0;
 
-	if (!m_dlss && m_max_spp > 0 && !m_render_surfaces.empty() && m_render_surfaces.front().spp() >= m_max_spp) {
+	if (!m_dlss && m_max_spp > 0 && !m_views.empty() && m_views.front().render_buffer->spp() >= m_max_spp) {
 		skip_rendering = true;
 		if (!m_train) {
 			std::this_thread::sleep_for(1ms);
@@ -2508,6 +3078,12 @@ bool Testbed::frame() {
 		skip_rendering = false;
 	}
 
+#ifdef NGP_GUI
+	if (m_hmd && m_hmd->is_visible()) {
+		skip_rendering = false;
+	}
+#endif
+
 	if (!skip_rendering || (std::chrono::steady_clock::now() - m_last_gui_draw_time_point) > 25ms) {
 		redraw_gui_next_frame();
 	}
@@ -2540,6 +3116,32 @@ bool Testbed::frame() {
 
 		ImGui::EndFrame();
 	}
+
+	if (m_vr_frame_info) {
+		// If HMD is visible to the user, splat rendered images to the HMD
+		if (m_hmd->is_visible()) {
+			size_t n_views = std::min(m_views.size(), m_vr_frame_info->views.size());
+
+			// Blit textures to the OpenXR-owned framebuffers (each corresponding to one eye)
+			for (size_t i = 0; i < n_views; ++i) {
+				const auto& vr_view = m_vr_frame_info->views.at(i);
+
+				Vector2i resolution = {
+					vr_view.view.subImage.imageRect.extent.width,
+					vr_view.view.subImage.imageRect.extent.height,
+				};
+
+				blit_texture(m_views.at(i).foveation, m_rgba_render_textures.at(i)->texture(), GL_LINEAR, m_depth_render_textures.at(i)->texture(), vr_view.framebuffer, Vector2i::Zero(), resolution);
+			}
+
+			glFinish();
+		}
+
+		// Far and near planes are intentionally reversed, because we map depth inversely
+		// to z. I.e. a window-space depth of 1 refers to the near plane and a depth of 0
+		// to the far plane. This results in much better numeric precision.
+		m_hmd->end_frame(m_vr_frame_info, m_ndc_zfar / m_scale, m_ndc_znear / m_scale);
+	}
 #endif
 
 	return true;
@@ -2579,8 +3181,10 @@ void Testbed::set_camera_from_keyframe(const CameraKeyframe& k) {
 }
 
 void Testbed::set_camera_from_time(float t) {
-	if (m_camera_path.keyframes.empty())
+	if (m_camera_path.keyframes.empty()) {
 		return;
+	}
+
 	set_camera_from_keyframe(m_camera_path.eval_camera_path(t));
 }
 
@@ -2711,6 +3315,8 @@ void Testbed::reset_network(bool clear_density_grid) {
 	if (clear_density_grid) {
 		m_nerf.density_grid.memset(0);
 		m_nerf.density_grid_bitfield.memset(0);
+
+		set_all_devices_dirty();
 	}
 
 	m_loss_graph_samples = 0;
@@ -2723,6 +3329,13 @@ void Testbed::reset_network(bool clear_density_grid) {
 	json& optimizer_config = config["optimizer"];
 	json& network_config = config["network"];
 
+	// If the network config is incomplete, avoid doing further work.
+	/*
+	if (config.is_null() || encoding_config.is_null() || loss_config.is_null() || optimizer_config.is_null() || network_config.is_null()) {
+		return;
+	}
+	*/
+
 	auto dims = network_dims();
 
 	if (m_testbed_mode == ETestbedMode::Nerf) {
@@ -2798,16 +3411,22 @@ void Testbed::reset_network(bool clear_density_grid) {
 
 		uint32_t n_dir_dims = 3;
 		uint32_t n_extra_dims = m_nerf.training.dataset.n_extra_dims();
-		m_network = m_nerf_network = std::make_shared<NerfNetwork<precision_t>>(
-			dims.n_pos,
-			n_dir_dims,
-			n_extra_dims,
-			dims.n_pos + 1, // The offset of 1 comes from the dt member variable of NerfCoordinate. HACKY
-			encoding_config,
-			dir_encoding_config,
-			network_config,
-			rgb_network_config
-		);
+
+		// Instantiate an additional model for each auxiliary GPU
+		for (auto& device : m_devices) {
+			device.set_nerf_network(std::make_shared<NerfNetwork<precision_t>>(
+				dims.n_pos,
+				n_dir_dims,
+				n_extra_dims,
+				dims.n_pos + 1, // The offset of 1 comes from the dt member variable of NerfCoordinate. HACKY
+				encoding_config,
+				dir_encoding_config,
+				network_config,
+				rgb_network_config
+			));
+		}
+
+		m_network = m_nerf_network = primary_device().nerf_network();
 
 		m_encoding = m_nerf_network->encoding();
 		n_encoding_params = m_encoding->n_params() + m_nerf_network->dir_encoding()->n_params();
@@ -2873,7 +3492,12 @@ void Testbed::reset_network(bool clear_density_grid) {
 			}
 		}
 
-		m_network = std::make_shared<NetworkWithInputEncoding<precision_t>>(m_encoding, dims.n_output, network_config);
+		for (auto& device : m_devices) {
+			device.set_network(std::make_shared<NetworkWithInputEncoding<precision_t>>(m_encoding, dims.n_output, network_config));
+		}
+
+		m_network = primary_device().network();
+
 		n_encoding_params = m_encoding->n_params();
 
 		tlog::info()
@@ -2910,6 +3534,7 @@ void Testbed::reset_network(bool clear_density_grid) {
 		}
 	}
 
+	set_all_devices_dirty();
 }
 
 Testbed::Testbed(ETestbedMode mode) {
@@ -2955,6 +3580,28 @@ Testbed::Testbed(ETestbedMode mode) {
 		tlog::warning() << "This program was compiled for >=" << MIN_GPU_ARCH << " and may thus behave unexpectedly.";
 	}
 
+	m_devices.emplace_back(active_device, true);
+
+	// Multi-GPU is only supported in NeRF mode for now
+	int n_devices = cuda_device_count();
+	for (int i = 0; i < n_devices; ++i) {
+		if (i == active_device) {
+			continue;
+		}
+
+		if (cuda_compute_capability(i) >= MIN_GPU_ARCH) {
+			m_devices.emplace_back(i, false);
+		}
+	}
+
+	if (m_devices.size() > 1) {
+		tlog::success() << "Detected auxiliary GPUs:";
+		for (size_t i = 1; i < m_devices.size(); ++i) {
+			const auto& device = m_devices[i];
+			tlog::success() << "  #" << device.id() << ": " << device.name() << " [" << device.compute_capability() << "]";
+		}
+	}
+
 	m_network_config = {
 		{"loss", {
 			{"otype", "L2"}
@@ -3032,6 +3679,8 @@ void Testbed::train(uint32_t batch_size) {
 		throw std::runtime_error{"Cannot train without a mode."};
 	}
 
+	set_all_devices_dirty();
+
 	// If we don't have a trainer, as can happen when having loaded training data or changed modes without having
 	// explicitly loaded a new neural network.
 	if (!m_trainer) {
@@ -3097,18 +3746,16 @@ void Testbed::train(uint32_t batch_size) {
 	}
 }
 
-Vector2f Testbed::calc_focal_length(const Vector2i& resolution, int fov_axis, float zoom) const {
-	return m_relative_focal_length * resolution[fov_axis] * zoom;
+Vector2f Testbed::calc_focal_length(const Vector2i& resolution, const Vector2f& relative_focal_length, int fov_axis, float zoom) const {
+	return relative_focal_length * resolution[fov_axis] * zoom;
 }
 
-Vector2f Testbed::render_screen_center() const {
-	// see pixel_to_ray for how screen center is used; 0.5,0.5 is 'normal'. we flip so that it becomes the point in the original image we want to center on.
-	auto screen_center = m_screen_center;
-	return {(0.5f-screen_center.x())*m_zoom + 0.5f, (0.5-screen_center.y())*m_zoom + 0.5f};
+Vector2f Testbed::render_screen_center(const Vector2f& screen_center) const {
+	// see pixel_to_ray for how screen center is used; 0.5, 0.5 is 'normal'. we flip so that it becomes the point in the original image we want to center on.
+	return (Vector2f::Constant(0.5f) - screen_center) * m_zoom + Vector2f::Constant(0.5f);
 }
 
 __global__ void dlss_prep_kernel(
-	ETestbedMode mode,
 	Vector2i resolution,
 	uint32_t sample_index,
 	Vector2f focal_length,
@@ -3116,18 +3763,16 @@ __global__ void dlss_prep_kernel(
 	Vector3f parallax_shift,
 	bool snap_to_pixel_centers,
 	float* depth_buffer,
+	const float znear,
+	const float zfar,
 	Matrix<float, 3, 4> camera,
 	Matrix<float, 3, 4> prev_camera,
 	cudaSurfaceObject_t depth_surface,
 	cudaSurfaceObject_t mvec_surface,
 	cudaSurfaceObject_t exposure_surface,
-	Lens lens,
-	const float view_dist,
-	const float prev_view_dist,
-	const Vector2f image_pos,
-	const Vector2f prev_image_pos,
-	const Vector2i image_resolution,
-	const Vector2i quilting_dims
+	Foveation foveation,
+	Foveation prev_foveation,
+	Lens lens
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -3141,26 +3786,11 @@ __global__ void dlss_prep_kernel(
 	uint32_t x_orig = x;
 	uint32_t y_orig = y;
 
-	if (quilting_dims != Vector2i::Ones()) {
-		apply_quilting(&x, &y, resolution, parallax_shift, quilting_dims);
-	}
-
 	const float depth = depth_buffer[idx];
-	Vector2f mvec = mode == ETestbedMode::Image ? motion_vector_2d(
-		sample_index,
-		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
-		image_resolution,
-		screen_center,
-		view_dist,
-		prev_view_dist,
-		image_pos,
-		prev_image_pos,
-		snap_to_pixel_centers
-	) : motion_vector_3d(
+	Vector2f mvec = motion_vector(
 		sample_index,
 		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
+		resolution,
 		focal_length,
 		camera,
 		prev_camera,
@@ -3168,13 +3798,16 @@ __global__ void dlss_prep_kernel(
 		parallax_shift,
 		snap_to_pixel_centers,
 		depth,
+		foveation,
+		prev_foveation,
 		lens
 	);
 
 	surf2Dwrite(make_float2(mvec.x(), mvec.y()), mvec_surface, x_orig * sizeof(float2), y_orig);
 
-	// Scale depth buffer to be guaranteed in [0,1].
-	surf2Dwrite(std::min(std::max(depth / 128.0f, 0.0f), 1.0f), depth_surface, x_orig * sizeof(float), y_orig);
+	// DLSS was trained on games, which presumably used standard normalized device coordinates (ndc)
+	// depth buffers. So: convert depth to NDC with reasonable near- and far planes.
+	surf2Dwrite(to_ndc_depth(depth, znear, zfar), depth_surface, x_orig * sizeof(float), y_orig);
 
 	// First thread write an exposure factor of 1. Since DLSS will run on tonemapped data,
 	// exposure is assumed to already have been applied to DLSS' inputs.
@@ -3183,22 +3816,202 @@ __global__ void dlss_prep_kernel(
 	}
 }
 
-void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matrix<float, 3, 4>& camera_matrix1, const Vector4f& nerf_rolling_shutter, CudaRenderBuffer& render_buffer, bool to_srgb) {
-	Vector2i max_res = m_window_res.cwiseMax(render_buffer.in_resolution());
+__global__ void spherical_checkerboard_kernel(
+	Vector2i resolution,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera,
+	Vector2f screen_center,
+	Vector3f parallax_shift,
+	Foveation foveation,
+	Lens lens,
+	Array4f* frame_buffer
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	Ray ray = pixel_to_ray(
+		0,
+		{x, y},
+		resolution,
+		focal_length,
+		camera,
+		screen_center,
+		parallax_shift,
+		false,
+		0.0f,
+		1.0f,
+		0.0f,
+		foveation,
+		{}, // No need for hidden area mask
+		lens
+	);
+
+	// Blend with checkerboard to break up reprojection weirdness in some VR runtimes
+	host_device_swap(ray.d.z(), ray.d.y());
+	Vector2f spherical = dir_to_spherical(ray.d.normalized()) * 32.0f / PI();
+	const Array4f dark_gray = {0.5f, 0.5f, 0.5f, 1.0f};
+	const Array4f light_gray = {0.55f, 0.55f, 0.55f, 1.0f};
+	Array4f checker = fabsf(fmodf(floorf(spherical.x()) + floorf(spherical.y()), 2.0f)) < 0.5f ? dark_gray : light_gray;
+
+	uint32_t idx = x + resolution.x() * y;
+	frame_buffer[idx] += (1.0f - frame_buffer[idx].w()) * checker;
+}
+
+__global__ void vr_overlay_hands_kernel(
+	Vector2i resolution,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera,
+	Vector2f screen_center,
+	Vector3f parallax_shift,
+	Foveation foveation,
+	Lens lens,
+	Vector3f left_hand_pos,
+	float left_grab_strength,
+	Array4f left_hand_color,
+	Vector3f right_hand_pos,
+	float right_grab_strength,
+	Array4f right_hand_color,
+	float hand_radius,
+	EColorSpace output_color_space,
+	cudaSurfaceObject_t surface
+	// TODO: overwrite depth buffer
+) {
+	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
+	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
+
+	if (x >= resolution.x() || y >= resolution.y()) {
+		return;
+	}
+
+	Ray ray = pixel_to_ray(
+		0,
+		{x, y},
+		resolution,
+		focal_length,
+		camera,
+		screen_center,
+		parallax_shift,
+		false,
+		0.0f,
+		1.0f,
+		0.0f,
+		foveation,
+		{}, // No need for hidden area mask
+		lens
+	);
+
+	Array4f color = Array4f::Zero();
+	auto composit_hand = [&](Vector3f hand_pos, float grab_strength, Array4f hand_color) {
+		// Don't render the hand indicator if it's behind the ray origin.
+		if (ray.d.dot(hand_pos - ray.o) < 0.0f) {
+			return;
+		}
 
-	render_buffer.clear_frame(m_stream.get());
+		float distance = ray.distance_to(hand_pos);
 
-	Vector2f focal_length = calc_focal_length(render_buffer.in_resolution(), m_fov_axis, m_zoom);
-	Vector2f screen_center = render_screen_center();
+		Array4f base_color = Array4f::Zero();
+		const Array4f border_color = {0.4f, 0.4f, 0.4f, 0.4f};
+
+		// Divide hand radius into an inner part (4/5ths) and a border (1/5th).
+		float radius = hand_radius * 0.8f;
+		float border_width = hand_radius * 0.2f;
+
+		// When grabbing, shrink the inner part as a visual indicator.
+		radius *= 0.5f + 0.5f * (1.0f - grab_strength);
+
+		if (distance < radius) {
+			base_color = hand_color;
+		} else if (distance < radius + border_width) {
+			base_color = border_color;
+		} else {
+			return;
+		}
+
+		// Make hand color opaque when grabbing.
+		base_color.w() = grab_strength + (1.0f - grab_strength) * base_color.w();
+		color += base_color * (1.0f - color.w());
+	};
+
+	if (ray.d.dot(left_hand_pos - ray.o) < ray.d.dot(right_hand_pos - ray.o)) {
+		composit_hand(left_hand_pos, left_grab_strength, left_hand_color);
+		composit_hand(right_hand_pos, right_grab_strength, right_hand_color);
+	} else {
+		composit_hand(right_hand_pos, right_grab_strength, right_hand_color);
+		composit_hand(left_hand_pos, left_grab_strength, left_hand_color);
+	}
+
+	// Blend with existing color of pixel
+	Array4f prev_color;
+	surf2Dread((float4*)&prev_color, surface, x * sizeof(float4), y);
+	if (output_color_space == EColorSpace::SRGB) {
+		prev_color.head<3>() = srgb_to_linear(prev_color.head<3>());
+	}
+
+	color += (1.0f - color.w()) * prev_color;
+
+	if (output_color_space == EColorSpace::SRGB) {
+		color.head<3>() = linear_to_srgb(color.head<3>());
+	}
+
+	surf2Dwrite(to_float4(color), surface, x * sizeof(float4), y);
+}
+
+void Testbed::render_frame(
+	cudaStream_t stream,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Matrix<float, 3, 4>& prev_camera_matrix,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Vector4f& nerf_rolling_shutter,
+	const Foveation& foveation,
+	const Foveation& prev_foveation,
+	int visualized_dimension,
+	CudaRenderBuffer& render_buffer,
+	bool to_srgb,
+	CudaDevice* device
+) {
+	if (!device) {
+		device = &primary_device();
+	}
+
+	sync_device(render_buffer, *device);
+
+	{
+		auto device_guard = use_device(stream, render_buffer, *device);
+		render_frame_main(*device, camera_matrix0, camera_matrix1, orig_screen_center, relative_focal_length, nerf_rolling_shutter, foveation, visualized_dimension);
+	}
+
+	render_frame_epilogue(stream, camera_matrix0, prev_camera_matrix, orig_screen_center, relative_focal_length, foveation, prev_foveation, render_buffer, to_srgb);
+}
+
+void Testbed::render_frame_main(
+	CudaDevice& device,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Vector4f& nerf_rolling_shutter,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
+	device.render_buffer_view().clear(device.stream());
 
 	if (!m_network) {
 		return;
 	}
 
+	Vector2f focal_length = calc_focal_length(device.render_buffer_view().resolution, relative_focal_length, m_fov_axis, m_zoom);
+	Vector2f screen_center = render_screen_center(orig_screen_center);
+
 	switch (m_testbed_mode) {
 		case ETestbedMode::Nerf:
 			if (!m_render_ground_truth || m_ground_truth_alpha < 1.0f) {
-				render_nerf(render_buffer, max_res, focal_length, camera_matrix0, camera_matrix1, nerf_rolling_shutter, screen_center, m_stream.get());
+				render_nerf(device.stream(), device.render_buffer_view(), *device.nerf_network(), device.data().density_grid_bitfield_ptr, focal_length, camera_matrix0, camera_matrix1, nerf_rolling_shutter, screen_center, foveation, visualized_dimension);
 			}
 			break;
 		case ETestbedMode::Sdf:
@@ -3219,15 +4032,13 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 							m_sdf.brick_data.data(),
 							m_sdf.triangles_gpu.data(),
 							false,
-							m_stream.get()
+							device.stream()
 						);
 					}
 				}
+
 				distance_fun_t distance_fun =
 					m_render_ground_truth ? (distance_fun_t)[&](uint32_t n_elements, const Vector3f* positions, float* distances, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
 						if (m_sdf.groundtruth_mode == ESDFGroundTruthMode::SDFBricks) {
 							// linear_kernel(sdf_brick_kernel, 0, stream,
 							// 	n_elements,
@@ -3244,17 +4055,14 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 							m_sdf.triangle_bvh->signed_distance_gpu(
 								n_elements,
 								m_sdf.mesh_sdf_mode,
-								(Vector3f*)positions,
+								positions,
 								distances,
 								m_sdf.triangles_gpu.data(),
 								false,
-								m_stream.get()
+								stream
 							);
 						}
 					} : (distance_fun_t)[&](uint32_t n_elements, const Vector3f* positions, float* distances, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
 						n_elements = next_multiple(n_elements, tcnn::batch_size_granularity);
 						GPUMatrix<float> positions_matrix((float*)positions, 3, n_elements);
 						GPUMatrix<float, RM> distances_matrix(distances, 1, n_elements);
@@ -3265,53 +4073,64 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_render_ground_truth ? (normals_fun_t)[&](uint32_t n_elements, const Vector3f* positions, Vector3f* normals, cudaStream_t stream) {
 						// NO-OP. Normals will automatically be populated by raytrace
 					} : (normals_fun_t)[&](uint32_t n_elements, const Vector3f* positions, Vector3f* normals, cudaStream_t stream) {
-						if (n_elements == 0) {
-							return;
-						}
-
 						n_elements = next_multiple(n_elements, tcnn::batch_size_granularity);
-
 						GPUMatrix<float> positions_matrix((float*)positions, 3, n_elements);
 						GPUMatrix<float> normals_matrix((float*)normals, 3, n_elements);
 						m_network->input_gradient(stream, 0, positions_matrix, normals_matrix);
 					};
 
 				render_sdf(
+					device.stream(),
 					distance_fun,
 					normals_fun,
-					render_buffer,
-					max_res,
+					device.render_buffer_view(),
 					focal_length,
 					camera_matrix0,
 					screen_center,
-					m_stream.get()
+					foveation,
+					visualized_dimension
 				);
 			}
 			break;
 		case ETestbedMode::Image:
-			render_image(render_buffer, m_stream.get());
+			render_image(device.stream(), device.render_buffer_view(), focal_length, camera_matrix0, screen_center, foveation, visualized_dimension);
 			break;
 		case ETestbedMode::Volume:
-			render_volume(render_buffer, focal_length, camera_matrix0, screen_center, m_stream.get());
+			render_volume(device.stream(), device.render_buffer_view(), focal_length, camera_matrix0, screen_center, foveation);
 			break;
 		default:
-			throw std::runtime_error{"Invalid render mode."};
+			// No-op if no mode is active
+			break;
 	}
+}
+
+void Testbed::render_frame_epilogue(
+	cudaStream_t stream,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& prev_camera_matrix,
+	const Vector2f& orig_screen_center,
+	const Vector2f& relative_focal_length,
+	const Foveation& foveation,
+	const Foveation& prev_foveation,
+	CudaRenderBuffer& render_buffer,
+	bool to_srgb
+) {
+	Vector2f focal_length = calc_focal_length(render_buffer.in_resolution(), relative_focal_length, m_fov_axis, m_zoom);
+	Vector2f screen_center = render_screen_center(orig_screen_center);
 
 	render_buffer.set_color_space(m_color_space);
 	render_buffer.set_tonemap_curve(m_tonemap_curve);
 
+	Lens lens = (m_testbed_mode == ETestbedMode::Nerf && m_nerf.render_with_lens_distortion) ? m_nerf.render_lens : Lens{};
+
 	// Prepare DLSS data: motion vectors, scaled depth, exposure
 	if (render_buffer.dlss()) {
 		auto res = render_buffer.in_resolution();
 
-		bool distortion = m_testbed_mode == ETestbedMode::Nerf && m_nerf.render_with_lens_distortion;
-
 		const dim3 threads = { 16, 8, 1 };
 		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 
-		dlss_prep_kernel<<<blocks, threads, 0, m_stream.get()>>>(
-			m_testbed_mode,
+		dlss_prep_kernel<<<blocks, threads, 0, stream>>>(
 			res,
 			render_buffer.spp(),
 			focal_length,
@@ -3319,29 +4138,49 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 			m_parallax_shift,
 			m_snap_to_pixel_centers,
 			render_buffer.depth_buffer(),
+			m_ndc_znear,
+			m_ndc_zfar,
 			camera_matrix0,
-			m_prev_camera,
+			prev_camera_matrix,
 			render_buffer.dlss()->depth(),
 			render_buffer.dlss()->mvec(),
 			render_buffer.dlss()->exposure(),
-			distortion ? m_nerf.render_lens : Lens{},
-			m_scale,
-			m_prev_scale,
-			m_image.pos,
-			m_image.prev_pos,
-			m_image.resolution,
-			m_quilting_dims
+			foveation,
+			prev_foveation,
+			lens
 		);
 
 		render_buffer.set_dlss_sharpening(m_dlss_sharpening);
 	}
 
-	m_prev_camera = camera_matrix0;
-	m_prev_scale = m_scale;
-	m_image.prev_pos = m_image.pos;
+	EColorSpace output_color_space = to_srgb ? EColorSpace::SRGB : EColorSpace::Linear;
+
+	if (m_render_transparency_as_checkerboard) {
+		Matrix<float, 3, 4> checkerboard_transform = Matrix<float, 3, 4>::Identity();
 
-	render_buffer.accumulate(m_exposure, m_stream.get());
-	render_buffer.tonemap(m_exposure, m_background_color, to_srgb ? EColorSpace::SRGB : EColorSpace::Linear, m_stream.get());
+#if NGP_GUI
+		if (m_vr_frame_info && !m_vr_frame_info->views.empty()) {
+			checkerboard_transform = m_vr_frame_info->views[0].pose;
+		}
+#endif
+
+		auto res = render_buffer.in_resolution();
+		const dim3 threads = { 16, 8, 1 };
+		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
+		spherical_checkerboard_kernel<<<blocks, threads, 0, stream>>>(
+			res,
+			focal_length,
+			checkerboard_transform,
+			screen_center,
+			m_parallax_shift,
+			foveation,
+			lens,
+			render_buffer.frame_buffer()
+		);
+	}
+
+	render_buffer.accumulate(m_exposure, stream);
+	render_buffer.tonemap(m_exposure, m_background_color, output_color_space, m_ndc_znear, m_ndc_zfar, stream);
 
 	if (m_testbed_mode == ETestbedMode::Nerf) {
 		// Overlay the ground truth image if requested
@@ -3352,14 +4191,14 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_ground_truth_alpha,
 					Array3f::Constant(m_exposure) + m_nerf.training.cam_exposure[m_nerf.training.view].variable(),
 					m_background_color,
-					to_srgb ? EColorSpace::SRGB : EColorSpace::Linear,
+					output_color_space,
 					metadata.pixels,
 					metadata.image_data_type,
 					metadata.resolution,
 					m_fov_axis,
 					m_zoom,
 					Vector2f::Constant(0.5f),
-					m_stream.get()
+					stream
 				);
 			} else if (m_ground_truth_render_mode == EGroundTruthRenderMode::Depth && metadata.depth) {
 				render_buffer.overlay_depth(
@@ -3370,7 +4209,7 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 					m_fov_axis,
 					m_zoom,
 					Vector2f::Constant(0.5f),
-					m_stream.get()
+					stream
 				);
 			}
 		}
@@ -3385,39 +4224,67 @@ void Testbed::render_frame(const Matrix<float, 3, 4>& camera_matrix0, const Matr
 			}
 			size_t emap_size = error_map_res.x() * error_map_res.y();
 			err_data += emap_size * m_nerf.training.view;
-			static GPUMemory<float> average_error;
+
+			GPUMemory<float> average_error;
 			average_error.enlarge(1);
 			average_error.memset(0);
 			const float* aligned_err_data_s = (const float*)(((size_t)err_data)&~15);
 			const float* aligned_err_data_e = (const float*)(((size_t)(err_data+emap_size))&~15);
 			size_t reduce_size = aligned_err_data_e - aligned_err_data_s;
-			reduce_sum(aligned_err_data_s, [reduce_size] __device__ (float val) { return max(val,0.f) / (reduce_size); }, average_error.data(), reduce_size, m_stream.get());
+			reduce_sum(aligned_err_data_s, [reduce_size] __device__ (float val) { return max(val,0.f) / (reduce_size); }, average_error.data(), reduce_size, stream);
 			auto const &metadata = m_nerf.training.dataset.metadata[m_nerf.training.view];
-			render_buffer.overlay_false_color(metadata.resolution, to_srgb, m_fov_axis, m_stream.get(), err_data, error_map_res, average_error.data(), m_nerf.training.error_overlay_brightness, m_render_ground_truth);
+			render_buffer.overlay_false_color(metadata.resolution, to_srgb, m_fov_axis, stream, err_data, error_map_res, average_error.data(), m_nerf.training.error_overlay_brightness, m_render_ground_truth);
 		}
 	}
 
-	CUDA_CHECK_THROW(cudaStreamSynchronize(m_stream.get()));
-}
-
-void Testbed::determine_autofocus_target_from_pixel(const Vector2i& focus_pixel) {
-	float depth;
+#if NGP_GUI
+	// If in VR, indicate the hand position and render transparent background
+	if (m_vr_frame_info) {
+		auto& hands = m_vr_frame_info->hands;
 
-	const auto& surface = m_render_surfaces.front();
-	if (surface.depth_buffer()) {
-		auto res = surface.in_resolution();
-		Vector2i depth_pixel = focus_pixel.cast<float>().cwiseProduct(res.cast<float>()).cwiseQuotient(m_window_res.cast<float>()).cast<int>();
-		depth_pixel = depth_pixel.cwiseMin(res).cwiseMax(0);
+		auto res = render_buffer.out_resolution();
+		const dim3 threads = { 16, 8, 1 };
+		const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
+		vr_overlay_hands_kernel<<<blocks, threads, 0, stream>>>(
+			res,
+			focal_length.cwiseProduct(render_buffer.out_resolution().cast<float>()).cwiseQuotient(render_buffer.in_resolution().cast<float>()),
+			camera_matrix0,
+			screen_center,
+			m_parallax_shift,
+			foveation,
+			lens,
+			vr_to_world(hands[0].pose.col(3)),
+			hands[0].grab_strength,
+			{hands[0].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f},
+			vr_to_world(hands[1].pose.col(3)),
+			hands[1].grab_strength,
+			{hands[1].pressing ? 0.8f : 0.0f, 0.0f, 0.0f, 0.8f},
+			0.05f * m_scale, // Hand radius
+			output_color_space,
+			render_buffer.surface()
+		);
+	}
+#endif
+}
 
-		CUDA_CHECK_THROW(cudaMemcpy(&depth, surface.depth_buffer() + depth_pixel.x() + depth_pixel.y() * res.x(), sizeof(float), cudaMemcpyDeviceToHost));
-	} else {
-		depth = m_scale;
+float Testbed::get_depth_from_renderbuffer(const CudaRenderBuffer& render_buffer, const Vector2f& uv) {
+	if (!render_buffer.depth_buffer()) {
+		return m_scale;
 	}
 
-	auto ray = pixel_to_ray_pinhole(0, focus_pixel, m_window_res, calc_focal_length(m_window_res, m_fov_axis, m_zoom), m_smoothed_camera, render_screen_center());
+	float depth;
+	auto res = render_buffer.in_resolution();
+	Vector2i depth_pixel = uv.cwiseProduct(res.cast<float>()).cast<int>().cwiseMin(res).cwiseMax(0);
+	depth_pixel = depth_pixel.cwiseMin(res).cwiseMax(0);
+
+	CUDA_CHECK_THROW(cudaMemcpy(&depth, render_buffer.depth_buffer() + depth_pixel.x() + depth_pixel.y() * res.x(), sizeof(float), cudaMemcpyDeviceToHost));
+	return depth;
+}
 
-	m_autofocus_target = ray.o + ray.d * depth;
-	m_autofocus = true; // If someone shift-clicked, that means they want the AUTOFOCUS
+Vector3f Testbed::get_3d_pos_from_pixel(const CudaRenderBuffer& render_buffer, const Vector2i& pixel) {
+	float depth = get_depth_from_renderbuffer(render_buffer, pixel.cast<float>().cwiseQuotient(m_window_res.cast<float>()));
+	auto ray = pixel_to_ray_pinhole(0, pixel, m_window_res, calc_focal_length(m_window_res, m_relative_focal_length, m_fov_axis, m_zoom), m_smoothed_camera, render_screen_center(m_screen_center));
+	return ray(depth);
 }
 
 void Testbed::autofocus() {
@@ -3593,7 +4460,7 @@ void Testbed::load_snapshot(const fs::path& path) {
 			density_grid[i] = (float)density_grid_fp16[i];
 		});
 
-		if (m_nerf.density_grid.size() == NERF_GRIDSIZE() * NERF_GRIDSIZE() * NERF_GRIDSIZE() * (m_nerf.max_cascade + 1)) {
+		if (m_nerf.density_grid.size() == NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1)) {
 			update_density_grid_mean_and_bitfield(nullptr);
 		} else if (m_nerf.density_grid.size() != 0) {
 			// A size of 0 indicates that the density grid was never populated, which is a valid state of a (yet) untrained model.
@@ -3614,6 +4481,106 @@ void Testbed::load_snapshot(const fs::path& path) {
 	m_loss_scalar.set(m_network_config["snapshot"]["loss"]);
 
 	m_trainer->deserialize(m_network_config["snapshot"]);
+
+	set_all_devices_dirty();
+}
+
+void Testbed::CudaDevice::set_nerf_network(const std::shared_ptr<NerfNetwork<precision_t>>& nerf_network) {
+	m_network = m_nerf_network = nerf_network;
+}
+
+void Testbed::sync_device(CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) {
+	if (!device.dirty()) {
+		return;
+	}
+
+	if (device.is_primary()) {
+		device.data().density_grid_bitfield_ptr = m_nerf.density_grid_bitfield.data();
+		device.data().hidden_area_mask = render_buffer.hidden_area_mask();
+		device.set_dirty(false);
+		return;
+	}
+
+	m_stream.signal(device.stream());
+
+	int active_device = cuda_device();
+	auto guard = device.device_guard();
+
+	device.data().density_grid_bitfield.resize(m_nerf.density_grid_bitfield.size());
+	if (m_nerf.density_grid_bitfield.size() > 0) {
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(device.data().density_grid_bitfield.data(), device.id(), m_nerf.density_grid_bitfield.data(), active_device, m_nerf.density_grid_bitfield.bytes(), device.stream()));
+	}
+
+	device.data().density_grid_bitfield_ptr = device.data().density_grid_bitfield.data();
+
+	if (m_network) {
+		device.data().params.resize(m_network->n_params());
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(device.data().params.data(), device.id(), m_network->inference_params(), active_device, device.data().params.bytes(), device.stream()));
+		device.nerf_network()->set_params(device.data().params.data(), device.data().params.data(), nullptr);
+	}
+
+	if (render_buffer.hidden_area_mask()) {
+		auto ham = std::make_shared<Buffer2D<uint8_t>>(render_buffer.hidden_area_mask()->resolution());
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(ham->data(), device.id(), render_buffer.hidden_area_mask()->data(), active_device, ham->bytes(), device.stream()));
+		device.data().hidden_area_mask = ham;
+	} else {
+		device.data().hidden_area_mask = nullptr;
+	}
+
+	device.set_dirty(false);
+}
+
+// From https://stackoverflow.com/questions/20843271/passing-a-non-copyable-closure-object-to-stdfunction-parameter
+template <class F>
+auto make_copyable_function(F&& f) {
+	using dF = std::decay_t<F>;
+	auto spf = std::make_shared<dF>(std::forward<F>(f));
+	return [spf](auto&&... args) -> decltype(auto) {
+		return (*spf)( decltype(args)(args)... );
+	};
+}
+
+ScopeGuard Testbed::use_device(cudaStream_t stream, CudaRenderBuffer& render_buffer, Testbed::CudaDevice& device) {
+	device.wait_for(stream);
+
+	if (device.is_primary()) {
+		device.set_render_buffer_view(render_buffer.view());
+		return ScopeGuard{[&device, stream]() {
+			device.set_render_buffer_view({});
+			device.signal(stream);
+		}};
+	}
+
+	int active_device = cuda_device();
+	auto guard = device.device_guard();
+
+	size_t n_pixels = render_buffer.in_resolution().prod();
+
+	GPUMemoryArena::Allocation alloc;
+	auto scratch = allocate_workspace_and_distribute<Array4f, float>(device.stream(), &alloc, n_pixels, n_pixels);
+
+	device.set_render_buffer_view({
+		std::get<0>(scratch),
+		std::get<1>(scratch),
+		render_buffer.in_resolution(),
+		render_buffer.spp(),
+		device.data().hidden_area_mask,
+	});
+
+	return ScopeGuard{make_copyable_function([&render_buffer, &device, guard=std::move(guard), alloc=std::move(alloc), active_device, stream]() {
+		// Copy device's render buffer's data onto the original render buffer
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(render_buffer.frame_buffer(), active_device, device.render_buffer_view().frame_buffer, device.id(), render_buffer.in_resolution().prod() * sizeof(Array4f), device.stream()));
+		CUDA_CHECK_THROW(cudaMemcpyPeerAsync(render_buffer.depth_buffer(), active_device, device.render_buffer_view().depth_buffer, device.id(), render_buffer.in_resolution().prod() * sizeof(float), device.stream()));
+
+		device.set_render_buffer_view({});
+		device.signal(stream);
+	})};
+}
+
+void Testbed::set_all_devices_dirty() {
+	for (auto& device : m_devices) {
+		device.set_dirty(true);
+	}
 }
 
 void Testbed::load_camera_path(const fs::path& path) {
diff --git a/src/testbed_image.cu b/src/testbed_image.cu
index 5be8aa74bf6f11b99d70b8663835956b66137fd0..59d385a7840d561eb389fbdc8bca34cbff606ed2 100644
--- a/src/testbed_image.cu
+++ b/src/testbed_image.cu
@@ -77,14 +77,20 @@ __global__ void stratify2_kernel(uint32_t n_elements, uint32_t log2_batch_size,
 }
 
 __global__ void init_image_coords(
+	uint32_t sample_index,
 	Vector2f* __restrict__ positions,
+	float* __restrict__ depth_buffer,
 	Vector2i resolution,
-	Vector2i image_resolution,
-	float view_dist,
-	Vector2f image_pos,
+	float aspect,
+	Vector2f focal_length,
+	Matrix<float, 3, 4> camera_matrix,
 	Vector2f screen_center,
+	Vector3f parallax_shift,
 	bool snap_to_pixel_centers,
-	uint32_t sample_index
+	float plane_z,
+	float aperture_size,
+	Foveation foveation,
+	Buffer2DView<const uint8_t> hidden_area_mask
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -93,49 +99,47 @@ __global__ void init_image_coords(
 		return;
 	}
 
-	uint32_t idx = x + resolution.x() * y;
-	positions[idx] = pixel_to_image_uv(
+	// The image is displayed on the plane [0.5, 0.5, 0.5] + [X, Y, 0] to facilitate
+	// a top-down view by default, while permitting general camera movements (for
+	// motion vectors and code sharing with 3D tasks).
+	// Hence: generate rays and intersect that plane.
+	Ray ray = pixel_to_ray(
 		sample_index,
 		{x, y},
 		resolution,
-		image_resolution,
+		focal_length,
+		camera_matrix,
 		screen_center,
-		view_dist,
-		image_pos,
-		snap_to_pixel_centers
+		parallax_shift,
+		snap_to_pixel_centers,
+		0.0f, // near distance
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
 	);
-}
 
-// #define COLOR_SPACE_CONVERT convert to ycrcb experiment - causes some color shift tho it does lead to very slightly sharper edges. not a net win if you like colors :)
-#define CHROMA_SCALE 0.2f
-
-__global__ void colorspace_convert_image_half(Vector2i resolution, const char* __restrict__ texture) {
-	uint32_t x = blockIdx.x * blockDim.x + threadIdx.x;
-	uint32_t y = blockIdx.y * blockDim.y + threadIdx.y;
-	if (x >= resolution.x() || y >= resolution.y()) return;
-	__half val[4];
-	*(int2*)&val[0] = ((int2*)texture)[y * resolution.x() + x];
-	float R=val[0],G=val[1],B=val[2];
-	val[0]=(0.2126f * R + 0.7152f * G + 0.0722f * B);
-	val[1]=((-0.1146f * R - 0.3845f * G + 0.5f * B)+0.f)*CHROMA_SCALE;
-	val[2]=((0.5f * R - 0.4542f * G - 0.0458f * B)+0.f)*CHROMA_SCALE;
-	((int2*)texture)[y * resolution.x() + x] = *(int2*)&val[0];
-}
+	// Intersect the Z=0.5 plane
+	float t = ray.is_valid() ? (0.5f - ray.o.z()) / ray.d.z() : -1.0f;
+
+	uint32_t idx = x + resolution.x() * y;
+	if (t <= 0.0f) {
+		depth_buffer[idx] = MAX_DEPTH();
+		positions[idx] = -Vector2f::Ones();
+		return;
+	}
+
+	Vector2f uv = ray(t).head<2>();
+
+	// Flip from world coordinates where Y goes up to image coordinates where Y goes down.
+	// Also, multiply the x-axis by the image's aspect ratio to make it have the right proportions.
+	uv = (uv - Vector2f::Constant(0.5f)).cwiseProduct(Vector2f{aspect, -1.0f}) + Vector2f::Constant(0.5f);
 
-__global__ void colorspace_convert_image_float(Vector2i resolution, const char* __restrict__ texture) {
-	uint32_t x = blockIdx.x * blockDim.x + threadIdx.x;
-	uint32_t y = blockIdx.y * blockDim.y + threadIdx.y;
-	if (x >= resolution.x() || y >= resolution.y()) return;
-	float val[4];
-	*(float4*)&val[0] = ((float4*)texture)[y * resolution.x() + x];
-	float R=val[0],G=val[1],B=val[2];
-	val[0]=(0.2126f * R + 0.7152f * G + 0.0722f * B);
-	val[1]=((-0.1146f * R - 0.3845f * G + 0.5f * B)+0.f)*CHROMA_SCALE;
-	val[2]=((0.5f * R - 0.4542f * G - 0.0458f * B)+0.f)*CHROMA_SCALE;
-	((float4*)texture)[y * resolution.x() + x] = *(float4*)&val[0];
+	depth_buffer[idx] = t;
+	positions[idx] = uv;
 }
 
-__global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restrict__ positions, const Array3f* __restrict__ colors, Array4f* __restrict__ frame_buffer, float* __restrict__ depth_buffer, bool linear_colors) {
+__global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restrict__ positions, const Array3f* __restrict__ colors, Array4f* __restrict__ frame_buffer, bool linear_colors) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
 
@@ -148,7 +152,6 @@ __global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restri
 	const Vector2f uv = positions[idx];
 	if (uv.x() < 0.0f || uv.x() > 1.0f || uv.y() < 0.0f || uv.y() > 1.0f) {
 		frame_buffer[idx] = Array4f::Zero();
-		depth_buffer[idx] = 1e10f;
 		return;
 	}
 
@@ -158,16 +161,7 @@ __global__ void shade_kernel_image(Vector2i resolution, const Vector2f* __restri
 		color = srgb_to_linear(color);
 	}
 
-#ifdef COLOR_SPACE_CONVERT
-	float Y=color.x(), Cb =color.y()*(1.f/CHROMA_SCALE) -0.f, Cr = color.z() * (1.f/CHROMA_SCALE) - 0.f;
-	float R = Y                + 1.5748f * Cr;
-	float G = Y - 0.1873f * Cb - 0.4681 * Cr;
-	float B = Y + 1.8556f * Cb;
-	frame_buffer[idx] = {R, G, B, 1.0f};
-#else
 	frame_buffer[idx] = {color.x(), color.y(), color.z(), 1.0f};
-#endif
-	depth_buffer[idx] = 1.0f;
 }
 
 template <typename T, uint32_t stride>
@@ -291,8 +285,16 @@ void Testbed::train_image(size_t target_batch_size, bool get_loss_scalar, cudaSt
 	m_training_step++;
 }
 
-void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream) {
-	auto res = render_buffer.in_resolution();
+void Testbed::render_image(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
+	const Vector2f& focal_length,
+	const Matrix<float, 3, 4>& camera_matrix,
+	const Vector2f& screen_center,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
+	auto res = render_buffer.resolution;
 
 	// Make sure we have enough memory reserved to render at the requested resolution
 	size_t n_pixels = (size_t)res.x() * res.y();
@@ -300,18 +302,27 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 	m_image.render_coords.enlarge(n_elements);
 	m_image.render_out.enlarge(n_elements);
 
+	float plane_z = m_slice_plane_z + m_scale;
+	float aspect = (float)m_image.resolution.y() / (float)m_image.resolution.x();
+
 	// Generate 2D coords at which to query the network
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	init_image_coords<<<blocks, threads, 0, stream>>>(
+		render_buffer.spp,
 		m_image.render_coords.data(),
+		render_buffer.depth_buffer,
 		res,
-		m_image.resolution,
-		m_scale,
-		m_image.pos,
-		m_screen_center - Vector2f::Constant(0.5f),
+		aspect,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		m_parallax_shift,
 		m_snap_to_pixel_centers,
-		render_buffer.spp()
+		plane_z,
+		m_aperture_size,
+		foveation,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{}
 	);
 
 	// Obtain colors for each 2D coord
@@ -338,10 +349,10 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 	}
 
 	if (!m_render_ground_truth) {
-		if (m_visualized_dimension >= 0) {
+		if (visualized_dimension >= 0) {
 			GPUMatrix<float> positions_matrix((float*)m_image.render_coords.data(), 2, n_elements);
 			GPUMatrix<float> colors_matrix((float*)m_image.render_out.data(), 3, n_elements);
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+			m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 		} else {
 			GPUMatrix<float> positions_matrix((float*)m_image.render_coords.data(), 2, n_elements);
 			GPUMatrix<float> colors_matrix((float*)m_image.render_out.data(), 3, n_elements);
@@ -354,8 +365,7 @@ void Testbed::render_image(CudaRenderBuffer& render_buffer, cudaStream_t stream)
 		res,
 		m_image.render_coords.data(),
 		m_image.render_out.data(),
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		render_buffer.frame_buffer,
 		m_image.training.linear_colors
 	);
 }
diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu
index 33a663d4894ad43422b55af9fa2b8d97eec7dd4c..ba563a6d1115ebd38bbe864e1e99bc6c44d326a5 100644
--- a/src/testbed_nerf.cu
+++ b/src/testbed_nerf.cu
@@ -377,34 +377,60 @@ __global__ void mark_untrained_density_grid(const uint32_t n_elements,  float* _
 	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
 	uint32_t z = tcnn::morton3D_invert(pos_idx>>2);
 
-	Vector3f pos = ((Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}) / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
-	float voxel_radius = 0.5f*SQRT3()*scalbnf(1.0f, level) / NERF_GRIDSIZE();
-	int count = 0;
-	for (uint32_t j=0; j < n_training_images; ++j) {
-		if (metadata[j].lens.mode == ELensMode::FTheta || metadata[j].lens.mode == ELensMode::LatLong || metadata[j].lens.mode == ELensMode::OpenCVFisheye) {
-			// not supported for now
-			count++;
-			break;
+	float voxel_size = scalbnf(1.0f / NERF_GRIDSIZE(), level);
+	Vector3f pos = (Vector3f{(float)x, (float)y, (float)z} / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
+
+	Vector3f corners[8] = {
+		pos + Vector3f{0.0f,       0.0f,       0.0f      },
+		pos + Vector3f{voxel_size, 0.0f,       0.0f      },
+		pos + Vector3f{0.0f,       voxel_size, 0.0f      },
+		pos + Vector3f{voxel_size, voxel_size, 0.0f      },
+		pos + Vector3f{0.0f,       0.0f,       voxel_size},
+		pos + Vector3f{voxel_size, 0.0f,       voxel_size},
+		pos + Vector3f{0.0f,       voxel_size, voxel_size},
+		pos + Vector3f{voxel_size, voxel_size, voxel_size},
+	};
+
+	// Number of training views that need to see a voxel cell
+	// at minimum for that cell to be marked trainable.
+	// Floaters can be reduced by increasing this value to 2,
+	// but at the cost of certain reconstruction artifacts.
+	const uint32_t min_count = 1;
+	uint32_t count = 0;
+
+	for (uint32_t j = 0; j < n_training_images && count < min_count; ++j) {
+		const auto& xform = training_xforms[j].start;
+		const auto& m = metadata[j];
+
+		if (m.lens.mode == ELensMode::FTheta || m.lens.mode == ELensMode::LatLong || m.lens.mode == ELensMode::Equirectangular) {
+			// FTheta lenses don't have a forward mapping, so are assumed seeing everything. Latlong and equirect lenses
+			// by definition see everything.
+			++count;
+			continue;
 		}
-		float half_resx = metadata[j].resolution.x() * 0.5f;
-		float half_resy = metadata[j].resolution.y() * 0.5f;
-		Matrix<float, 3, 4> xform = training_xforms[j].start;
-		Vector3f ploc = pos - xform.col(3);
-		float x = ploc.dot(xform.col(0));
-		float y = ploc.dot(xform.col(1));
-		float z = ploc.dot(xform.col(2));
-		if (z > 0.f) {
-			auto focal = metadata[j].focal_length;
-			// TODO - add a box / plane intersection to stop thomas from murdering me
-			if (fabsf(x) - voxel_radius < z / focal.x() * half_resx && fabsf(y) - voxel_radius < z / focal.y() * half_resy) {
-				count++;
-				if (count > 0) break;
+
+		for (uint32_t k = 0; k < 8; ++k) {
+			// Only consider voxel corners in front of the camera
+			Vector3f dir = (corners[k] - xform.col(3)).normalized();
+			if (dir.dot(xform.col(2)) < 1e-4f) {
+				continue;
+			}
+
+			// Check if voxel corner projects onto the image plane, i.e. uv must be in (0, 1)^2
+			Vector2f uv = pos_to_uv(corners[k], m.resolution, m.focal_length, xform, m.principal_point, Vector3f::Zero(), {}, m.lens);
+
+			// `pos_to_uv` is _not_ injective in the presence of lens distortion (which breaks down outside of the image plane).
+			// So we need to check whether the produced uv location generates a ray that matches the ray that we started with.
+			Ray ray = uv_to_ray(0.0f, uv, m.resolution, m.focal_length, xform, m.principal_point, Vector3f::Zero(), 0.0f, 1.0f, 0.0f, {}, {}, m.lens);
+			if ((ray.d.normalized() - dir).norm() < 1e-3f && uv.x() > 0.0f && uv.y() > 0.0f && uv.x() < 1.0f && uv.y() < 1.0f) {
+				++count;
+				break;
 			}
 		}
 	}
 
-	if (clear_visible_voxels || (grid_out[i] < 0) != (count <= 0)) {
-		grid_out[i] = (count > 0) ? 0.f : -1.f;
+	if (clear_visible_voxels || (grid_out[i] < 0) != (count < min_count)) {
+		grid_out[i] = (count >= min_count) ? 0.f : -1.f;
 	}
 }
 
@@ -525,9 +551,10 @@ __global__ void grid_samples_half_to_float(const uint32_t n_elements, BoundingBo
 		Vector3f pos = unwarp_position(coords_in[i].p, aabb);
 		float grid_density = cascaded_grid_at(pos, grid_in, mip_from_pos(pos, max_cascade));
 		if (grid_density < NERF_MIN_OPTICAL_THICKNESS()) {
-			mlp = -10000.f;
+			mlp = -10000.0f;
 		}
 	}
+
 	dst[i] = mlp;
 }
 
@@ -777,8 +804,6 @@ __global__ void composite_kernel_nerf(
 	BoundingBox aabb,
 	float glow_y_cutoff,
 	int glow_mode,
-	const uint32_t n_training_images,
-	const TrainingXForm* __restrict__ training_xforms,
 	Matrix<float, 3, 4> camera_matrix,
 	Vector2f focal_length,
 	float depth_scale,
@@ -1038,18 +1063,18 @@ inline __device__ float pdf_2d(Vector2f sample, uint32_t img, const Vector2i& re
 }
 
 inline __device__ Vector2f nerf_random_image_pos_training(default_rng_t& rng, const Vector2i& resolution, bool snap_to_pixel_centers, const float* __restrict__ cdf_x_cond_y, const float* __restrict__ cdf_y, const Vector2i& cdf_res, uint32_t img, float* __restrict__ pdf = nullptr) {
-	Vector2f xy = random_val_2d(rng);
+	Vector2f uv = random_val_2d(rng);
 
 	if (cdf_x_cond_y) {
-		xy = sample_cdf_2d(xy, img, cdf_res, cdf_x_cond_y, cdf_y, pdf);
+		uv = sample_cdf_2d(uv, img, cdf_res, cdf_x_cond_y, cdf_y, pdf);
 	} else if (pdf) {
 		*pdf = 1.0f;
 	}
 
 	if (snap_to_pixel_centers) {
-		xy = (xy.cwiseProduct(resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(resolution - Vector2i::Ones()).cast<float>() + Vector2f::Constant(0.5f)).cwiseQuotient(resolution.cast<float>());
+		uv = (uv.cwiseProduct(resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(resolution - Vector2i::Ones()).cast<float>() + Vector2f::Constant(0.5f)).cwiseQuotient(resolution.cast<float>());
 	}
-	return xy;
+	return uv;
 }
 
 inline __device__ uint32_t image_idx(uint32_t base_idx, uint32_t n_rays, uint32_t n_rays_total, uint32_t n_training_images, const float* __restrict__ cdf = nullptr, float* __restrict__ pdf = nullptr) {
@@ -1096,8 +1121,7 @@ __global__ void generate_training_samples_nerf(
 	bool snap_to_pixel_centers,
 	bool train_envmap,
 	float cone_angle_constant,
-	const float* __restrict__ distortion_data,
-	const Vector2i distortion_resolution,
+	Buffer2DView<const Vector2f> distortion,
 	const float* __restrict__ cdf_x_cond_y,
 	const float* __restrict__ cdf_y,
 	const float* __restrict__ cdf_img,
@@ -1112,11 +1136,11 @@ __global__ void generate_training_samples_nerf(
 	Eigen::Vector2i resolution = metadata[img].resolution;
 
 	rng.advance(i * N_MAX_RANDOM_SAMPLES_PER_RAY());
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, cdf_res, img);
 
 	// Negative values indicate masked-away regions
-	size_t pix_idx = pixel_idx(xy, resolution, 0);
-	if (read_rgba(xy, resolution, metadata[img].pixels, metadata[img].image_data_type).x() < 0.0f) {
+	size_t pix_idx = pixel_idx(uv, resolution, 0);
+	if (read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type).x() < 0.0f) {
 		return;
 	}
 
@@ -1129,7 +1153,7 @@ __global__ void generate_training_samples_nerf(
 	const float* extra_dims = extra_dims_gpu + img * n_extra_dims;
 	const Lens lens = metadata[img].lens;
 
-	const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, xy, motionblur_time);
+	const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, motionblur_time);
 
 	Ray ray_unnormalized;
 	const Ray* rays_in_unnormalized = metadata[img].rays;
@@ -1138,16 +1162,16 @@ __global__ void generate_training_samples_nerf(
 		ray_unnormalized = rays_in_unnormalized[pix_idx];
 
 		/* DEBUG - compare the stored rays to the computed ones
-		const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, xy, 0.f);
+		const Matrix<float, 3, 4> xform = get_xform_given_rolling_shutter(training_xforms[img], metadata[img].rolling_shutter, uv, 0.f);
 		Ray ray2;
 		ray2.o = xform.col(3);
-		ray2.d = f_theta_distortion(xy, principal_point, lens);
+		ray2.d = f_theta_distortion(uv, principal_point, lens);
 		ray2.d = (xform.block<3, 3>(0, 0) * ray2.d).normalized();
 		if (i==1000) {
 			printf("\n%d uv %0.3f,%0.3f pixel %0.2f,%0.2f transform from [%0.5f %0.5f %0.5f] to [%0.5f %0.5f %0.5f]\n"
 				" origin    [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n"
 				" direction [%0.5f %0.5f %0.5f] vs [%0.5f %0.5f %0.5f]\n"
-			, img,xy.x(), xy.y(), xy.x()*resolution.x(), xy.y()*resolution.y(),
+			, img,uv.x(), uv.y(), uv.x()*resolution.x(), uv.y()*resolution.y(),
 				training_xforms[img].start.col(3).x(),training_xforms[img].start.col(3).y(),training_xforms[img].start.col(3).z(),
 				training_xforms[img].end.col(3).x(),training_xforms[img].end.col(3).y(),training_xforms[img].end.col(3).z(),
 				ray_unnormalized.o.x(),ray_unnormalized.o.y(),ray_unnormalized.o.z(),
@@ -1157,31 +1181,10 @@ __global__ void generate_training_samples_nerf(
 		}
 		*/
 	} else {
-		// Rays need to be inferred from the camera matrix
-		ray_unnormalized.o = xform.col(3);
-		if (lens.mode == ELensMode::FTheta) {
-			ray_unnormalized.d = f_theta_undistortion(xy - principal_point, lens.params, {0.f, 0.f, 1.f});
-		} else if (lens.mode == ELensMode::LatLong) {
-			ray_unnormalized.d = latlong_to_dir(xy);
-		} else {
-			ray_unnormalized.d = {
-				(xy.x()-principal_point.x())*resolution.x() / focal_length.x(),
-				(xy.y()-principal_point.y())*resolution.y() / focal_length.y(),
-				1.0f,
-			};
-
-			if (lens.mode == ELensMode::OpenCV) {
-				iterative_opencv_lens_undistortion(lens.params, &ray_unnormalized.d.x(), &ray_unnormalized.d.y());
-			} else if (lens.mode == ELensMode::OpenCVFisheye) {
-				iterative_opencv_fisheye_lens_undistortion(lens.params, &ray_unnormalized.d.x(), &ray_unnormalized.d.y());
-			}
+		ray_unnormalized = uv_to_ray(0, uv, resolution, focal_length, xform, principal_point, Vector3f::Zero(), 0.0f, 1.0f, 0.0f, {}, {}, lens, distortion);
+		if (!ray_unnormalized.is_valid()) {
+			ray_unnormalized = {xform.col(3), xform.col(2)};
 		}
-
-		if (distortion_data) {
-			ray_unnormalized.d.head<2>() += read_image<2>(distortion_data, distortion_resolution, xy);
-		}
-
-		ray_unnormalized.d = (xform.block<3, 3>(0, 0) * ray_unnormalized.d); // NOT normalized
 	}
 
 	Eigen::Vector3f ray_d_normalized = ray_unnormalized.d.normalized();
@@ -1278,7 +1281,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	const uint32_t* __restrict__ rays_counter,
 	float loss_scale,
 	int padded_output_width,
-	const float* __restrict__ envmap_data,
+	Buffer2DView<const Eigen::Array4f> envmap,
 	float* __restrict__ envmap_gradient,
 	const Vector2i envmap_resolution,
 	ELossType envmap_loss_type,
@@ -1374,8 +1377,8 @@ __global__ void compute_loss_kernel_train_nerf(
 	uint32_t img = image_idx(ray_idx, n_rays, n_rays_total, n_training_images, cdf_img, &img_pdf);
 	Eigen::Vector2i resolution = metadata[img].resolution;
 
-	float xy_pdf = 1.0f;
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_cdf_res, img, &xy_pdf);
+	float uv_pdf = 1.0f;
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_cdf_res, img, &uv_pdf);
 	float max_level = max_level_rand_training ? (random_val(rng) * 2.0f) : 1.0f; // Multiply by 2 to ensure 50% of training is at max level
 
 	if (train_with_random_bg_color) {
@@ -1386,16 +1389,16 @@ __global__ void compute_loss_kernel_train_nerf(
 	// Composit background behind envmap
 	Array4f envmap_value;
 	Vector3f dir;
-	if (envmap_data) {
+	if (envmap) {
 		dir = rays_in_unnormalized[i].d.normalized();
-		envmap_value = read_envmap(envmap_data, envmap_resolution, dir);
+		envmap_value = read_envmap(envmap, dir);
 		background_color = envmap_value.head<3>() + background_color * (1.0f - envmap_value.w());
 	}
 
 	Array3f exposure_scale = (0.6931471805599453f * exposure[img]).exp();
-	// Array3f rgbtarget = composit_and_lerp(xy, resolution, img, training_images, background_color, exposure_scale);
-	// Array3f rgbtarget = composit(xy, resolution, img, training_images, background_color, exposure_scale);
-	Array4f texsamp = read_rgba(xy, resolution, metadata[img].pixels, metadata[img].image_data_type);
+	// Array3f rgbtarget = composit_and_lerp(uv, resolution, img, training_images, background_color, exposure_scale);
+	// Array3f rgbtarget = composit(uv, resolution, img, training_images, background_color, exposure_scale);
+	Array4f texsamp = read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type);
 
 	Array3f rgbtarget;
 	if (train_in_linear_colors || color_space == EColorSpace::Linear) {
@@ -1437,9 +1440,9 @@ __global__ void compute_loss_kernel_train_nerf(
 	dloss_doutput += compacted_base * padded_output_width;
 
 	LossAndGradient lg = loss_and_gradient(rgbtarget, rgb_ray, loss_type);
-	lg.loss /= img_pdf * xy_pdf;
+	lg.loss /= img_pdf * uv_pdf;
 
-	float target_depth = rays_in_unnormalized[i].d.norm() * ((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(xy, resolution, metadata[img].depth) : -1.0f);
+	float target_depth = rays_in_unnormalized[i].d.norm() * ((depth_supervision_lambda > 0.0f && metadata[img].depth) ? read_depth(uv, resolution, metadata[img].depth) : -1.0f);
 	LossAndGradient lg_depth = loss_and_gradient(Array3f::Constant(target_depth), Array3f::Constant(depth_ray), depth_loss_type);
 	float depth_loss_gradient = target_depth > 0.0f ? depth_supervision_lambda * lg_depth.gradient.x() : 0;
 
@@ -1447,7 +1450,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	// Essentially: variance reduction, but otherwise the same optimization.
 	// We _dont_ want that. If importance sampling is enabled, we _do_ actually want
 	// to change the weighting of the loss function. So don't divide.
-	// lg.gradient /= img_pdf * xy_pdf;
+	// lg.gradient /= img_pdf * uv_pdf;
 
 	float mean_loss = lg.loss.mean();
 	if (loss_output) {
@@ -1455,7 +1458,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	}
 
 	if (error_map) {
-		const Vector2f pos = (xy.cwiseProduct(error_map_res.cast<float>()) - Vector2f::Constant(0.5f)).cwiseMax(0.0f).cwiseMin(error_map_res.cast<float>() - Vector2f::Constant(1.0f + 1e-4f));
+		const Vector2f pos = (uv.cwiseProduct(error_map_res.cast<float>()) - Vector2f::Constant(0.5f)).cwiseMax(0.0f).cwiseMin(error_map_res.cast<float>() - Vector2f::Constant(1.0f + 1e-4f));
 		const Vector2i pos_int = pos.cast<int>();
 		const Vector2f weight = pos - pos_int.cast<float>();
 
@@ -1466,7 +1469,7 @@ __global__ void compute_loss_kernel_train_nerf(
 		};
 
 		if (sharpness_data && aabb.contains(hitpoint)) {
-			Vector2i sharpness_pos = xy.cwiseProduct(sharpness_resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(sharpness_resolution - Vector2i::Constant(1));
+			Vector2i sharpness_pos = uv.cwiseProduct(sharpness_resolution.cast<float>()).cast<int>().cwiseMax(0).cwiseMin(sharpness_resolution - Vector2i::Constant(1));
 			float sharp = sharpness_data[img * sharpness_resolution.prod() + sharpness_pos.y() * sharpness_resolution.x() + sharpness_pos.x()] + 1e-6f;
 
 			// The maximum value of positive floats interpreted in uint format is the same as the maximum value of the floats.
@@ -1549,7 +1552,7 @@ __global__ void compute_loss_kernel_train_nerf(
 
 	if (exposure_gradient) {
 		// Assume symmetric loss
-		Array3f dloss_by_dgt = -lg.gradient / xy_pdf;
+		Array3f dloss_by_dgt = -lg.gradient / uv_pdf;
 
 		if (!train_in_linear_colors) {
 			dloss_by_dgt /= srgb_to_linear_derivative(rgbtarget);
@@ -1656,9 +1659,9 @@ __global__ void compute_cam_gradient_train_nerf(
 	}
 
 	rng.advance(ray_idx * N_MAX_RANDOM_SAMPLES_PER_RAY());
-	float xy_pdf = 1.0f;
+	float uv_pdf = 1.0f;
 
-	Vector2f xy = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_res, img, &xy_pdf);
+	Vector2f uv = nerf_random_image_pos_training(rng, resolution, snap_to_pixel_centers, cdf_x_cond_y, cdf_y, error_map_res, img, &uv_pdf);
 
 	if (distortion_gradient) {
 		// Projection of the raydir gradient onto the plane normal to raydir,
@@ -1671,14 +1674,14 @@ __global__ void compute_cam_gradient_train_nerf(
 		Vector3f image_plane_gradient = xform.block<3,3>(0,0).inverse() * orthogonal_ray_gradient;
 
 		// Splat the resulting 2D image plane gradient into the distortion params
-		deposit_image_gradient<2>(image_plane_gradient.head<2>() / xy_pdf, distortion_gradient, distortion_gradient_weight, distortion_resolution, xy);
+		deposit_image_gradient<2>(image_plane_gradient.head<2>() / uv_pdf, distortion_gradient, distortion_gradient_weight, distortion_resolution, uv);
 	}
 
 	if (cam_pos_gradient) {
 		// Atomically reduce the ray gradient into the xform gradient
 		NGP_PRAGMA_UNROLL
 		for (uint32_t j = 0; j < 3; ++j) {
-			atomicAdd(&cam_pos_gradient[img][j], ray_gradient.o[j] / xy_pdf);
+			atomicAdd(&cam_pos_gradient[img][j], ray_gradient.o[j] / uv_pdf);
 		}
 	}
 
@@ -1692,7 +1695,7 @@ __global__ void compute_cam_gradient_train_nerf(
 		// Atomically reduce the ray gradient into the xform gradient
 		NGP_PRAGMA_UNROLL
 		for (uint32_t j = 0; j < 3; ++j) {
-			atomicAdd(&cam_rot_gradient[img][j], angle_axis[j] / xy_pdf);
+			atomicAdd(&cam_rot_gradient[img][j], angle_axis[j] / uv_pdf);
 		}
 	}
 }
@@ -1811,15 +1814,14 @@ __global__ void init_rays_with_payload_kernel_nerf(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
+	Foveation foveation,
 	Lens lens,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
-	const float* __restrict__ distortion_data,
-	const Vector2i distortion_resolution,
-	ERenderMode render_mode,
-	Vector2i quilting_dims
+	Buffer2DView<const Eigen::Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
+	Buffer2DView<const Eigen::Vector2f> distortion,
+	ERenderMode render_mode
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -1834,34 +1836,37 @@ __global__ void init_rays_with_payload_kernel_nerf(
 		aperture_size = 0.0;
 	}
 
-	if (quilting_dims != Vector2i::Ones()) {
-		apply_quilting(&x, &y, resolution, parallax_shift, quilting_dims);
-	}
-
-	// TODO: pixel_to_ray also immediately computes u,v for the pixel, so this is somewhat redundant
-	float u = (x + 0.5f) * (1.f / resolution.x());
-	float v = (y + 0.5f) * (1.f / resolution.y());
-	float ray_time = rolling_shutter.x() + rolling_shutter.y() * u + rolling_shutter.z() * v + rolling_shutter.w() * ld_random_val(sample_index, idx * 72239731);
-	Ray ray = pixel_to_ray(
+	Vector2f pixel_offset = ld_random_pixel_offset(snap_to_pixel_centers ? 0 : sample_index);
+	Vector2f uv = Vector2f{(float)x + pixel_offset.x(), (float)y + pixel_offset.y()}.cwiseQuotient(resolution.cast<float>());
+	float ray_time = rolling_shutter.x() + rolling_shutter.y() * uv.x() + rolling_shutter.z() * uv.y() + rolling_shutter.w() * ld_random_val(sample_index, idx * 72239731);
+	Ray ray = uv_to_ray(
 		sample_index,
-		{x, y},
-		resolution.cwiseQuotient(quilting_dims),
+		uv,
+		resolution,
 		focal_length,
 		camera_matrix0 * ray_time + camera_matrix1 * (1.f - ray_time),
 		screen_center,
 		parallax_shift,
-		snap_to_pixel_centers,
 		near_distance,
 		plane_z,
 		aperture_size,
+		foveation,
+		hidden_area_mask,
 		lens,
-		distortion_data,
-		distortion_resolution
+		distortion
 	);
 
 	NerfPayload& payload = payloads[idx];
 	payload.max_weight = 0.0f;
 
+	depth_buffer[idx] = MAX_DEPTH();
+
+	if (!ray.is_valid()) {
+		payload.origin = ray.o;
+		payload.alive = false;
+		return;
+	}
+
 	if (plane_z < 0) {
 		float n = ray.d.norm();
 		payload.origin = ray.o;
@@ -1870,21 +1875,19 @@ __global__ void init_rays_with_payload_kernel_nerf(
 		payload.idx = idx;
 		payload.n_steps = 0;
 		payload.alive = false;
-		depthbuffer[idx] = -plane_z;
+		depth_buffer[idx] = -plane_z;
 		return;
 	}
 
-	depthbuffer[idx] = 1e10f;
-
 	ray.d = ray.d.normalized();
 
-	if (envmap_data) {
-		framebuffer[idx] = read_envmap(envmap_data, envmap_resolution, ray.d);
+	if (envmap) {
+		frame_buffer[idx] = read_envmap(envmap, ray.d);
 	}
 
 	float t = fmaxf(render_aabb.ray_intersect(render_aabb_to_local * ray.o, render_aabb_to_local * ray.d).x(), 0.0f) + 1e-6f;
 
-	if (!render_aabb.contains(render_aabb_to_local * (ray.o + ray.d * t))) {
+	if (!render_aabb.contains(render_aabb_to_local * ray(t))) {
 		payload.origin = ray.o;
 		payload.alive = false;
 		return;
@@ -1892,13 +1895,14 @@ __global__ void init_rays_with_payload_kernel_nerf(
 
 	if (render_mode == ERenderMode::Distortion) {
 		Vector2f offset = Vector2f::Zero();
-		if (distortion_data) {
-			offset += read_image<2>(distortion_data, distortion_resolution, Vector2f((float)x + 0.5f, (float)y + 0.5f).cwiseQuotient(resolution.cast<float>()));
+		if (distortion) {
+			offset += distortion.at_lerp(Vector2f{(float)x + 0.5f, (float)y + 0.5f}.cwiseQuotient(resolution.cast<float>()));
 		}
-		framebuffer[idx].head<3>() = to_rgb(offset * 50.0f);
-		framebuffer[idx].w() = 1.0f;
-		depthbuffer[idx] = 1.0f;
-		payload.origin = ray.o + ray.d * 10000.0f;
+
+		frame_buffer[idx].head<3>() = to_rgb(offset * 50.0f);
+		frame_buffer[idx].w() = 1.0f;
+		depth_buffer[idx] = 1.0f;
+		payload.origin = ray(MAX_DEPTH());
 		payload.alive = false;
 		return;
 	}
@@ -1987,21 +1991,20 @@ void Testbed::NerfTracer::init_rays_from_camera(
 	const Vector4f& rolling_shutter,
 	const Vector2f& screen_center,
 	const Vector3f& parallax_shift,
-	const Vector2i& quilting_dims,
 	bool snap_to_pixel_centers,
 	const BoundingBox& render_aabb,
 	const Matrix3f& render_aabb_to_local,
 	float near_distance,
 	float plane_z,
 	float aperture_size,
+	const Foveation& foveation,
 	const Lens& lens,
-	const float* envmap_data,
-	const Vector2i& envmap_resolution,
-	const float* distortion_data,
-	const Vector2i& distortion_resolution,
+	const Buffer2DView<const Array4f>& envmap,
+	const Buffer2DView<const Vector2f>& distortion,
 	Eigen::Array4f* frame_buffer,
 	float* depth_buffer,
-	uint8_t* grid,
+	const Buffer2DView<const uint8_t>& hidden_area_mask,
+	const uint8_t* grid,
 	int show_accel,
 	float cone_angle_constant,
 	ERenderMode render_mode,
@@ -2029,15 +2032,14 @@ void Testbed::NerfTracer::init_rays_from_camera(
 		near_distance,
 		plane_z,
 		aperture_size,
+		foveation,
 		lens,
-		envmap_data,
-		envmap_resolution,
+		envmap,
 		frame_buffer,
 		depth_buffer,
-		distortion_data,
-		distortion_resolution,
-		render_mode,
-		quilting_dims
+		hidden_area_mask,
+		distortion,
+		render_mode
 	);
 
 	m_n_rays_initialized = resolution.x() * resolution.y();
@@ -2064,8 +2066,6 @@ uint32_t Testbed::NerfTracer::trace(
 	const BoundingBox& render_aabb,
 	const Eigen::Matrix3f& render_aabb_to_local,
 	const BoundingBox& train_aabb,
-	const uint32_t n_training_images,
-	const TrainingXForm* training_xforms,
 	const Vector2f& focal_length,
 	float cone_angle_constant,
 	const uint8_t* grid,
@@ -2156,8 +2156,6 @@ uint32_t Testbed::NerfTracer::trace(
 			train_aabb,
 			glow_y_cutoff,
 			glow_mode,
-			n_training_images,
-			training_xforms,
 			camera_matrix,
 			focal_length,
 			depth_scale,
@@ -2261,51 +2259,59 @@ const float* Testbed::get_inference_extra_dims(cudaStream_t stream) const {
 	return dims_gpu;
 }
 
-void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_res, const Vector2f& focal_length, const Matrix<float, 3, 4>& camera_matrix0, const Matrix<float, 3, 4>& camera_matrix1, const Vector4f& rolling_shutter, const Vector2f& screen_center, cudaStream_t stream) {
+void Testbed::render_nerf(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
+	NerfNetwork<precision_t>& nerf_network,
+	const uint8_t* density_grid_bitfield,
+	const Vector2f& focal_length,
+	const Matrix<float, 3, 4>& camera_matrix0,
+	const Matrix<float, 3, 4>& camera_matrix1,
+	const Vector4f& rolling_shutter,
+	const Vector2f& screen_center,
+	const Foveation& foveation,
+	int visualized_dimension
+) {
 	float plane_z = m_slice_plane_z + m_scale;
 	if (m_render_mode == ERenderMode::Slice) {
 		plane_z = -plane_z;
 	}
 
-	ERenderMode render_mode = m_visualized_dimension > -1 ? ERenderMode::EncodingVis : m_render_mode;
+	ERenderMode render_mode = visualized_dimension > -1 ? ERenderMode::EncodingVis : m_render_mode;
 
 	const float* extra_dims_gpu = get_inference_extra_dims(stream);
 
 	NerfTracer tracer;
 
-	// Our motion vector code can't undo f-theta and grid distortions -- so don't render these if DLSS is enabled.
-	bool render_opencv_lens = m_nerf.render_with_lens_distortion && (!render_buffer.dlss() || m_nerf.render_lens.mode == ELensMode::OpenCV || m_nerf.render_lens.mode == ELensMode::OpenCVFisheye);
-	bool render_grid_distortion = m_nerf.render_with_lens_distortion && !render_buffer.dlss();
-
-	Lens lens = render_opencv_lens ? m_nerf.render_lens : Lens{};
-
+	// Our motion vector code can't undo grid distortions -- so don't render grid distortion if DLSS is enabled
+	auto grid_distortion = m_nerf.render_with_lens_distortion && !m_dlss ? m_distortion.inference_view() : Buffer2DView<const Vector2f>{};
+	Lens lens = m_nerf.render_with_lens_distortion ? m_nerf.render_lens : Lens{};
 
 	tracer.init_rays_from_camera(
-		render_buffer.spp(),
-		m_network->padded_output_width(),
-		m_nerf_network->n_extra_dims(),
-		render_buffer.in_resolution(),
+		render_buffer.spp,
+		nerf_network.padded_output_width(),
+		nerf_network.n_extra_dims(),
+		render_buffer.resolution,
 		focal_length,
 		camera_matrix0,
 		camera_matrix1,
 		rolling_shutter,
 		screen_center,
 		m_parallax_shift,
-		m_quilting_dims,
 		m_snap_to_pixel_centers,
 		m_render_aabb,
 		m_render_aabb_to_local,
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
+		foveation,
 		lens,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_grid_distortion ? m_distortion.map->inference_params() : nullptr,
-		m_distortion.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
-		m_nerf.density_grid_bitfield.data(),
+		m_envmap.inference_view(),
+		grid_distortion,
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
+		density_grid_bitfield,
 		m_nerf.show_accel,
 		m_nerf.cone_angle_constant,
 		render_mode,
@@ -2318,20 +2324,18 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 	} else {
 		float depth_scale = 1.0f / m_nerf.training.dataset.scale;
 		n_hit = tracer.trace(
-			*m_nerf_network,
+			nerf_network,
 			m_render_aabb,
 			m_render_aabb_to_local,
 			m_aabb,
-			m_nerf.training.n_images_for_training,
-			m_nerf.training.transforms.data(),
 			focal_length,
 			m_nerf.cone_angle_constant,
-			m_nerf.density_grid_bitfield.data(),
+			density_grid_bitfield,
 			render_mode,
 			camera_matrix1,
 			depth_scale,
 			m_visualized_layer,
-			m_visualized_dimension,
+			visualized_dimension,
 			m_nerf.rgb_activation,
 			m_nerf.density_activation,
 			m_nerf.show_accel,
@@ -2347,19 +2351,19 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 	if (m_render_mode == ERenderMode::Slice) {
 		// Store colors in the normal buffer
 		uint32_t n_elements = next_multiple(n_hit, tcnn::batch_size_granularity);
-		const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + m_nerf_network->n_extra_dims();
-		const uint32_t extra_stride = m_nerf_network->n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct
+		const uint32_t floats_per_coord = sizeof(NerfCoordinate) / sizeof(float) + nerf_network.n_extra_dims();
+		const uint32_t extra_stride = nerf_network.n_extra_dims() * sizeof(float); // extra stride on top of base NerfCoordinate struct
 
 		GPUMatrix<float> positions_matrix{floats_per_coord, n_elements, stream};
 		GPUMatrix<float> rgbsigma_matrix{4, n_elements, stream};
 
 		linear_kernel(generate_nerf_network_inputs_at_current_position, 0, stream, n_hit, m_aabb, rays_hit.payload, PitchedPtr<NerfCoordinate>((NerfCoordinate*)positions_matrix.data(), 1, 0, extra_stride), extra_dims_gpu );
 
-		if (m_visualized_dimension == -1) {
-			m_network->inference(stream, positions_matrix, rgbsigma_matrix);
+		if (visualized_dimension == -1) {
+			nerf_network.inference(stream, positions_matrix, rgbsigma_matrix);
 			linear_kernel(compute_nerf_rgba, 0, stream, n_hit, (Array4f*)rgbsigma_matrix.data(), m_nerf.rgb_activation, m_nerf.density_activation, 0.01f, false);
 		} else {
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, rgbsigma_matrix);
+			nerf_network.visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, rgbsigma_matrix);
 		}
 
 		linear_kernel(shade_kernel_nerf, 0, stream,
@@ -2369,8 +2373,8 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 			rays_hit.payload,
 			m_render_mode,
 			m_nerf.training.linear_colors,
-			render_buffer.frame_buffer(),
-			render_buffer.depth_buffer()
+			render_buffer.frame_buffer,
+			render_buffer.depth_buffer
 		);
 		return;
 	}
@@ -2382,8 +2386,8 @@ void Testbed::render_nerf(CudaRenderBuffer& render_buffer, const Vector2i& max_r
 		rays_hit.payload,
 		m_render_mode,
 		m_nerf.training.linear_colors,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer()
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer
 	);
 
 	if (render_mode == ERenderMode::Cost) {
@@ -2673,7 +2677,20 @@ void Testbed::load_nerf(const fs::path& data_path) {
 			throw std::runtime_error{"NeRF data path must either be a json file or a directory containing json files."};
 		}
 
+		const auto prev_aabb_scale = m_nerf.training.dataset.aabb_scale;
+
 		m_nerf.training.dataset = ngp::load_nerf(json_paths, m_nerf.sharpen);
+
+		// Check if the NeRF network has been previously configured.
+		// If it has not, don't reset it.
+		bool previously_configured = !m_network_config["rgb_network"].is_null()
+		                          && !m_network_config["dir_encoding"].is_null();
+
+		if (m_nerf.training.dataset.aabb_scale != prev_aabb_scale && previously_configured) {
+			// The AABB scale affects network size indirectly. If it changed after loading,
+			// we need to reset the previously configured network to keep a consistent internal state.
+			reset_network();
+		}
 	}
 
 	load_nerf_post();
@@ -2785,6 +2802,40 @@ void Testbed::update_density_grid_mean_and_bitfield(cudaStream_t stream) {
 	for (uint32_t level = 1; level < NERF_CASCADES(); ++level) {
 		linear_kernel(bitfield_max_pool, 0, stream, n_elements/64, m_nerf.get_density_grid_bitfield_mip(level-1), m_nerf.get_density_grid_bitfield_mip(level));
 	}
+
+	set_all_devices_dirty();
+}
+
+__global__ void mark_density_grid_in_sphere_empty_kernel(const uint32_t n_elements, float* density_grid, Vector3f pos, float radius) {
+	const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
+	if (i >= n_elements) return;
+
+	// Random position within that cellq
+	uint32_t level = i / NERF_GRID_N_CELLS();
+	uint32_t pos_idx = i % NERF_GRID_N_CELLS();
+
+	uint32_t x = tcnn::morton3D_invert(pos_idx>>0);
+	uint32_t y = tcnn::morton3D_invert(pos_idx>>1);
+	uint32_t z = tcnn::morton3D_invert(pos_idx>>2);
+
+	float cell_radius = scalbnf(SQRT3(), level) / NERF_GRIDSIZE();
+	Vector3f cell_pos = ((Vector3f{(float)x+0.5f, (float)y+0.5f, (float)z+0.5f}) / NERF_GRIDSIZE() - Vector3f::Constant(0.5f)) * scalbnf(1.0f, level) + Vector3f::Constant(0.5f);
+
+	// Disable if the cell touches the sphere (conservatively, by bounding the cell with a sphere)
+	if ((pos - cell_pos).norm() < radius + cell_radius) {
+		density_grid[i] = -1.0f;
+	}
+}
+
+void Testbed::mark_density_grid_in_sphere_empty(const Vector3f& pos, float radius, cudaStream_t stream) {
+	const uint32_t n_elements = NERF_GRID_N_CELLS() * (m_nerf.max_cascade + 1);
+	if (m_nerf.density_grid.size() != n_elements) {
+		return;
+	}
+
+	linear_kernel(mark_density_grid_in_sphere_empty_kernel, 0, stream, n_elements, m_nerf.density_grid.data(), pos, radius);
+
+	update_density_grid_mean_and_bitfield(stream);
 }
 
 void Testbed::NerfCounters::prepare_for_training_steps(cudaStream_t stream) {
@@ -3167,8 +3218,7 @@ void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters&
 		m_nerf.training.snap_to_pixel_centers,
 		m_nerf.training.train_envmap,
 		m_nerf.cone_angle_constant,
-		m_distortion.map->params(),
-		m_distortion.resolution,
+		m_distortion.view(),
 		sample_focal_plane_proportional_to_error ? m_nerf.training.error_map.cdf_x_cond_y.data() : nullptr,
 		sample_focal_plane_proportional_to_error ? m_nerf.training.error_map.cdf_y.data() : nullptr,
 		sample_image_proportional_to_error ? m_nerf.training.error_map.cdf_img.data() : nullptr,
@@ -3197,7 +3247,7 @@ void Testbed::train_nerf_step(uint32_t target_batch_size, Testbed::NerfCounters&
 		ray_counter,
 		LOSS_SCALE,
 		padded_output_width,
-		m_envmap.envmap->params(),
+		m_envmap.view(),
 		envmap_gradient,
 		m_envmap.resolution,
 		m_envmap.loss_type,
diff --git a/src/testbed_sdf.cu b/src/testbed_sdf.cu
index aced131525d5198518e14ed2a97ac75e180f6a6d..2332bec11f2b36773872657d6c941edab8f52a81 100644
--- a/src/testbed_sdf.cu
+++ b/src/testbed_sdf.cu
@@ -156,7 +156,7 @@ __global__ void advance_pos_kernel_sdf(
 	BoundingBox aabb,
 	float floor_y,
 	const TriangleOctreeNode* __restrict__ octree_nodes,
-	int max_depth,
+	int max_octree_depth,
 	float distance_scale,
 	float maximum_distance,
 	float k,
@@ -181,8 +181,8 @@ __global__ void advance_pos_kernel_sdf(
 	pos += distance * payload.dir;
 
 	// Skip over regions not covered by the octree
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, pos)) {
-		float octree_distance = (TriangleOctree::ray_intersect(octree_nodes, max_depth, pos, payload.dir) + 1e-6f);
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, pos)) {
+		float octree_distance = (TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, pos, payload.dir) + 1e-6f);
 		distance += octree_distance;
 		pos += octree_distance * payload.dir;
 	}
@@ -242,7 +242,7 @@ __global__ void prepare_shadow_rays(const uint32_t n_elements,
 	SdfPayload* __restrict__ payloads,
 	BoundingBox aabb,
 	const TriangleOctreeNode* __restrict__ octree_nodes,
-	int max_depth
+	int max_octree_depth
 ) {
 	const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
 	if (i >= n_elements) return;
@@ -256,21 +256,21 @@ __global__ void prepare_shadow_rays(const uint32_t n_elements,
 	float t = fmaxf(aabb.ray_intersect(view_pos, dir).x() + 1e-6f, 0.0f);
 	view_pos += t * dir;
 
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, view_pos)) {
-		t = fmaxf(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_depth, view_pos, dir) + 1e-6f);
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, view_pos)) {
+		t = fmaxf(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, view_pos, dir) + 1e-6f);
 		view_pos += t * dir;
 	}
 
 	positions[i] = view_pos;
 
 	if (!aabb.contains(view_pos)) {
-		distances[i] = 10000.0f;
+		distances[i] = MAX_DEPTH();
 		payload.alive = false;
 		min_visibility[i] = 1.0f;
 		return;
 	}
 
-	distances[i] = 10000.0f;
+	distances[i] = MAX_DEPTH();
 	payload.idx = i;
 	payload.dir = dir;
 	payload.n_steps = 0;
@@ -322,13 +322,13 @@ __global__ void shade_kernel_sdf(
 
 	// The normal in memory isn't normalized yet
 	Vector3f normal = normals[i].normalized();
-
 	Vector3f pos = positions[i];
 	bool floor = false;
-	if (pos.y() < floor_y+0.001f && payload.dir.y() < 0.f) {
+	if (pos.y() < floor_y + 0.001f && payload.dir.y() < 0.f) {
 		normal = Vector3f(0.f, 1.f, 0.f);
 		floor = true;
 	}
+
 	Vector3f cam_pos = camera_matrix.col(3);
 	Vector3f cam_fwd = camera_matrix.col(2);
 	float ao = powf(0.92f, payload.n_steps * 0.5f) * (1.f / 0.92f);
@@ -456,12 +456,12 @@ __global__ void scale_to_aabb_kernel(uint32_t n_elements, BoundingBox aabb, Vect
 	inout[i] = aabb.min + inout[i].cwiseProduct(aabb.diag());
 }
 
-__global__ void compare_signs_kernel(uint32_t n_elements, const Vector3f *positions, const float *distances_ref, const float *distances_model, uint32_t *counters, const TriangleOctreeNode* octree_nodes, int max_depth) {
+__global__ void compare_signs_kernel(uint32_t n_elements, const Vector3f *positions, const float *distances_ref, const float *distances_model, uint32_t *counters, const TriangleOctreeNode* octree_nodes, int max_octree_depth) {
 	uint32_t i = blockIdx.x * blockDim.x + threadIdx.x;
 	if (i >= n_elements) return;
 	bool inside1 = distances_ref[i]<=0.f;
 	bool inside2 = distances_model[i]<=0.f;
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, positions[i])) {
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, positions[i])) {
 		inside2=inside1; // assume, when using the octree, that the model is always correct outside the octree.
 		atomicAdd(&counters[6],1); // outside the octree
 	} else {
@@ -506,12 +506,13 @@ __global__ void init_rays_with_payload_kernel_sdf(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Foveation foveation,
+	Buffer2DView<const Eigen::Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
 	const TriangleOctreeNode* __restrict__ octree_nodes = nullptr,
-	int max_depth = 0
+	int max_octree_depth = 0
 ) {
 	uint32_t x = threadIdx.x + blockDim.x * blockIdx.x;
 	uint32_t y = threadIdx.y + blockDim.y * blockIdx.y;
@@ -526,30 +527,54 @@ __global__ void init_rays_with_payload_kernel_sdf(
 		aperture_size = 0.0;
 	}
 
-	Ray ray = pixel_to_ray(sample_index, {x, y}, resolution, focal_length, camera_matrix, screen_center, parallax_shift, snap_to_pixel_centers, near_distance, plane_z, aperture_size);
+	Ray ray = pixel_to_ray(
+		sample_index,
+		{x, y},
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		snap_to_pixel_centers,
+		near_distance,
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
+	);
+
+	distances[idx] = MAX_DEPTH();
+	depth_buffer[idx] = MAX_DEPTH();
+
+	SdfPayload& payload = payloads[idx];
 
-	distances[idx] = 10000.0f;
+	if (!ray.is_valid()) {
+		payload.dir = ray.d;
+		payload.idx = idx;
+		payload.n_steps = 0;
+		payload.alive = false;
+		positions[idx] = ray.o;
+		return;
+	}
 
 	if (plane_z < 0) {
 		float n = ray.d.norm();
-		SdfPayload& payload = payloads[idx];
 		payload.dir = (1.0f/n) * ray.d;
 		payload.idx = idx;
 		payload.n_steps = 0;
 		payload.alive = false;
 		positions[idx] = ray.o - plane_z * ray.d;
-		depthbuffer[idx] = -plane_z;
+		depth_buffer[idx] = -plane_z;
 		return;
 	}
 
-	depthbuffer[idx] = 1e10f;
-
 	ray.d = ray.d.normalized();
 	float t = max(aabb.ray_intersect(ray.o, ray.d).x(), 0.0f);
-	ray.o = ray.o + (t + 1e-6f) * ray.d;
 
-	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_depth, ray.o)) {
-		t = max(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_depth, ray.o, ray.d));
+	ray.advance(t + 1e-6f);
+
+	if (octree_nodes && !TriangleOctree::contains(octree_nodes, max_octree_depth, ray.o)) {
+		t = max(0.0f, TriangleOctree::ray_intersect(octree_nodes, max_octree_depth, ray.o, ray.d));
 		if (ray.o.y() > floor_y && ray.d.y() < 0.f) {
 			float floor_dist = -(ray.o.y() - floor_y) / ray.d.y();
 			if (floor_dist > 0.f) {
@@ -557,16 +582,15 @@ __global__ void init_rays_with_payload_kernel_sdf(
 			}
 		}
 
-		ray.o = ray.o + (t + 1e-6f) * ray.d;
+		ray.advance(t + 1e-6f);
 	}
 
 	positions[idx] = ray.o;
 
-	if (envmap_data) {
-		framebuffer[idx] = read_envmap(envmap_data, envmap_resolution, ray.d);
+	if (envmap) {
+		frame_buffer[idx] = read_envmap(envmap, ray.d);
 	}
 
-	SdfPayload& payload = payloads[idx];
 	payload.dir = ray.d;
 	payload.idx = idx;
 	payload.n_steps = 0;
@@ -600,10 +624,11 @@ void Testbed::SphereTracer::init_rays_from_camera(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* envmap_data,
-	const Vector2i& envmap_resolution,
+	const Foveation& foveation,
+	const Buffer2DView<const Array4f>& envmap,
 	Array4f* frame_buffer,
 	float* depth_buffer,
+	const Buffer2DView<const uint8_t>& hidden_area_mask,
 	const TriangleOctree* octree,
 	uint32_t n_octree_levels,
 	cudaStream_t stream
@@ -630,10 +655,11 @@ void Testbed::SphereTracer::init_rays_from_camera(
 		near_distance,
 		plane_z,
 		aperture_size,
-		envmap_data,
-		envmap_resolution,
+		foveation,
+		envmap,
 		frame_buffer,
 		depth_buffer,
+		hidden_area_mask,
 		octree ? octree->nodes_gpu() : nullptr,
 		octree ? n_octree_levels : 0
 	);
@@ -840,14 +866,15 @@ void Testbed::FiniteDifferenceNormalsApproximator::normal(uint32_t n_elements, c
 }
 
 void Testbed::render_sdf(
+	cudaStream_t stream,
 	const distance_fun_t& distance_function,
 	const normals_fun_t& normals_function,
-	CudaRenderBuffer& render_buffer,
-	const Vector2i& max_res,
+	const CudaRenderBufferView& render_buffer,
 	const Vector2f& focal_length,
 	const Matrix<float, 3, 4>& camera_matrix,
 	const Vector2f& screen_center,
-	cudaStream_t stream
+	const Foveation& foveation,
+	int visualized_dimension
 ) {
 	float plane_z = m_slice_plane_z + m_scale;
 	if (m_render_mode == ERenderMode::Slice) {
@@ -865,8 +892,8 @@ void Testbed::render_sdf(
 	BoundingBox sdf_bounding_box = m_aabb;
 	sdf_bounding_box.inflate(m_sdf.zero_offset);
 	tracer.init_rays_from_camera(
-		render_buffer.spp(),
-		render_buffer.in_resolution(),
+		render_buffer.spp,
+		render_buffer.resolution,
 		focal_length,
 		camera_matrix,
 		screen_center,
@@ -877,10 +904,11 @@ void Testbed::render_sdf(
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		foveation,
+		m_envmap.inference_view(),
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
 		octree_ptr,
 		n_octree_levels,
 		stream
@@ -912,10 +940,11 @@ void Testbed::render_sdf(
 	} else {
 		n_hit = trace(tracer);
 	}
+
 	RaysSdfSoa& rays_hit = m_render_mode == ERenderMode::Slice || gt_raytrace ? tracer.rays_init() : tracer.rays_hit();
 
 	if (m_render_mode == ERenderMode::Slice) {
-		if (m_visualized_dimension == -1) {
+		if (visualized_dimension == -1) {
 			distance_function(n_hit, rays_hit.pos, rays_hit.distance, stream);
 			extract_dimension_pos_neg_kernel<float><<<n_blocks_linear(n_hit*3), n_threads_linear, 0, stream>>>(n_hit*3, 0, 1, 3, rays_hit.distance, CM, (float*)rays_hit.normal);
 		} else {
@@ -924,11 +953,11 @@ void Testbed::render_sdf(
 
 			GPUMatrix<float> positions_matrix((float*)rays_hit.pos, 3, n_elements);
 			GPUMatrix<float> colors_matrix((float*)rays_hit.normal, 3, n_elements);
-			m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+			m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 		}
 	}
 
-	ERenderMode render_mode = (m_visualized_dimension > -1 || m_render_mode == ERenderMode::Slice) ? ERenderMode::EncodingVis : m_render_mode;
+	ERenderMode render_mode = (visualized_dimension > -1 || m_render_mode == ERenderMode::Slice) ? ERenderMode::EncodingVis : m_render_mode;
 	if (render_mode == ERenderMode::Shade || render_mode == ERenderMode::Normals) {
 		if (m_sdf.analytic_normals || gt_raytrace) {
 			normals_function(n_hit, rays_hit.pos, rays_hit.normal, stream);
@@ -964,6 +993,7 @@ void Testbed::render_sdf(
 				octree_ptr ? octree_ptr->nodes_gpu() : nullptr,
 				n_octree_levels
 			);
+
 			uint32_t n_hit_shadow = trace(shadow_tracer);
 			auto& shadow_rays_hit = gt_raytrace ? shadow_tracer.rays_init() : shadow_tracer.rays_hit();
 
@@ -984,7 +1014,7 @@ void Testbed::render_sdf(
 
 		GPUMatrix<float> positions_matrix((float*)rays_hit.pos, 3, n_elements);
 		GPUMatrix<float> colors_matrix((float*)rays_hit.normal, 3, n_elements);
-		m_network->visualize_activation(stream, m_visualized_layer, m_visualized_dimension, positions_matrix, colors_matrix);
+		m_network->visualize_activation(stream, m_visualized_layer, visualized_dimension, positions_matrix, colors_matrix);
 	}
 
 	linear_kernel(shade_kernel_sdf, 0, stream,
@@ -1000,8 +1030,8 @@ void Testbed::render_sdf(
 		rays_hit.normal,
 		rays_hit.distance,
 		rays_hit.payload,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer()
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer
 	);
 
 	if (render_mode == ERenderMode::Cost) {
diff --git a/src/testbed_volume.cu b/src/testbed_volume.cu
index 10306cb0ca6e0dbca0f59f32923c9c84df79b6a8..c8c7a09a7236b2740b0d6f5b5da5d1496a68673d 100644
--- a/src/testbed_volume.cu
+++ b/src/testbed_volume.cu
@@ -218,10 +218,11 @@ __global__ void init_rays_volume(
 	float near_distance,
 	float plane_z,
 	float aperture_size,
-	const float* __restrict__ envmap_data,
-	const Vector2i envmap_resolution,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Foveation foveation,
+	Buffer2DView<const Array4f> envmap,
+	Array4f* __restrict__ frame_buffer,
+	float* __restrict__ depth_buffer,
+	Buffer2DView<const uint8_t> hidden_area_mask,
 	default_rng_t rng,
 	const uint8_t *bitgrid,
 	float distance_scale,
@@ -240,20 +241,42 @@ __global__ void init_rays_volume(
 	if (plane_z < 0) {
 		aperture_size = 0.0;
 	}
-	Ray ray = pixel_to_ray(sample_index, {x, y}, resolution, focal_length, camera_matrix, screen_center, parallax_shift, snap_to_pixel_centers, near_distance, plane_z, aperture_size);
+
+	Ray ray = pixel_to_ray(
+		sample_index,
+		{x, y},
+		resolution,
+		focal_length,
+		camera_matrix,
+		screen_center,
+		parallax_shift,
+		snap_to_pixel_centers,
+		near_distance,
+		plane_z,
+		aperture_size,
+		foveation,
+		hidden_area_mask
+	);
+
+	if (!ray.is_valid()) {
+		depth_buffer[idx] = MAX_DEPTH();
+		return;
+	}
+
 	ray.d = ray.d.normalized();
 	auto box_intersection = aabb.ray_intersect(ray.o, ray.d);
 	float t = max(box_intersection.x(), 0.0f);
-	ray.o = ray.o + (t + 1e-6f) * ray.d;
+	ray.advance(t + 1e-6f);
 	float scale = distance_scale / global_majorant;
+
 	if (t >= box_intersection.y() || !walk_to_next_event(rng, aabb, ray.o, ray.d, bitgrid, scale)) {
-		framebuffer[idx] = proc_envmap_render(ray.d, up_dir, sun_dir, sky_col);
-		depthbuffer[idx] = 1e10f;
+		frame_buffer[idx] = proc_envmap_render(ray.d, up_dir, sun_dir, sky_col);
+		depth_buffer[idx] = MAX_DEPTH();
 	} else {
 		uint32_t dstidx = atomicAdd(pixel_counter, 1);
 		positions[dstidx] = ray.o;
 		payloads[dstidx] = {ray.d, Array4f::Constant(0.f), idx};
-		depthbuffer[idx] = camera_matrix.col(2).dot(ray.o - camera_matrix.col(3));
+		depth_buffer[idx] = camera_matrix.col(2).dot(ray.o - camera_matrix.col(3));
 	}
 }
 
@@ -276,8 +299,7 @@ __global__ void volume_render_kernel_gt(
 	float distance_scale,
 	float albedo,
 	float scattering,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer
+	Array4f* __restrict__ frame_buffer
 ) {
 	uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x;
 	if (idx>=n_pixels || idx>=pixel_counter_in[0])
@@ -325,7 +347,7 @@ __global__ void volume_render_kernel_gt(
 	} else {
 		col = proc_envmap_render(dir, up_dir, sun_dir, sky_col);
 	}
-	framebuffer[pixidx] = col;
+	frame_buffer[pixidx] = col;
 }
 
 __global__ void volume_render_kernel_step(
@@ -351,8 +373,7 @@ __global__ void volume_render_kernel_step(
 	float distance_scale,
 	float albedo,
 	float scattering,
-	Array4f* __restrict__ framebuffer,
-	float* __restrict__ depthbuffer,
+	Array4f* __restrict__ frame_buffer,
 	bool force_finish_ray
 ) {
 	uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x;
@@ -382,23 +403,25 @@ __global__ void volume_render_kernel_step(
 	payload.col.w() += alpha;
 	if (payload.col.w() > 0.99f || !walk_to_next_event(rng, aabb, pos, dir, bitgrid, scale) || force_finish_ray) {
 		payload.col += (1.f-payload.col.w()) * proc_envmap_render(dir, up_dir, sun_dir, sky_col);
-		framebuffer[pixidx] = payload.col;
+		frame_buffer[pixidx] = payload.col;
 		return;
 	}
 	uint32_t dstidx = atomicAdd(pixel_counter_out, 1);
-	positions_out[dstidx]=pos;
-	payloads_out[dstidx]=payload;
+	positions_out[dstidx] = pos;
+	payloads_out[dstidx] = payload;
 }
 
-void Testbed::render_volume(CudaRenderBuffer& render_buffer,
+void Testbed::render_volume(
+	cudaStream_t stream,
+	const CudaRenderBufferView& render_buffer,
 	const Vector2f& focal_length,
 	const Matrix<float, 3, 4>& camera_matrix,
 	const Vector2f& screen_center,
-	cudaStream_t stream
+	const Foveation& foveation
 ) {
 	float plane_z = m_slice_plane_z + m_scale;
 	float distance_scale = 1.f/std::max(m_volume.inv_distance_scale,0.01f);
-	auto res = render_buffer.in_resolution();
+	auto res = render_buffer.resolution;
 
 	size_t n_pixels = (size_t)res.x() * res.y();
 	for (uint32_t i=0;i<2;++i) {
@@ -413,7 +436,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 	const dim3 threads = { 16, 8, 1 };
 	const dim3 blocks = { div_round_up((uint32_t)res.x(), threads.x), div_round_up((uint32_t)res.y(), threads.y), 1 };
 	init_rays_volume<<<blocks, threads, 0, stream>>>(
-		render_buffer.spp(),
+		render_buffer.spp,
 		m_volume.pos[0].data(),
 		m_volume.payload[0].data(),
 		m_volume.hit_counter.data(),
@@ -427,10 +450,11 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 		m_render_near_distance,
 		plane_z,
 		m_aperture_size,
-		m_envmap.envmap->inference_params(),
-		m_envmap.resolution,
-		render_buffer.frame_buffer(),
-		render_buffer.depth_buffer(),
+		foveation,
+		m_envmap.inference_view(),
+		render_buffer.frame_buffer,
+		render_buffer.depth_buffer,
+		render_buffer.hidden_area_mask ? render_buffer.hidden_area_mask->const_view() : Buffer2DView<const uint8_t>{},
 		m_rng,
 		m_volume.bitgrid.data(),
 		distance_scale,
@@ -466,8 +490,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 			distance_scale,
 			std::min(m_volume.albedo,0.995f),
 			m_volume.scattering,
-			render_buffer.frame_buffer(),
-			render_buffer.depth_buffer()
+			render_buffer.frame_buffer
 		);
 		m_rng.advance(n_pixels*256);
 	} else {
@@ -508,8 +531,7 @@ void Testbed::render_volume(CudaRenderBuffer& render_buffer,
 				distance_scale,
 				std::min(m_volume.albedo,0.995f),
 				m_volume.scattering,
-				render_buffer.frame_buffer(),
-				render_buffer.depth_buffer(),
+				render_buffer.frame_buffer,
 				(iter>=max_iter-1)
 			);
 			m_rng.advance(n_pixels*256);