Skip to content
Snippets Groups Projects
Commit c7005ee2 authored by Karl Stelzner's avatar Karl Stelzner
Browse files

Update README and rendering code

parent 11419cc0
No related branches found
No related tags found
No related merge requests found
......@@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r
and `compile_video.py` for different ways of compiling videos.
## Results
<img src="https://drive.google.com/uc?id=1UENZEp4OydMHDUOz8ySOfa0eTWyrwxYy" alt="MSN Rotation" width="750"/>
We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes
used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard.
Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the
8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after
750k iterations. The checkpoint may be downloaded here:
750k iterations. The checkpoint may be downloaded
[here](https://drive.google.com/file/d/1EAxajGk0guvKtj0FLjza24pMbdV0p7br/view?usp=sharing).
## Citation
......
......@@ -10,9 +10,9 @@ from os.path import join
from osrt.utils.visualize import setup_axis, background_image
def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
def compile_video_plot(path, frames=False, num_frames=1000000000):
frame_output_dir = os.path.join(path, 'frames_small' if small else 'frames')
frame_output_dir = os.path.join(path, 'frames')
if not os.path.exists(frame_output_dir):
os.mkdir(frame_output_dir)
......@@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
if not frames:
break
if small:
fig, ax = plt.subplots(2, 2, figsize=(600/dpi, 480/dpi), dpi=dpi)
else:
fig, ax = plt.subplots(3, 4, figsize=(1280/dpi, 720/dpi), dpi=dpi)
fig, ax = plt.subplots(1, 3, figsize=(900/dpi, 350/dpi), dpi=dpi)
plt.subplots_adjust(wspace=0.05, hspace=0.08, left=0.01, right=0.99, top=0.995, bottom=0.035)
for row in ax:
for cell in row:
setup_axis(cell)
for cell in ax:
setup_axis(cell)
ax[0, 0].imshow(input_image)
ax[0, 0].set_xlabel('Input Image')
ax[0].imshow(input_image)
ax[0].set_xlabel('Input Image 1')
try:
render = imageio.imread(join(path, 'renders', f'{frame_id}.png'))
except FileNotFoundError:
break
ax[0, 1].imshow(bg_image)
ax[0, 1].imshow(render[..., :3])
ax[0, 1].set_xlabel('Rendered Scene')
try:
depths = imageio.imread(join(path, 'depths', f'{frame_id}.png'))
if small:
depths = depths.astype(np.float32) / 65536.
ax[1, 0].imshow(depths, cmap='viridis')
ax[1, 0].set_xlabel('Render Depths')
else:
depths = 1. - depths.astype(np.float32) / 65536.
ax[0, 2].imshow(depths, cmap='viridis')
ax[0, 2].set_xlabel('Render Depths')
except FileNotFoundError:
pass
"""
ax[1].imshow(bg_image)
ax[1].imshow(render[..., :3])
ax[1].set_xlabel('Rendered Scene')
segmentations = imageio.imread(join(path, 'segmentations', f'{frame_id}.png'))
if small:
ax[1, 1].imshow(segmentations)
ax[1, 1].set_xlabel('Segmentations')
else:
ax[0, 3].imshow(segmentations)
ax[0, 3].set_xlabel('Segmentations')
if small:
fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
plt.close()
frame_id += 1
continue
for slot_id in range(8):
row = 1 + slot_id // 4
col = slot_id % 4
try:
slot_render = imageio.imread(join(path, 'slot_renders', f'{slot_id}-{frame_id}.png'))
except FileNotFoundError:
ax[row, col].axis('off')
continue
# if (slot_render[..., 3] > 0.1).astype(np.float32).mean() < 0.4:
ax[row, col].imshow(bg_image)
ax[row, col].imshow(slot_render)
ax[row, col].set_xlabel(f'Rendered Slot #{slot_id}')
"""
ax[2].imshow(segmentations)
ax[2].set_xlabel('Segmentations')
fig.savefig(join(frame_output_dir, f'{frame_id}.png'))
plt.close()
frame_id += 1
frame_placeholder = join(frame_output_dir, '%d.png')
video_out_file = join(path, 'video-small.mp4' if small else 'video.mp4')
video_out_file = join(path, 'video.mp4')
print('rendering video to ', video_out_file)
subprocess.call(['ffmpeg', '-y', '-framerate', '60', '-i', frame_placeholder,
'-pix_fmt', 'yuv420p', '-b:v', '1M', '-threads', '1', video_out_file])
......@@ -110,15 +69,11 @@ if __name__ == '__main__':
)
parser.add_argument('path', type=str, help='Path to image files.')
parser.add_argument('--plot', action='store_true', help='Plot available data, instead of just renders.')
parser.add_argument('--small', action='store_true', help='Create small 2x2 video.')
parser.add_argument('--noframes', action='store_true', help="Assume frames already exist and don't rerender them.")
args = parser.parse_args()
if args.plot:
compile_video_plot(args.path, small=args.small, frames=not args.noframes)
compile_video_plot(args.path, frames=not args.noframes)
else:
compile_video_render(args.path)
......@@ -11,10 +11,10 @@ from osrt.data import get_dataset
from osrt.checkpoint import Checkpoint
from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors
from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch
from osrt.model import SRT
from osrt.model import OSRT
from osrt.trainer import SRTTrainer
from compile_video import compile_video_render
from compile_video import compile_video_render, compile_video_plot
def get_camera_rays_render(camera_pos, **kwargs):
......@@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu
depths = (depths / render_kwargs['max_dist'] * 255.).astype(np.uint8)
imageio.imwrite(os.path.join(render_path, 'depths', f'{frame}.png'), depths)
if 'segmentation' in extras:
pred_seg = extras['segmentation'].squeeze(0).cpu()
colors = get_clustering_colors(pred_seg.shape[-1] + 1)
pred_seg = pred_seg.argmax(-1).numpy() + 1
pred_img = visualize_2d_cluster(pred_seg, colors)
pred_img = (pred_img * 255.).astype(np.uint8)
imageio.imwrite(os.path.join(render_path, 'segmentations', f'{frame}.png'), pred_img)
def process_scene(sceneid):
render_path = os.path.join(out_dir, 'render', args.name, str(sceneid))
......@@ -124,7 +132,7 @@ def process_scene(sceneid):
print(f'Warning: Path {render_path} exists. Contents will be overwritten.')
os.makedirs(render_path, exist_ok=True)
subdirs = ['renders', 'depths']
subdirs = ['renders', 'depths', 'segmentations']
for d in subdirs:
os.makedirs(os.path.join(render_path, d), exist_ok=True)
......@@ -155,12 +163,13 @@ def process_scene(sceneid):
with torch.no_grad():
z = model.encoder(input_images, input_camera_pos, input_rays)
print('Rendering frames...')
render3d(trainer, render_path, z, input_camera_pos[:, 0],
motion=args.motion, transform=transform, resolution=resolution, **render_kwargs)
if not args.novideo:
compile_video_render(render_path)
print('Compiling plot video...')
compile_video_plot(render_path, frames=True, num_frames=args.num_frames)
if __name__ == '__main__':
# Arguments
......@@ -201,7 +210,7 @@ if __name__ == '__main__':
else:
render_kwargs = dict()
model = SRT(cfg['model']).to(device)
model = OSRT(cfg['model']).to(device)
model.eval()
mode = 'train' if args.train else 'val'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment