diff --git a/scripts/colmap2nerf.py b/scripts/colmap2nerf.py
index a8f3fa3f103201c0d33b91a59817bede2f9a714e..4a56df36243e0e8e3877dd8b4af1ed898b9a2218 100644
--- a/scripts/colmap2nerf.py
+++ b/scripts/colmap2nerf.py
@@ -33,6 +33,7 @@ def parse_args():
 	parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
 	parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
 	parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
+	parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)")
 	parser.add_argument("--out", default="transforms.json", help="output path")
 	args = parser.parse_args()
 	return args
@@ -253,50 +254,64 @@ if __name__ == "__main__":
 				t = tvec.reshape([3,1])
 				m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
 				c2w = np.linalg.inv(m)
-				c2w[0:3,2] *= -1 # flip the y and z axis
-				c2w[0:3,1] *= -1
-				c2w = c2w[[1,0,2,3],:] # swap y and z
-				c2w[2,:] *= -1 # flip whole world upside down
+				if not args.keep_colmap_coords:
+					c2w[0:3,2] *= -1 # flip the y and z axis
+					c2w[0:3,1] *= -1
+					c2w = c2w[[1,0,2,3],:] # swap y and z
+					c2w[2,:] *= -1 # flip whole world upside down
 
-				up += c2w[0:3,1]
+					up += c2w[0:3,1]
 
 				frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
 				out["frames"].append(frame)
 	nframes = len(out["frames"])
-	up = up / np.linalg.norm(up)
-	print("up vector was", up)
-	R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
-	R = np.pad(R,[0,1])
-	R[-1, -1] = 1
 
+	if args.keep_colmap_coords:
+		flip_mat = np.array([
+			[1, 0, 0, 0],
+			[0, -1, 0, 0],
+			[0, 0, -1, 0],
+			[0, 0, 0, 1]
+		])
 
-	for f in out["frames"]:
-		f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
+		for f in out["frames"]:
+			f["transform_matrix"] = np.matmul(f["transform_matrix"], flip_mat) # flip cameras (it just works)
+	else:
+		# don't keep colmap coords - reorient the scene to be easier to work with
+		
+		up = up / np.linalg.norm(up)
+		print("up vector was", up)
+		R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
+		R = np.pad(R,[0,1])
+		R[-1, -1] = 1
 
-	# find a central point they are all looking at
-	print("computing center of attention...")
-	totw = 0.0
-	totp = np.array([0.0, 0.0, 0.0])
-	for f in out["frames"]:
-		mf = f["transform_matrix"][0:3,:]
-		for g in out["frames"]:
-			mg = g["transform_matrix"][0:3,:]
-			p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
-			if w > 0.01:
-				totp += p*w
-				totw += w
-	totp /= totw
-	print(totp) # the cameras are looking at totp
-	for f in out["frames"]:
-		f["transform_matrix"][0:3,3] -= totp
+		for f in out["frames"]:
+			f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
+		
+		# find a central point they are all looking at
+		print("computing center of attention...")
+		totw = 0.0
+		totp = np.array([0.0, 0.0, 0.0])
+		for f in out["frames"]:
+			mf = f["transform_matrix"][0:3,:]
+			for g in out["frames"]:
+				mg = g["transform_matrix"][0:3,:]
+				p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
+				if w > 0.01:
+					totp += p*w
+					totw += w
+		totp /= totw
+		print(totp) # the cameras are looking at totp
+		for f in out["frames"]:
+			f["transform_matrix"][0:3,3] -= totp
 
-	avglen = 0.
-	for f in out["frames"]:
-		avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
-	avglen /= nframes
-	print("avg camera distance from origin", avglen)
-	for f in out["frames"]:
-		f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
+		avglen = 0.
+		for f in out["frames"]:
+			avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
+		avglen /= nframes
+		print("avg camera distance from origin", avglen)
+		for f in out["frames"]:
+			f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
 
 	for f in out["frames"]:
 		f["transform_matrix"] = f["transform_matrix"].tolist()