Skip to content
Snippets Groups Projects
Commit c7005ee2 authored by Karl Stelzner's avatar Karl Stelzner
Browse files

Update README and rendering code

parent 11419cc0
No related branches found
No related tags found
No related merge requests found
...@@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r ...@@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r
and `compile_video.py` for different ways of compiling videos. and `compile_video.py` for different ways of compiling videos.
## Results ## Results
<img src="https://drive.google.com/uc?id=1UENZEp4OydMHDUOz8ySOfa0eTWyrwxYy" alt="MSN Rotation" width="750"/>
We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes
used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard. used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard.
Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the
8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after 8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after
750k iterations. The checkpoint may be downloaded here: 750k iterations. The checkpoint may be downloaded
[here](https://drive.google.com/file/d/1EAxajGk0guvKtj0FLjza24pMbdV0p7br/view?usp=sharing).
## Citation ## Citation
......
...@@ -10,9 +10,9 @@ from os.path import join ...@@ -10,9 +10,9 @@ from os.path import join
from osrt.utils.visualize import setup_axis, background_image from osrt.utils.visualize import setup_axis, background_image
def compile_video_plot(path, small=False, frames=False, num_frames=1000000000): def compile_video_plot(path, frames=False, num_frames=1000000000):
frame_output_dir = os.path.join(path, 'frames_small' if small else 'frames') frame_output_dir = os.path.join(path, 'frames')
if not os.path.exists(frame_output_dir): if not os.path.exists(frame_output_dir):
os.mkdir(frame_output_dir) os.mkdir(frame_output_dir)
...@@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000): ...@@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
if not frames: if not frames:
break break
if small: fig, ax = plt.subplots(1, 3, figsize=(900/dpi, 350/dpi), dpi=dpi)
fig, ax = plt.subplots(2, 2, figsize=(600/dpi, 480/dpi), dpi=dpi)
else:
fig, ax = plt.subplots(3, 4, figsize=(1280/dpi, 720/dpi), dpi=dpi)
plt.subplots_adjust(wspace=0.05, hspace=0.08, left=0.01, right=0.99, top=0.995, bottom=0.035) plt.subplots_adjust(wspace=0.05, hspace=0.08, left=0.01, right=0.99, top=0.995, bottom=0.035)
for row in ax: for cell in ax:
for cell in row: setup_axis(cell)
setup_axis(cell)
ax[0, 0].imshow(input_image) ax[0].imshow(input_image)
ax[0, 0].set_xlabel('Input Image') ax[0].set_xlabel('Input Image 1')
try: try:
render = imageio.imread(join(path, 'renders', f'{frame_id}.png')) render = imageio.imread(join(path, 'renders', f'{frame_id}.png'))
except FileNotFoundError: except FileNotFoundError:
break break
ax[0, 1].imshow(bg_image) ax[1].imshow(bg_image)
ax[0, 1].imshow(render[..., :3]) ax[1].imshow(render[..., :3])
ax[0, 1].set_xlabel('Rendered Scene') ax[1].set_xlabel('Rendered Scene')
try:
depths = imageio.imread(join(path, 'depths', f'{frame_id}.png'))
if small:
depths = depths.astype(np.float32) / 65536.
ax[1, 0].imshow(depths, cmap='viridis')
ax[1, 0].set_xlabel('Render Depths')
else:
depths = 1. - depths.astype(np.float32) / 65536.
ax[0, 2].imshow(depths, cmap='viridis')
ax[0, 2].set_xlabel('Render Depths')
except FileNotFoundError:
pass
"""
segmentations = imageio.imread(join(path, 'segmentations', f'{frame_id}.png')) segmentations = imageio.imread(join(path, 'segmentations', f'{frame_id}.png'))
if small: ax[2].imshow(segmentations)
ax[1, 1].imshow(segmentations) ax[2].set_xlabel('Segmentations')
ax[1, 1].set_xlabel('Segmentations')
else:
ax[0, 3].imshow(segmentations)
ax[0, 3].set_xlabel('Segmentations')
if small:
fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
plt.close()
frame_id += 1
continue
for slot_id in range(8):
row = 1 + slot_id // 4
col = slot_id % 4
try:
slot_render = imageio.imread(join(path, 'slot_renders', f'{slot_id}-{frame_id}.png'))
except FileNotFoundError:
ax[row, col].axis('off')
continue
# if (slot_render[..., 3] > 0.1).astype(np.float32).mean() < 0.4:
ax[row, col].imshow(bg_image)
ax[row, col].imshow(slot_render)
ax[row, col].set_xlabel(f'Rendered Slot #{slot_id}')
"""
fig.savefig(join(frame_output_dir, f'{frame_id}.png')) fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
plt.close() plt.close()
frame_id += 1
frame_placeholder = join(frame_output_dir, '%d.png') frame_placeholder = join(frame_output_dir, '%d.png')
video_out_file = join(path, 'video-small.mp4' if small else 'video.mp4') video_out_file = join(path, 'video.mp4')
print('rendering video to ', video_out_file) print('rendering video to ', video_out_file)
subprocess.call(['ffmpeg', '-y', '-framerate', '60', '-i', frame_placeholder, subprocess.call(['ffmpeg', '-y', '-framerate', '60', '-i', frame_placeholder,
'-pix_fmt', 'yuv420p', '-b:v', '1M', '-threads', '1', video_out_file]) '-pix_fmt', 'yuv420p', '-b:v', '1M', '-threads', '1', video_out_file])
...@@ -110,15 +69,11 @@ if __name__ == '__main__': ...@@ -110,15 +69,11 @@ if __name__ == '__main__':
) )
parser.add_argument('path', type=str, help='Path to image files.') parser.add_argument('path', type=str, help='Path to image files.')
parser.add_argument('--plot', action='store_true', help='Plot available data, instead of just renders.') parser.add_argument('--plot', action='store_true', help='Plot available data, instead of just renders.')
parser.add_argument('--small', action='store_true', help='Create small 2x2 video.')
parser.add_argument('--noframes', action='store_true', help="Assume frames already exist and don't rerender them.") parser.add_argument('--noframes', action='store_true', help="Assume frames already exist and don't rerender them.")
args = parser.parse_args() args = parser.parse_args()
if args.plot: if args.plot:
compile_video_plot(args.path, small=args.small, frames=not args.noframes) compile_video_plot(args.path, frames=not args.noframes)
else: else:
compile_video_render(args.path) compile_video_render(args.path)
...@@ -11,10 +11,10 @@ from osrt.data import get_dataset ...@@ -11,10 +11,10 @@ from osrt.data import get_dataset
from osrt.checkpoint import Checkpoint from osrt.checkpoint import Checkpoint
from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors
from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch
from osrt.model import SRT from osrt.model import OSRT
from osrt.trainer import SRTTrainer from osrt.trainer import SRTTrainer
from compile_video import compile_video_render from compile_video import compile_video_render, compile_video_plot
def get_camera_rays_render(camera_pos, **kwargs): def get_camera_rays_render(camera_pos, **kwargs):
...@@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu ...@@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu
depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8) depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8)
imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths) imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths)
if 'segmentation' in extras:
pred_seg = extras['segmentation'].squeeze(0).cpu()
colors = get_clustering_colors(pred_seg.shape[-1] + 1)
pred_seg = pred_seg.argmax(-1).numpy() + 1
pred_img = visualize_2d_cluster(pred_seg, colors)
pred_img = (pred_img * 255.).astype(np.uint8)
imageio.imwrite(os.path.join(render_path, 'segmentations', f'{frame}.png'), pred_img)
def process_scene(sceneid): def process_scene(sceneid):
render_path = os.path.join(out_dir, 'render', args.name, str(sceneid)) render_path = os.path.join(out_dir, 'render', args.name, str(sceneid))
...@@ -124,7 +132,7 @@ def process_scene(sceneid): ...@@ -124,7 +132,7 @@ def process_scene(sceneid):
print(f'Warning: Path {render_path} exists. Contents will be overwritten.') print(f'Warning: Path {render_path} exists. Contents will be overwritten.')
os.makedirs(render_path, exist_ok=True) os.makedirs(render_path, exist_ok=True)
subdirs = ['renders', 'depths'] subdirs = ['renders', 'depths', 'segmentations']
for d in subdirs: for d in subdirs:
os.makedirs(os.path.join(render_path, d), exist_ok=True) os.makedirs(os.path.join(render_path, d), exist_ok=True)
...@@ -155,12 +163,13 @@ def process_scene(sceneid): ...@@ -155,12 +163,13 @@ def process_scene(sceneid):
with torch.no_grad(): with torch.no_grad():
z = model.encoder(input_images, input_camera_pos, input_rays) z = model.encoder(input_images, input_camera_pos, input_rays)
print('Rendering frames...')
render3d(trainer, render_path, z, input_camera_pos[:, 0], render3d(trainer, render_path, z, input_camera_pos[:, 0],
motion=args.motion, transform=transform, resolution=resolution, **render_kwargs) motion=args.motion, transform=transform, resolution=resolution, **render_kwargs)
if not args.novideo: if not args.novideo:
compile_video_render(render_path) print('Compiling plot video...')
compile_video_plot(render_path, frames=True, num_frames=args.num_frames)
if __name__ == '__main__': if __name__ == '__main__':
# Arguments # Arguments
...@@ -201,7 +210,7 @@ if __name__ == '__main__': ...@@ -201,7 +210,7 @@ if __name__ == '__main__':
else: else:
render_kwargs = dict() render_kwargs = dict()
model = SRT(cfg['model']).to(device) model = OSRT(cfg['model']).to(device)
model.eval() model.eval()
mode = 'train' if args.train else 'val' mode = 'train' if args.train else 'val'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment