From c7005ee2a8efc28a0fd4ad8ae9673a3e3e71ea4d Mon Sep 17 00:00:00 2001 From: Karl Stelzner <stelzner@cs.tu-darmstadt.de> Date: Tue, 18 Apr 2023 17:43:06 +0200 Subject: [PATCH] Update README and rendering code --- README.md | 6 +++- compile_video.py | 79 +++++++++++------------------------------------- render.py | 21 +++++++++---- 3 files changed, 37 insertions(+), 69 deletions(-) diff --git a/README.md b/README.md index a34a8dc..ac2e7f7 100644 --- a/README.md +++ b/README.md @@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r and `compile_video.py` for different ways of compiling videos. ## Results +<img src="https://drive.google.com/uc?id=1UENZEp4OydMHDUOz8ySOfa0eTWyrwxYy" alt="MSN Rotation" width="750"/> + We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard. Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the 8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after -750k iterations. The checkpoint may be downloaded here: +750k iterations. The checkpoint may be downloaded +[here](https://drive.google.com/file/d/1EAxajGk0guvKtj0FLjza24pMbdV0p7br/view?usp=sharing). + ## Citation diff --git a/compile_video.py b/compile_video.py index 5e0801b..8b6de1a 100644 --- a/compile_video.py +++ b/compile_video.py @@ -10,9 +10,9 @@ from os.path import join from osrt.utils.visualize import setup_axis, background_image -def compile_video_plot(path, small=False, frames=False, num_frames=1000000000): +def compile_video_plot(path, frames=False, num_frames=1000000000): - frame_output_dir = os.path.join(path, 'frames_small' if small else 'frames') + frame_output_dir = os.path.join(path, 'frames') if not os.path.exists(frame_output_dir): os.mkdir(frame_output_dir) @@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000): if not frames: break - if small: - fig, ax = plt.subplots(2, 2, figsize=(600/dpi, 480/dpi), dpi=dpi) - else: - fig, ax = plt.subplots(3, 4, figsize=(1280/dpi, 720/dpi), dpi=dpi) + fig, ax = plt.subplots(1, 3, figsize=(900/dpi, 350/dpi), dpi=dpi) plt.subplots_adjust(wspace=0.05, hspace=0.08, left=0.01, right=0.99, top=0.995, bottom=0.035) - for row in ax: - for cell in row: - setup_axis(cell) + for cell in ax: + setup_axis(cell) - ax[0, 0].imshow(input_image) - ax[0, 0].set_xlabel('Input Image') + ax[0].imshow(input_image) + ax[0].set_xlabel('Input Image 1') try: render = imageio.imread(join(path, 'renders', f'{frame_id}.png')) except FileNotFoundError: break - ax[0, 1].imshow(bg_image) - ax[0, 1].imshow(render[..., :3]) - ax[0, 1].set_xlabel('Rendered Scene') - - try: - depths = imageio.imread(join(path, 'depths', f'{frame_id}.png')) - if small: - depths = depths.astype(np.float32) / 65536. - ax[1, 0].imshow(depths, cmap='viridis') - ax[1, 0].set_xlabel('Render Depths') - else: - depths = 1. - depths.astype(np.float32) / 65536. - ax[0, 2].imshow(depths, cmap='viridis') - ax[0, 2].set_xlabel('Render Depths') - except FileNotFoundError: - pass - - """ + ax[1].imshow(bg_image) + ax[1].imshow(render[..., :3]) + ax[1].set_xlabel('Rendered Scene') + segmentations = imageio.imread(join(path, 'segmentations', f'{frame_id}.png')) - if small: - ax[1, 1].imshow(segmentations) - ax[1, 1].set_xlabel('Segmentations') - else: - ax[0, 3].imshow(segmentations) - ax[0, 3].set_xlabel('Segmentations') - - if small: - fig.savefig(join(frame_output_dir, f'{frame_id}.png')) - plt.close() - - frame_id += 1 - continue - - for slot_id in range(8): - row = 1 + slot_id // 4 - col = slot_id % 4 - try: - slot_render = imageio.imread(join(path, 'slot_renders', f'{slot_id}-{frame_id}.png')) - except FileNotFoundError: - ax[row, col].axis('off') - continue - # if (slot_render[..., 3] > 0.1).astype(np.float32).mean() < 0.4: - ax[row, col].imshow(bg_image) - ax[row, col].imshow(slot_render) - ax[row, col].set_xlabel(f'Rendered Slot #{slot_id}') - """ + ax[2].imshow(segmentations) + ax[2].set_xlabel('Segmentations') fig.savefig(join(frame_output_dir, f'{frame_id}.png')) plt.close() + frame_id += 1 + frame_placeholder = join(frame_output_dir, '%d.png') - video_out_file = join(path, 'video-small.mp4' if small else 'video.mp4') + video_out_file = join(path, 'video.mp4') print('rendering video to ', video_out_file) subprocess.call(['ffmpeg', '-y', '-framerate', '60', '-i', frame_placeholder, '-pix_fmt', 'yuv420p', '-b:v', '1M', '-threads', '1', video_out_file]) @@ -110,15 +69,11 @@ if __name__ == '__main__': ) parser.add_argument('path', type=str, help='Path to image files.') parser.add_argument('--plot', action='store_true', help='Plot available data, instead of just renders.') - parser.add_argument('--small', action='store_true', help='Create small 2x2 video.') parser.add_argument('--noframes', action='store_true', help="Assume frames already exist and don't rerender them.") args = parser.parse_args() if args.plot: - compile_video_plot(args.path, small=args.small, frames=not args.noframes) + compile_video_plot(args.path, frames=not args.noframes) else: compile_video_render(args.path) - - - diff --git a/render.py b/render.py index 3d367cb..355553f 100644 --- a/render.py +++ b/render.py @@ -11,10 +11,10 @@ from osrt.data import get_dataset from osrt.checkpoint import Checkpoint from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch -from osrt.model import SRT +from osrt.model import OSRT from osrt.trainer import SRTTrainer -from compile_video import compile_video_render +from compile_video import compile_video_render, compile_video_plot def get_camera_rays_render(camera_pos, **kwargs): @@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8) imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths) + if 'segmentation' in extras: + pred_seg = extras['segmentation'].squeeze(0).cpu() + colors = get_clustering_colors(pred_seg.shape[-1] + 1) + pred_seg = pred_seg.argmax(-1).numpy() + 1 + pred_img = visualize_2d_cluster(pred_seg, colors) + pred_img = (pred_img * 255.).astype(np.uint8) + imageio.imwrite(os.path.join(render_path, 'segmentations', f'{frame}.png'), pred_img) + def process_scene(sceneid): render_path = os.path.join(out_dir, 'render', args.name, str(sceneid)) @@ -124,7 +132,7 @@ def process_scene(sceneid): print(f'Warning: Path {render_path} exists. Contents will be overwritten.') os.makedirs(render_path, exist_ok=True) - subdirs = ['renders', 'depths'] + subdirs = ['renders', 'depths', 'segmentations'] for d in subdirs: os.makedirs(os.path.join(render_path, d), exist_ok=True) @@ -155,12 +163,13 @@ def process_scene(sceneid): with torch.no_grad(): z = model.encoder(input_images, input_camera_pos, input_rays) - + print('Rendering frames...') render3d(trainer, render_path, z, input_camera_pos[:, 0], motion=args.motion, transform=transform, resolution=resolution, **render_kwargs) if not args.novideo: - compile_video_render(render_path) + print('Compiling plot video...') + compile_video_plot(render_path, frames=True, num_frames=args.num_frames) if __name__ == '__main__': # Arguments @@ -201,7 +210,7 @@ if __name__ == '__main__': else: render_kwargs = dict() - model = SRT(cfg['model']).to(device) + model = OSRT(cfg['model']).to(device) model.eval() mode = 'train' if args.train else 'val' -- GitLab