Skip to content
Snippets Groups Projects
Commit 0f7f245c authored by Thomas Müller's avatar Thomas Müller
Browse files

Robustify NeRF dataset loader

parent 97178a3f
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,6 @@ jobs:
working-directory: ${{ env.build_dir }}
run: cmake --build . --target all --verbose -j `nproc`
build_windows:
name: Build on Windows
runs-on: ${{ matrix.os }}
......
......@@ -326,10 +326,29 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
}
);
std::vector<std::string> supported_image_formats = {
"png", "jpg", "jpeg", "bmp", "gif", "tga", "pic", "pnm", "psd", "exr",
};
auto resolve_path = [&supported_image_formats](const fs::path& base_path, const fs::path& local_path) {
fs::path path = local_path.is_absolute() ? local_path : (base_path / local_path);
if (path.extension().empty() && !path.exists()) {
for (const auto& format : supported_image_formats) {
if (path.with_extension(format).exists()) {
return path.with_extension(format);
}
}
}
return path;
};
result.n_images = 0;
for (size_t i = 0; i < jsons.size(); ++i) {
auto& json = jsons[i];
fs::path basepath = jsonpaths[i].parent_path();
fs::path base_path = jsonpaths[i].parent_path();
if (!json.contains("frames") || !json["frames"].is_array()) {
tlog::warning() << " " << jsonpaths[i] << " does not contain any frames. Skipping.";
continue;
......@@ -343,6 +362,11 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
return frame1["file_path"] < frame2["file_path"];
});
for (auto&& frame : frames) {
// Compatibility with Windows paths on Linux. (Breaks linux filenames with "\\" in them, which is acceptable for us.)
frame["file_path"] = replace_all(frame["file_path"], "\\", "/");
}
if (json.contains("n_frames")) {
size_t cull_idx = std::min(frames.size(), (size_t)json["n_frames"]);
frames.get_ptr<nlohmann::json::array_t*>()->resize(cull_idx);
......@@ -357,20 +381,18 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
for (int i = 0; i < (int)frames_copy.size(); ++i) {
float mean_sharpness = 0.0f;
int mean_start = std::max(0, i-neighborhood_size);
int mean_end = std::min(i+neighborhood_size, (int)frames_copy.size()-1);
int mean_end = std::min(i + neighborhood_size, (int)frames_copy.size() - 1);
for (int j = mean_start; j < mean_end; ++j) {
mean_sharpness += float(frames_copy[j]["sharpness"]);
mean_sharpness += float(frames_copy[j].value("sharpness", 1.0));
}
mean_sharpness /= (mean_end - mean_start);
// Compatibility with Windows paths on Linux. (Breaks linux filenames with "\\" in them, which is acceptable for us.)
frames_copy[i]["file_path"] = replace_all(frames_copy[i]["file_path"], "\\", "/");
mean_sharpness /= (mean_end - mean_start);
if ((basepath / fs::path(std::string(frames_copy[i]["file_path"]))).exists() && frames_copy[i]["sharpness"] > sharpness_discard_threshold * mean_sharpness) {
if (resolve_path(base_path, frames_copy[i]["file_path"]).exists() && frames_copy[i].value("sharpness", 1.0) > sharpness_discard_threshold * mean_sharpness) {
frames.emplace_back(frames_copy[i]);
} else {
// tlog::info() << "discarding frame " << frames_copy[i]["file_path"];
// fs::remove(basepath / fs::path(std::string(frames_copy[i]["file_path"])));
// fs::remove(resolve_path(base_path, frames_copy[i]["file_path"]));
}
}
}
......@@ -395,7 +417,8 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
std::vector<std::future<void>> futures;
size_t image_idx = 0;
if (result.n_images==0) {
if (result.n_images == 0) {
throw std::invalid_argument{"No training images were found for NeRF training!"};
}
......@@ -410,10 +433,10 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
for (size_t i = 0; i < jsons.size(); ++i) {
auto& json = jsons[i];
fs::path basepath = jsonpaths[i].parent_path();
fs::path base_path = jsonpaths[i].parent_path();
std::string jp = jsonpaths[i].str();
auto lastdot = jp.find_last_of('.'); if (lastdot==std::string::npos) lastdot=jp.length();
auto lastunderscore = jp.find_last_of('_'); if (lastunderscore==std::string::npos) lastunderscore=lastdot; else lastunderscore++;
auto lastdot = jp.find_last_of('.'); if (lastdot==std::string::npos) lastdot = jp.length();
auto lastunderscore = jp.find_last_of('_'); if (lastunderscore == std::string::npos) lastunderscore=lastdot; else lastunderscore++;
std::string part_after_underscore(jp.begin()+lastunderscore,jp.begin()+lastdot);
if (json.contains("enable_ray_loading")) {
......@@ -517,8 +540,7 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
}
if (json.contains("envmap") && result.envmap_resolution.isZero()) {
std::string json_provided_path = json["envmap"];
fs::path envmap_path = basepath / json_provided_path;
fs::path envmap_path = resolve_path(base_path, json["envmap"]);
if (!envmap_path.exists()) {
throw std::runtime_error{fmt::format("Environment map {} does not exist.", envmap_path.str())};
}
......@@ -531,28 +553,23 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
}
}
if (json.contains("frames") && json["frames"].is_array()) pool.parallel_for_async<size_t>(0, json["frames"].size(), [&progress, &n_loaded, &result, &images, &json, basepath, image_idx, info, rolling_shutter, principal_point, lens, part_after_underscore, fix_premult, enable_depth_loading, enable_ray_loading](size_t i) {
if (json.contains("frames") && json["frames"].is_array()) pool.parallel_for_async<size_t>(0, json["frames"].size(), [&progress, &n_loaded, &result, &images, &json, &resolve_path, &supported_image_formats, base_path, image_idx, info, rolling_shutter, principal_point, lens, part_after_underscore, fix_premult, enable_depth_loading, enable_ray_loading](size_t i) {
size_t i_img = i + image_idx;
auto& frame = json["frames"][i];
LoadedImageInfo& dst = images[i_img];
dst = info; // copy defaults
std::string json_provided_path(frame["file_path"]);
std::string json_provided_path = frame["file_path"];
if (json_provided_path == "") {
char buf[256];
snprintf(buf, 256, "%s_%03d/rgba.png", part_after_underscore.c_str(), (int)i);
json_provided_path = buf;
}
fs::path path = basepath / json_provided_path;
if (path.extension() == "") {
path = path.with_extension("png");
if (!path.exists()) {
path = path.with_extension("exr");
}
if (!path.exists()) {
throw std::runtime_error{"Could not find image file: " + path.str()};
}
fs::path path = resolve_path(base_path, json_provided_path);
if (!path.exists()) {
throw std::runtime_error{fmt::format("Could not find image file '{}'.", path.str())};
}
int comp = 0;
......@@ -568,34 +585,38 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
throw std::runtime_error{"Could not open image file: "s + std::string{stbi_failure_reason()}};
}
fs::path alphapath = basepath / fmt::format("{}.alpha.{}", frame["file_path"], path.extension());
fs::path alphapath = resolve_path(base_path, fmt::format("{}.alpha.{}", frame["file_path"], path.extension()));
if (alphapath.exists()) {
int wa = 0, ha = 0;
uint8_t* alpha_img = stbi_load(alphapath.str().c_str(), &wa, &ha, &comp, 4);
if (!alpha_img) {
throw std::runtime_error{"Could not load alpha image "s + alphapath.str()};
}
ScopeGuard mem_guard{[&]() { stbi_image_free(alpha_img); }};
if (wa != dst.res.x() || ha != dst.res.y()) {
throw std::runtime_error{fmt::format("Alpha image {} has wrong resolution.", alphapath.str())};
}
tlog::success() << "Alpha loaded from " << alphapath;
for (int i = 0; i < dst.res.prod(); ++i) {
img[i*4+3] = (uint8_t)(255.0f*srgb_to_linear(alpha_img[i*4]*(1.f/255.f))); // copy red channel of alpha to alpha.png to our alpha channel
}
}
fs::path maskpath = path.parent_path()/(fmt::format("dynamic_mask_{}.png", path.basename()));
fs::path maskpath = path.parent_path() / fmt::format("dynamic_mask_{}.png", path.basename());
if (maskpath.exists()) {
int wa = 0, ha = 0;
uint8_t* mask_img = stbi_load(maskpath.str().c_str(), &wa, &ha, &comp, 4);
if (!mask_img) {
throw std::runtime_error{fmt::format("Dynamic mask {} could not be loaded.", maskpath.str())};
}
ScopeGuard mem_guard{[&]() { stbi_image_free(mask_img); }};
if (wa != dst.res.x() || ha != dst.res.y()) {
throw std::runtime_error{fmt::format("Dynamic mask {} has wrong resolution.", maskpath.str())};
}
dst.mask_color = 0x00FF00FF; // HOT PINK
for (int i = 0; i < dst.res.prod(); ++i) {
if (mask_img[i*4] != 0 || mask_img[i*4+1] != 0 || mask_img[i*4+2] != 0) {
......@@ -609,25 +630,25 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
}
if (!dst.pixels) {
throw std::runtime_error{"Could not load image: " + path.str()};
throw std::runtime_error{fmt::format("Could not load image file '{}'.", path.str())};
}
if (enable_depth_loading && info.depth_scale > 0.f && frame.contains("depth_path")) {
fs::path depthpath = basepath / std::string{frame["depth_path"]};
fs::path depthpath = resolve_path(base_path, frame["depth_path"]);
if (depthpath.exists()) {
int wa=0,ha=0;
int wa = 0, ha = 0;
dst.depth_pixels = stbi_load_16(depthpath.str().c_str(), &wa, &ha, &comp, 1);
if (!dst.depth_pixels) {
throw std::runtime_error{"Could not load depth image "s + depthpath.str()};
throw std::runtime_error{fmt::format("Could not load depth image '{}'.", depthpath.str())};
}
if (wa != dst.res.x() || ha != dst.res.y()) {
throw std::runtime_error{fmt::format("Depth image {} has wrong resolution.", depthpath.str())};
}
//tlog::success() << "Depth loaded from " << depthpath;
}
}
fs::path rayspath = path.parent_path()/(fmt::format("rays_{}.dat", path.basename()));
fs::path rayspath = path.parent_path() / fmt::format("rays_{}.dat", path.basename());
if (enable_ray_loading && rayspath.exists()) {
uint32_t n_pixels = dst.res.prod();
dst.rays = (Ray*)malloc(n_pixels * sizeof(Ray));
......@@ -647,6 +668,7 @@ NerfDataset load_nerf(const std::vector<filesystem::path>& jsonpaths, float shar
for (uint32_t px = 0; px < n_pixels; ++px) {
result.nerf_ray_to_ngp(dst.rays[px]);
}
result.has_rays = true;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment