Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
c7005ee2
Commit
c7005ee2
authored
2 years ago
by
Karl Stelzner
Browse files
Options
Downloads
Patches
Plain Diff
Update README and rendering code
parent
11419cc0
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
README.md
+5
-1
5 additions, 1 deletion
README.md
compile_video.py
+17
-62
17 additions, 62 deletions
compile_video.py
render.py
+15
-6
15 additions, 6 deletions
render.py
with
37 additions
and
69 deletions
README.md
+
5
−
1
View file @
c7005ee2
...
...
@@ -52,11 +52,15 @@ Rendered frames and videos are placed in the run directory. Check the args of `r
and
`compile_video.py`
for different ways of compiling videos.
## Results
<img
src=
"https://drive.google.com/uc?id=1UENZEp4OydMHDUOz8ySOfa0eTWyrwxYy"
alt=
"MSN Rotation"
width=
"750"
/>
We have found OSRT's object segmentation performance to be strongly dependent on the batch sizes
used during training. Due to memory constraint, we were unable to match OSRT's setting on MSN-hard.
Our largest and most successful run thus far utilized 2304 target rays per scene as opposed to the
8192 specified in the paper. It reached a foreground ARI of around 0.73 and a PSNR of 22.8 after
750k iterations. The checkpoint may be downloaded here:
750k iterations. The checkpoint may be downloaded
[
here
](
https://drive.google.com/file/d/1EAxajGk0guvKtj0FLjza24pMbdV0p7br/view?usp=sharing
)
.
## Citation
...
...
This diff is collapsed.
Click to expand it.
compile_video.py
+
17
−
62
View file @
c7005ee2
...
...
@@ -10,9 +10,9 @@ from os.path import join
from
osrt.utils.visualize
import
setup_axis
,
background_image
def
compile_video_plot
(
path
,
small
=
False
,
frames
=
False
,
num_frames
=
1000000000
):
def
compile_video_plot
(
path
,
frames
=
False
,
num_frames
=
1000000000
):
frame_output_dir
=
os
.
path
.
join
(
path
,
'
frames_small
'
if
small
else
'
frames
'
)
frame_output_dir
=
os
.
path
.
join
(
path
,
'
frames
'
)
if
not
os
.
path
.
exists
(
frame_output_dir
):
os
.
mkdir
(
frame_output_dir
)
...
...
@@ -25,74 +25,33 @@ def compile_video_plot(path, small=False, frames=False, num_frames=1000000000):
if
not
frames
:
break
if
small
:
fig
,
ax
=
plt
.
subplots
(
2
,
2
,
figsize
=
(
600
/
dpi
,
480
/
dpi
),
dpi
=
dpi
)
else
:
fig
,
ax
=
plt
.
subplots
(
3
,
4
,
figsize
=
(
1280
/
dpi
,
720
/
dpi
),
dpi
=
dpi
)
fig
,
ax
=
plt
.
subplots
(
1
,
3
,
figsize
=
(
900
/
dpi
,
350
/
dpi
),
dpi
=
dpi
)
plt
.
subplots_adjust
(
wspace
=
0.05
,
hspace
=
0.08
,
left
=
0.01
,
right
=
0.99
,
top
=
0.995
,
bottom
=
0.035
)
for
row
in
ax
:
for
cell
in
row
:
setup_axis
(
cell
)
for
cell
in
ax
:
setup_axis
(
cell
)
ax
[
0
,
0
].
imshow
(
input_image
)
ax
[
0
,
0
].
set_xlabel
(
'
Input Image
'
)
ax
[
0
].
imshow
(
input_image
)
ax
[
0
].
set_xlabel
(
'
Input Image
1
'
)
try
:
render
=
imageio
.
imread
(
join
(
path
,
'
renders
'
,
f
'
{
frame_id
}
.png
'
))
except
FileNotFoundError
:
break
ax
[
0
,
1
].
imshow
(
bg_image
)
ax
[
0
,
1
].
imshow
(
render
[...,
:
3
])
ax
[
0
,
1
].
set_xlabel
(
'
Rendered Scene
'
)
try
:
depths
=
imageio
.
imread
(
join
(
path
,
'
depths
'
,
f
'
{
frame_id
}
.png
'
))
if
small
:
depths
=
depths
.
astype
(
np
.
float32
)
/
65536.
ax
[
1
,
0
].
imshow
(
depths
,
cmap
=
'
viridis
'
)
ax
[
1
,
0
].
set_xlabel
(
'
Render Depths
'
)
else
:
depths
=
1.
-
depths
.
astype
(
np
.
float32
)
/
65536.
ax
[
0
,
2
].
imshow
(
depths
,
cmap
=
'
viridis
'
)
ax
[
0
,
2
].
set_xlabel
(
'
Render Depths
'
)
except
FileNotFoundError
:
pass
"""
ax
[
1
].
imshow
(
bg_image
)
ax
[
1
].
imshow
(
render
[...,
:
3
])
ax
[
1
].
set_xlabel
(
'
Rendered Scene
'
)
segmentations
=
imageio
.
imread
(
join
(
path
,
'
segmentations
'
,
f
'
{
frame_id
}
.png
'
))
if small:
ax[1, 1].imshow(segmentations)
ax[1, 1].set_xlabel(
'
Segmentations
'
)
else:
ax[0, 3].imshow(segmentations)
ax[0, 3].set_xlabel(
'
Segmentations
'
)
if small:
fig.savefig(join(frame_output_dir, f
'
{frame_id}.png
'
))
plt.close()
frame_id += 1
continue
for slot_id in range(8):
row = 1 + slot_id // 4
col = slot_id % 4
try:
slot_render = imageio.imread(join(path,
'
slot_renders
'
, f
'
{slot_id}-{frame_id}.png
'
))
except FileNotFoundError:
ax[row, col].axis(
'
off
'
)
continue
# if (slot_render[..., 3] > 0.1).astype(np.float32).mean() < 0.4:
ax[row, col].imshow(bg_image)
ax[row, col].imshow(slot_render)
ax[row, col].set_xlabel(f
'
Rendered Slot #{slot_id}
'
)
"""
ax
[
2
].
imshow
(
segmentations
)
ax
[
2
].
set_xlabel
(
'
Segmentations
'
)
fig
.
savefig
(
join
(
frame_output_dir
,
f
'
{
frame_id
}
.png
'
))
plt
.
close
()
frame_id
+=
1
frame_placeholder
=
join
(
frame_output_dir
,
'
%d.png
'
)
video_out_file
=
join
(
path
,
'
video-small.mp4
'
if
small
else
'
video.mp4
'
)
video_out_file
=
join
(
path
,
'
video.mp4
'
)
print
(
'
rendering video to
'
,
video_out_file
)
subprocess
.
call
([
'
ffmpeg
'
,
'
-y
'
,
'
-framerate
'
,
'
60
'
,
'
-i
'
,
frame_placeholder
,
'
-pix_fmt
'
,
'
yuv420p
'
,
'
-b:v
'
,
'
1M
'
,
'
-threads
'
,
'
1
'
,
video_out_file
])
...
...
@@ -110,15 +69,11 @@ if __name__ == '__main__':
)
parser
.
add_argument
(
'
path
'
,
type
=
str
,
help
=
'
Path to image files.
'
)
parser
.
add_argument
(
'
--plot
'
,
action
=
'
store_true
'
,
help
=
'
Plot available data, instead of just renders.
'
)
parser
.
add_argument
(
'
--small
'
,
action
=
'
store_true
'
,
help
=
'
Create small 2x2 video.
'
)
parser
.
add_argument
(
'
--noframes
'
,
action
=
'
store_true
'
,
help
=
"
Assume frames already exist and don
'
t rerender them.
"
)
args
=
parser
.
parse_args
()
if
args
.
plot
:
compile_video_plot
(
args
.
path
,
small
=
args
.
small
,
frames
=
not
args
.
noframes
)
compile_video_plot
(
args
.
path
,
frames
=
not
args
.
noframes
)
else
:
compile_video_render
(
args
.
path
)
This diff is collapsed.
Click to expand it.
render.py
+
15
−
6
View file @
c7005ee2
...
...
@@ -11,10 +11,10 @@ from osrt.data import get_dataset
from
osrt.checkpoint
import
Checkpoint
from
osrt.utils.visualize
import
visualize_2d_cluster
,
get_clustering_colors
from
osrt.utils.nerf
import
rotate_around_z_axis_torch
,
get_camera_rays
,
transform_points_torch
,
get_extrinsic_torch
from
osrt.model
import
SRT
from
osrt.model
import
O
SRT
from
osrt.trainer
import
SRTTrainer
from
compile_video
import
compile_video_render
from
compile_video
import
compile_video_render
,
compile_video_plot
def
get_camera_rays_render
(
camera_pos
,
**
kwargs
):
...
...
@@ -117,6 +117,14 @@ def render3d(trainer, render_path, z, camera_pos, motion, transform=None, resolu
depths
=
(
depths
/
render_kwargs
[
'
max_dist
'
]
*
255.
).
astype
(
np
.
uint8
)
imageio
.
imwrite
(
os
.
path
.
join
(
render_path
,
'
depths
'
,
f
'
{
frame
}
.png
'
),
depths
)
if
'
segmentation
'
in
extras
:
pred_seg
=
extras
[
'
segmentation
'
].
squeeze
(
0
).
cpu
()
colors
=
get_clustering_colors
(
pred_seg
.
shape
[
-
1
]
+
1
)
pred_seg
=
pred_seg
.
argmax
(
-
1
).
numpy
()
+
1
pred_img
=
visualize_2d_cluster
(
pred_seg
,
colors
)
pred_img
=
(
pred_img
*
255.
).
astype
(
np
.
uint8
)
imageio
.
imwrite
(
os
.
path
.
join
(
render_path
,
'
segmentations
'
,
f
'
{
frame
}
.png
'
),
pred_img
)
def
process_scene
(
sceneid
):
render_path
=
os
.
path
.
join
(
out_dir
,
'
render
'
,
args
.
name
,
str
(
sceneid
))
...
...
@@ -124,7 +132,7 @@ def process_scene(sceneid):
print
(
f
'
Warning: Path
{
render_path
}
exists. Contents will be overwritten.
'
)
os
.
makedirs
(
render_path
,
exist_ok
=
True
)
subdirs
=
[
'
renders
'
,
'
depths
'
]
subdirs
=
[
'
renders
'
,
'
depths
'
,
'
segmentations
'
]
for
d
in
subdirs
:
os
.
makedirs
(
os
.
path
.
join
(
render_path
,
d
),
exist_ok
=
True
)
...
...
@@ -155,12 +163,13 @@ def process_scene(sceneid):
with
torch
.
no_grad
():
z
=
model
.
encoder
(
input_images
,
input_camera_pos
,
input_rays
)
print
(
'
Rendering frames...
'
)
render3d
(
trainer
,
render_path
,
z
,
input_camera_pos
[:,
0
],
motion
=
args
.
motion
,
transform
=
transform
,
resolution
=
resolution
,
**
render_kwargs
)
if
not
args
.
novideo
:
compile_video_render
(
render_path
)
print
(
'
Compiling plot video...
'
)
compile_video_plot
(
render_path
,
frames
=
True
,
num_frames
=
args
.
num_frames
)
if
__name__
==
'
__main__
'
:
# Arguments
...
...
@@ -201,7 +210,7 @@ if __name__ == '__main__':
else
:
render_kwargs
=
dict
()
model
=
SRT
(
cfg
[
'
model
'
]).
to
(
device
)
model
=
O
SRT
(
cfg
[
'
model
'
]).
to
(
device
)
model
.
eval
()
mode
=
'
train
'
if
args
.
train
else
'
val
'
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment