From c7005ee2a8efc28a0fd4ad8ae9673a3e3e71ea4d Mon Sep 17 00:00:00 2001
From: Karl Stelzner <stelzner@cs.tu-darmstadt.de>
Date: Tue, 18 Apr 2023 17:43:06 +0200
Subject: [PATCH] Update README and rendering code

---
 README.md        |  6 +++-
 compile_video.py | 79 +++++++++++-------------------------------------
 render.py        | 21 +++++++++----
 3 files changed, 37 insertions(+), 69 deletions(-)

diff --git a/README.md b/README.md
index a34a8dc..ac2e7f7 100644
--- a/README.md
+++ b/README.md
@@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r
 and `compile_video.py` for different ways of compiling videos.
 
 ## Results
+<img src="https://drive.google.com/uc?id=1UENZEp4OydMHDUOz8ySOfa0eTWyrwxYy" alt="MSN Rotation" width="750"/>
+
 We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes
 used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard.
 Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the
 8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after
-750k iterations. The checkpoint may be downloaded here: 
+750k iterations. The checkpoint may be downloaded
+[here](https://drive.google.com/file/d/1EAxajGk0guvKtj0FLjza24pMbdV0p7br/view?usp=sharing).
+
 
 ## Citation
 
diff --git a/compile_video.py b/compile_video.py
index 5e0801b..8b6de1a 100644
--- a/compile_video.py
+++ b/compile_video.py
@@ -10,9 +10,9 @@ from os.path import join
 
 from osrt.utils.visualize import setup_axis, background_image
 
-def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
+def compile_video_plot(path, frames=False, num_frames=1000000000):
 
-    frame_output_dir = os.path.join(path, 'frames_small' if small else 'frames')
+    frame_output_dir = os.path.join(path, 'frames')
     if not os.path.exists(frame_output_dir):
         os.mkdir(frame_output_dir)
 
@@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
         if not frames:
             break
 
-        if small:
-            fig, ax = plt.subplots(2, 2, figsize=(600/dpi, 480/dpi), dpi=dpi)
-        else:
-            fig, ax = plt.subplots(3, 4, figsize=(1280/dpi, 720/dpi), dpi=dpi)
+        fig, ax = plt.subplots(1, 3, figsize=(900/dpi, 350/dpi), dpi=dpi)
 
         plt.subplots_adjust(wspace=0.05, hspace=0.08, left=0.01, right=0.99, top=0.995, bottom=0.035)
-        for row in ax:
-            for cell in row:
-                setup_axis(cell)
+        for cell in ax:
+            setup_axis(cell)
 
-        ax[0, 0].imshow(input_image)
-        ax[0, 0].set_xlabel('Input Image')
+        ax[0].imshow(input_image)
+        ax[0].set_xlabel('Input Image 1')
         try:
             render = imageio.imread(join(path, 'renders', f'{frame_id}.png'))
         except FileNotFoundError:
             break
-        ax[0, 1].imshow(bg_image)
-        ax[0, 1].imshow(render[..., :3])
-        ax[0, 1].set_xlabel('Rendered Scene')
-        
-        try:
-            depths = imageio.imread(join(path, 'depths', f'{frame_id}.png'))
-            if small:
-                depths = depths.astype(np.float32) / 65536.
-                ax[1, 0].imshow(depths, cmap='viridis')
-                ax[1, 0].set_xlabel('Render Depths')
-            else:
-                depths = 1. - depths.astype(np.float32) / 65536.
-                ax[0, 2].imshow(depths, cmap='viridis')
-                ax[0, 2].set_xlabel('Render Depths')
-        except FileNotFoundError:
-            pass
-
-        """
+        ax[1].imshow(bg_image)
+        ax[1].imshow(render[..., :3])
+        ax[1].set_xlabel('Rendered Scene')
+       
         segmentations = imageio.imread(join(path, 'segmentations', f'{frame_id}.png'))
-        if small:
-            ax[1, 1].imshow(segmentations)
-            ax[1, 1].set_xlabel('Segmentations')
-        else:
-            ax[0, 3].imshow(segmentations)
-            ax[0, 3].set_xlabel('Segmentations')
-
-        if small:
-            fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
-            plt.close()
-
-            frame_id += 1
-            continue
-
-        for slot_id in range(8):
-            row = 1 + slot_id // 4
-            col = slot_id % 4
-            try:
-                slot_render = imageio.imread(join(path, 'slot_renders', f'{slot_id}-{frame_id}.png'))
-            except FileNotFoundError:
-                ax[row, col].axis('off')
-                continue
-            # if (slot_render[..., 3] > 0.1).astype(np.float32).mean() < 0.4:
-            ax[row, col].imshow(bg_image)
-            ax[row, col].imshow(slot_render)
-            ax[row, col].set_xlabel(f'Rendered Slot #{slot_id}')
-        """
+        ax[2].imshow(segmentations)
+        ax[2].set_xlabel('Segmentations')
 
         fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
         plt.close()
 
+        frame_id += 1
+
     frame_placeholder = join(frame_output_dir, '%d.png')
-    video_out_file = join(path, 'video-small.mp4' if small else 'video.mp4')
+    video_out_file = join(path, 'video.mp4')
     print('rendering video to ', video_out_file)
     subprocess.call(['ffmpeg', '-y', '-framerate', '60', '-i', frame_placeholder,
                      '-pix_fmt', 'yuv420p', '-b:v', '1M', '-threads', '1', video_out_file])
@@ -110,15 +69,11 @@ if __name__ == '__main__':
     )
     parser.add_argument('path', type=str, help='Path to image files.')
     parser.add_argument('--plot', action='store_true', help='Plot available data, instead of just renders.')
-    parser.add_argument('--small', action='store_true', help='Create small 2x2 video.')
     parser.add_argument('--noframes', action='store_true', help="Assume frames already exist and don't rerender them.")
     args = parser.parse_args()
 
     if args.plot:
-        compile_video_plot(args.path, small=args.small, frames=not args.noframes)
+        compile_video_plot(args.path, frames=not args.noframes)
     else:
         compile_video_render(args.path)
 
-
-
-
diff --git a/render.py b/render.py
index 3d367cb..355553f 100644
--- a/render.py
+++ b/render.py
@@ -11,10 +11,10 @@ from osrt.data import get_dataset
 from osrt.checkpoint import Checkpoint
 from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors
 from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch
-from osrt.model import SRT
+from osrt.model import OSRT
 from osrt.trainer import SRTTrainer
 
-from compile_video import compile_video_render
+from compile_video import compile_video_render, compile_video_plot
 
 
 def get_camera_rays_render(camera_pos, **kwargs):
@@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu
             depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8)
             imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths)
 
+        if 'segmentation' in extras:
+            pred_seg = extras['segmentation'].squeeze(0).cpu()
+            colors = get_clustering_colors(pred_seg.shape[-1] + 1)
+            pred_seg = pred_seg.argmax(-1).numpy() + 1
+            pred_img = visualize_2d_cluster(pred_seg, colors)
+            pred_img = (pred_img * 255.).astype(np.uint8)
+            imageio.imwrite(os.path.join(render_path, 'segmentations', f'{frame}.png'), pred_img)
+
 
 def process_scene(sceneid):
     render_path = os.path.join(out_dir, 'render', args.name, str(sceneid))
@@ -124,7 +132,7 @@ def process_scene(sceneid):
         print(f'Warning: Path {render_path} exists. Contents will be overwritten.')
 
     os.makedirs(render_path, exist_ok=True)
-    subdirs = ['renders', 'depths']
+    subdirs = ['renders', 'depths', 'segmentations']
     for d in subdirs:
         os.makedirs(os.path.join(render_path, d), exist_ok=True)
 
@@ -155,12 +163,13 @@ def process_scene(sceneid):
 
     with torch.no_grad():
         z = model.encoder(input_images, input_camera_pos, input_rays)
-                                          
+        print('Rendering frames...')
         render3d(trainer, render_path, z, input_camera_pos[:, 0],
                  motion=args.motion, transform=transform, resolution=resolution, **render_kwargs)
 
     if not args.novideo:
-        compile_video_render(render_path)
+        print('Compiling plot video...')
+        compile_video_plot(render_path, frames=True, num_frames=args.num_frames)
 
 if __name__ == '__main__':
     # Arguments
@@ -201,7 +210,7 @@ if __name__ == '__main__':
     else:
         render_kwargs = dict()
 
-    model = SRT(cfg['model']).to(device)
+    model = OSRT(cfg['model']).to(device)
     model.eval()
 
     mode = 'train' if args.train else 'val'
-- 
GitLab