Skip to content
Snippets Groups Projects
Commit 7d67128a authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Solve issue render input

parent 90b18c2f
No related branches found
No related tags found
No related merge requests found
......@@ -192,7 +192,7 @@ def train(
if fabric.is_global_zero:
if it % visualize_every == 0 and vis_data != None:
print('Visualizing...')
visualize(model, vis_data, out_dir=out_dir, render_args=config['render_args'], fabric=fabric)
visualize(model, vis_data, out_dir=out_dir, render_args=config['render_args'], fabric=fabric, config=config)
fabric.print(f"Visualization process")
dt = time.time() - t0
### LOGGING
......@@ -212,7 +212,7 @@ def train(
fabric.save(f"{out_dir}/model_last.pt", state)
fabric.print(f'Last model validation metric : (loss {metric_val_best:.6f})')
def visualize(model, vis_data, out_dir, render_args, fabric):
def visualize(model, vis_data, out_dir, render_args, fabric, config):
model.eval()
with torch.no_grad():
vis_data = fabric.to_device(vis_data)
......@@ -270,8 +270,9 @@ def visualize(model, vis_data, out_dir, render_args, fabric):
camera_pos_rot = nerf.transform_points_torch(camera_pos_rot, transform)
rays_rot = nerf.transform_points_torch(
rays_rot, transform.unsqueeze(1).unsqueeze(2), translate=False)
img, extras = render_image(z, camera_pos_rot, rays_rot, model, **render_args)
img, extras = render_image(z, camera_pos_rot, rays_rot,
config['training']['num_gpu'], batch_size,
config['data']['num_points'], model, **render_args)
columns.append((f'render {angle_deg}°', img.cpu().numpy(), 'image'))
if 'depth' in extras:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment