From ecd02e13f8497ed9467b7816ff1393f88a7e88f7 Mon Sep 17 00:00:00 2001
From: Thomas Pickles <thomas.pickles@ens-lyon.fr>
Date: Fri, 3 Mar 2023 15:53:28 +0100
Subject: [PATCH] Training data is now treated as greyscale, so nerf learns
 alpha channel rather than rgb channels

---
 src/testbed_nerf.cu | 51 ++++++++++++++++++++++++++++++---------------
 1 file changed, 34 insertions(+), 17 deletions(-)

diff --git a/src/testbed_nerf.cu b/src/testbed_nerf.cu
index 974f28e..bf08d18 100644
--- a/src/testbed_nerf.cu
+++ b/src/testbed_nerf.cu
@@ -214,25 +214,27 @@ inline __device__ float advance_to_next_voxel(float t, float cone_angle, const V
 }
 
 __device__ float network_to_rgb(float val, ENerfActivation activation) {
-	switch (activation) {
-		case ENerfActivation::None: return val;
-		case ENerfActivation::ReLU: return val > 0.0f ? val : 0.0f;
-		case ENerfActivation::Logistic: return tcnn::logistic(val);
-		case ENerfActivation::Exponential: return __expf(tcnn::clamp(val, -10.0f, 10.0f));
-		default: assert(false);
-	}
-	return 0.0f;
+	return 1.f; // always return 1
+	// switch (activation) {
+	// 	case ENerfActivation::None: return val;
+	// 	case ENerfActivation::ReLU: return val > 0.0f ? val : 0.0f;
+	// 	case ENerfActivation::Logistic: return tcnn::logistic(val);
+	// 	case ENerfActivation::Exponential: return __expf(tcnn::clamp(val, -10.0f, 10.0f));
+	// 	default: assert(false);
+	// }
+	// return 0.0f;
 }
 
 // No way to modify the derivative for rgb
 __device__ float network_to_rgb_derivative(float val, ENerfActivation activation) {
-	switch (activation) {
-		case ENerfActivation::None: return 1.0f;
-		case ENerfActivation::ReLU: return val > 0.0f ? 1.0f : 0.0f;
-		case ENerfActivation::Logistic: { float density = tcnn::logistic(val); return density * (1 - density); };
-		case ENerfActivation::Exponential: return __expf(tcnn::clamp(val, -10.0f, 10.0f));
-		default: assert(false);
-	}
+	return 0.f; // no way to change rgb value
+	// switch (activation) {
+	// 	case ENerfActivation::None: return 1.0f;
+	// 	case ENerfActivation::ReLU: return val > 0.0f ? 1.0f : 0.0f;
+	// 	case ENerfActivation::Logistic: { float density = tcnn::logistic(val); return density * (1 - density); };
+	// 	case ENerfActivation::Exponential: return __expf(tcnn::clamp(val, -10.0f, 10.0f));
+	// 	default: assert(false);
+	// }
 }
 
 __device__ float network_to_density(float val, ENerfActivation activation) {
@@ -1405,11 +1407,21 @@ __global__ void compute_loss_kernel_train_nerf(
 	Array3f exposure_scale = (0.6931471805599453f * exposure[img]).exp();
 	// Array3f rgbtarget = composit_and_lerp(uv, resolution, img, training_images, background_color, exposure_scale);
 	// Array3f rgbtarget = composit(uv, resolution, img, training_images, background_color, exposure_scale);
+
+	// The image pixel value.  
 	Array4f texsamp = read_rgba(uv, resolution, metadata[img].pixels, metadata[img].image_data_type);
+	// The .w() and .head<3>() are just eigen access
+	// operations to get the alpha and the rgb.
+
+	// The most important change.  We throw away the rgb value
+	// of our training data and ask the network to only learn 
+	// the alpha.  rgb values are now irrelevant, so shade
+	// will always be [1,1,1]
+	Array3f grey = Array3f::Constant(texsamp.w());
 
 	Array3f rgbtarget;
 	if (train_in_linear_colors || color_space == EColorSpace::Linear) {
-		rgbtarget = exposure_scale * texsamp.head<3>() + (1.0f - texsamp.w()) * background_color;
+		rgbtarget = exposure_scale * grey + (1.0f - texsamp.w()) * background_color;
 
 		if (!train_in_linear_colors) {
 			rgbtarget = linear_to_srgb(rgbtarget);
@@ -1418,7 +1430,7 @@ __global__ void compute_loss_kernel_train_nerf(
 	} else if (color_space == EColorSpace::SRGB) {
 		background_color = linear_to_srgb(background_color);
 		if (texsamp.w() > 0) {
-			rgbtarget = linear_to_srgb(exposure_scale * texsamp.head<3>() / texsamp.w()) * texsamp.w() + (1.0f - texsamp.w()) * background_color;
+			rgbtarget = linear_to_srgb(exposure_scale * grey / texsamp.w()) * texsamp.w() + (1.0f - texsamp.w()) * background_color;
 		} else {
 			rgbtarget = background_color;
 		}
@@ -1446,6 +1458,9 @@ __global__ void compute_loss_kernel_train_nerf(
 
 	dloss_doutput += compacted_base * padded_output_width;
 
+	// we don't care about colour
+	// setting rgbtarget to [1,1,1] doesn't work - it just result in all of the image
+	// being white
 	LossAndGradient lg = loss_and_gradient(rgbtarget, rgb_ray, loss_type);
 	lg.loss /= img_pdf * uv_pdf;
 
@@ -1529,6 +1544,8 @@ __global__ void compute_loss_kernel_train_nerf(
 
 		tcnn::vector_t<tcnn::network_precision_t, 4> local_dL_doutput;
 
+		// TURN OFF COLOUR-BASED TRAINING:
+		
 		// chain rule to go from dloss/drgb to dloss/dmlp_output
 		local_dL_doutput[0] = loss_scale * (dloss_by_drgb.x() * network_to_rgb_derivative(local_network_output[0], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[0])); // Penalize way too large color values
 		local_dL_doutput[1] = loss_scale * (dloss_by_drgb.y() * network_to_rgb_derivative(local_network_output[1], rgb_activation) + fmaxf(0.0f, output_l2_reg * (float)local_network_output[1]));
-- 
GitLab