Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
2c7468b2
Commit
2c7468b2
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Check fsdp
parent
43291e42
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
osrt/trainer.py
+27
-184
27 additions, 184 deletions
osrt/trainer.py
osrt/utils/common.py
+2
-0
2 additions, 0 deletions
osrt/utils/common.py
segment-anything
+1
-1
1 addition, 1 deletion
segment-anything
with
30 additions
and
185 deletions
osrt/trainer.py
+
27
−
184
View file @
2c7468b2
import
torch
import
torch.distributed
as
dist
import
numpy
as
np
from
tqdm
import
tqdm
...
...
@@ -11,80 +12,28 @@ import os
import
math
from
collections
import
defaultdict
class
OSRTSamTrainer
:
def
__init__
(
self
,
model
,
optimizer
,
cfg
,
device
,
out_dir
,
render_kwargs
):
self
.
model
=
model
self
.
optimizer
=
optimizer
self
.
config
=
cfg
self
.
device
=
device
self
.
out_dir
=
out_dir
self
.
render_kwargs
=
render_kwargs
if
'
num_coarse_samples
'
in
cfg
[
'
training
'
]:
self
.
render_kwargs
[
'
num_coarse_samples
'
]
=
cfg
[
'
training
'
][
'
num_coarse_samples
'
]
if
'
num_fine_samples
'
in
cfg
[
'
training
'
]:
self
.
render_kwargs
[
'
num_fine_samples
'
]
=
cfg
[
'
training
'
][
'
num_fine_samples
'
]
def
evaluate
(
self
,
val_loader
,
**
kwargs
):
'''
Performs an evaluation.
Args:
val_loader (dataloader): pytorch dataloader
'''
self
.
model
.
eval
()
eval_lists
=
defaultdict
(
list
)
def
train
(
args
,
model
,
rank
,
world_size
,
train_loader
,
optimizer
,
epoch
):
ddp_loss
=
torch
.
zeros
(
2
).
to
(
rank
)
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
model
.
train
()
optimizer
.
zero_grad
()
loader
=
val_loader
if
get_rank
()
>
0
else
tqdm
(
val_loader
)
sceneids
=
[]
input_images
=
data
.
get
(
'
input_images
'
).
to
(
rank
)
input_camera_pos
=
data
.
get
(
'
input_camera_pos
'
).
to
(
rank
)
input_rays
=
data
.
get
(
'
input_rays
'
).
to
(
rank
)
target_pixels
=
data
.
get
(
'
target_pixels
'
).
to
(
rank
)
for
data
in
loader
:
sceneids
.
append
(
data
[
'
sceneid
'
])
eval_step_dict
=
self
.
eval_step
(
data
,
**
kwargs
)
for
k
,
v
in
eval_step_dict
.
items
():
eval_lists
[
k
].
append
(
v
)
sceneids
=
torch
.
cat
(
sceneids
,
0
).
cuda
()
sceneids
=
torch
.
cat
(
gather_all
(
sceneids
),
0
)
print
(
f
'
Evaluated
{
len
(
torch
.
unique
(
sceneids
))
}
unique scenes.
'
)
eval_dict
=
{
k
:
torch
.
cat
(
v
,
0
)
for
k
,
v
in
eval_lists
.
items
()}
eval_dict
=
reduce_dict
(
eval_dict
,
average
=
True
)
# Average across processes
eval_dict
=
{
k
:
v
.
mean
().
item
()
for
k
,
v
in
eval_dict
.
items
()}
# Average across batch_size
print
(
'
Evaluation results:
'
)
print
(
eval_dict
)
return
eval_dict
def
train_step
(
self
,
data
,
it
):
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
loss
,
loss_terms
=
self
.
compute_loss
(
data
,
it
)
loss
=
loss
.
mean
(
0
)
loss_terms
=
{
k
:
v
.
mean
(
0
).
item
()
for
k
,
v
in
loss_terms
.
items
()}
loss
.
backward
()
self
.
optimizer
.
step
()
return
loss
.
item
(),
loss_terms
def
compute_loss
(
self
,
data
,
it
):
device
=
self
.
device
input_images
=
data
.
get
(
'
input_images
'
).
to
(
device
)
input_camera_pos
=
data
.
get
(
'
input_camera_pos
'
).
to
(
device
)
input_rays
=
data
.
get
(
'
input_rays
'
).
to
(
device
)
target_pixels
=
data
.
get
(
'
target_pixels
'
).
to
(
device
)
input_images
=
input_images
.
permute
(
0
,
1
,
3
,
4
,
2
)
# from [b, k, c, h, w] to [b, k, h, w, c]
h
,
w
,
c
=
input_images
[
0
][
0
].
shape
with
torch
.
cuda
.
amp
.
autocast
():
z
=
self
.
model
.
encoder
(
input_images
,
(
h
,
w
),
input_camera_pos
,
input_rays
)
z
=
model
.
encoder
(
input_images
,
input_camera_pos
,
input_rays
)
target_camera_pos
=
data
.
get
(
'
target_camera_pos
'
).
to
(
device
)
target_rays
=
data
.
get
(
'
target_rays
'
).
to
(
device
)
target_camera_pos
=
data
.
get
(
'
target_camera_pos
'
).
to
(
rank
)
target_rays
=
data
.
get
(
'
target_rays
'
).
to
(
rank
)
loss
=
0.
loss_terms
=
dict
()
with
torch
.
cuda
.
amp
.
autocast
():
pred_pixels
,
extras
=
self
.
model
.
decoder
(
z
,
target_camera_pos
,
target_rays
,
**
self
.
render_kwargs
)
pred_pixels
,
extras
=
model
.
decoder
(
z
,
target_camera_pos
,
target_rays
)
#
, **self.render_kwargs)
loss
=
loss
+
((
pred_pixels
-
target_pixels
)
**
2
).
mean
((
1
,
2
))
loss_terms
[
'
mse
'
]
=
loss
...
...
@@ -95,7 +44,7 @@ class OSRTSamTrainer:
if
'
segmentation
'
in
extras
:
pred_seg
=
extras
[
'
segmentation
'
]
true_seg
=
data
[
'
target_masks
'
].
to
(
device
).
float
()
true_seg
=
data
[
'
target_masks
'
].
to
(
rank
).
float
()
# These are not actually used as part of the training loss.
# We just add the to the dict to report them.
...
...
@@ -103,126 +52,20 @@ class OSRTSamTrainer:
pred_seg
.
transpose
(
1
,
2
))
loss_terms
[
'
fg_ari
'
]
=
compute_adjusted_rand_index
(
true_seg
.
transpose
(
1
,
2
)[:,
1
:],
pred_seg
.
transpose
(
1
,
2
))
return
loss
,
loss_terms
def
eval_step
(
self
,
data
,
full_scale
=
False
):
with
torch
.
no_grad
():
loss
,
loss_terms
=
self
.
compute_loss
(
data
,
1000000
)
pred_seg
.
transpose
(
1
,
2
))
mse
=
loss_terms
[
'
mse
'
]
psnr
=
mse2psnr
(
mse
)
return
{
'
psnr
'
:
psnr
,
'
mse
'
:
mse
,
**
loss_terms
}
def
render_image
(
self
,
z
,
camera_pos
,
rays
,
**
render_kwargs
):
"""
Args:
z [n, k, c]: set structured latent variables
camera_pos [n, 3]: camera position
rays [n, h, w, 3]: ray directions
render_kwargs: kwargs passed on to decoder
"""
batch_size
,
height
,
width
=
rays
.
shape
[:
3
]
rays
=
rays
.
flatten
(
1
,
2
)
camera_pos
=
camera_pos
.
unsqueeze
(
1
).
repeat
(
1
,
rays
.
shape
[
1
],
1
)
max_num_rays
=
self
.
config
[
'
data
'
][
'
num_points
'
]
*
\
self
.
config
[
'
training
'
][
'
batch_size
'
]
//
(
rays
.
shape
[
0
]
*
get_world_size
())
num_rays
=
rays
.
shape
[
1
]
img
=
torch
.
zeros_like
(
rays
)
all_extras
=
[]
for
i
in
range
(
0
,
num_rays
,
max_num_rays
):
img
[:,
i
:
i
+
max_num_rays
],
extras
=
self
.
model
.
decoder
(
z
,
camera_pos
[:,
i
:
i
+
max_num_rays
],
rays
[:,
i
:
i
+
max_num_rays
],
**
render_kwargs
)
all_extras
.
append
(
extras
)
agg_extras
=
{}
for
key
in
all_extras
[
0
]:
agg_extras
[
key
]
=
torch
.
cat
([
extras
[
key
]
for
extras
in
all_extras
],
1
)
agg_extras
[
key
]
=
agg_extras
[
key
].
view
(
batch_size
,
height
,
width
,
-
1
)
img
=
img
.
view
(
img
.
shape
[
0
],
height
,
width
,
3
)
return
img
,
agg_extras
def
visualize
(
self
,
data
,
mode
=
'
val
'
):
self
.
model
.
eval
()
with
torch
.
no_grad
():
device
=
self
.
device
input_images
=
data
.
get
(
'
input_images
'
).
to
(
device
)
input_camera_pos
=
data
.
get
(
'
input_camera_pos
'
).
to
(
device
)
input_rays
=
data
.
get
(
'
input_rays
'
).
to
(
device
)
camera_pos_base
=
input_camera_pos
[:,
0
]
input_rays_base
=
input_rays
[:,
0
]
if
'
transform
'
in
data
:
# If the data is transformed in some different coordinate system, where
# rotating around the z axis doesn't make sense, we first undo this transform,
# then rotate, and then reapply it.
transform
=
data
[
'
transform
'
].
to
(
device
)
inv_transform
=
torch
.
inverse
(
transform
)
camera_pos_base
=
nerf
.
transform_points_torch
(
camera_pos_base
,
inv_transform
)
input_rays_base
=
nerf
.
transform_points_torch
(
input_rays_base
,
inv_transform
.
unsqueeze
(
1
).
unsqueeze
(
2
),
translate
=
False
)
else
:
transform
=
None
input_images_np
=
np
.
transpose
(
input_images
.
cpu
().
numpy
(),
(
0
,
1
,
3
,
4
,
2
))
z
=
self
.
model
.
encoder
(
input_images
,
input_camera_pos
,
input_rays
)
batch_size
,
num_input_images
,
height
,
width
,
_
=
input_rays
.
shape
num_angles
=
6
columns
=
[]
for
i
in
range
(
num_input_images
):
header
=
'
input
'
if
num_input_images
==
1
else
f
'
input
{
i
+
1
}
'
columns
.
append
((
header
,
input_images_np
[:,
i
],
'
image
'
))
if
'
input_masks
'
in
data
:
input_mask
=
data
[
'
input_masks
'
][:,
0
]
columns
.
append
((
'
true seg 0°
'
,
input_mask
.
argmax
(
-
1
),
'
clustering
'
))
row_labels
=
None
for
i
in
range
(
num_angles
):
angle
=
i
*
(
2
*
math
.
pi
/
num_angles
)
angle_deg
=
(
i
*
360
)
//
num_angles
camera_pos_rot
=
nerf
.
rotate_around_z_axis_torch
(
camera_pos_base
,
angle
)
rays_rot
=
nerf
.
rotate_around_z_axis_torch
(
input_rays_base
,
angle
)
if
transform
is
not
None
:
camera_pos_rot
=
nerf
.
transform_points_torch
(
camera_pos_rot
,
transform
)
rays_rot
=
nerf
.
transform_points_torch
(
rays_rot
,
transform
.
unsqueeze
(
1
).
unsqueeze
(
2
),
translate
=
False
)
img
,
extras
=
self
.
render_image
(
z
,
camera_pos_rot
,
rays_rot
,
**
self
.
render_kwargs
)
columns
.
append
((
f
'
render
{
angle_deg
}
°
'
,
img
.
cpu
().
numpy
(),
'
image
'
))
if
'
depth
'
in
extras
:
depth_img
=
extras
[
'
depth
'
].
unsqueeze
(
-
1
)
/
self
.
render_kwargs
[
'
max_dist
'
]
depth_img
=
depth_img
.
view
(
batch_size
,
height
,
width
,
1
)
columns
.
append
((
f
'
depths
{
angle_deg
}
°
'
,
depth_img
.
cpu
().
numpy
(),
'
image
'
))
if
'
segmentation
'
in
extras
:
pred_seg
=
extras
[
'
segmentation
'
].
cpu
()
columns
.
append
((
f
'
pred seg
{
angle_deg
}
°
'
,
pred_seg
.
argmax
(
-
1
).
numpy
(),
'
clustering
'
))
if
i
==
0
:
ari
=
compute_adjusted_rand_index
(
input_mask
.
flatten
(
1
,
2
).
transpose
(
1
,
2
)[:,
1
:],
pred_seg
.
flatten
(
1
,
2
).
transpose
(
1
,
2
))
row_labels
=
[
'
2D Fg-ARI={:.1f}
'
.
format
(
x
.
item
()
*
100
)
for
x
in
ari
]
output_img_path
=
os
.
path
.
join
(
self
.
out_dir
,
f
'
renders-
{
mode
}
'
)
vis
.
draw_visualization_grid
(
columns
,
output_img_path
,
row_labels
=
row_labels
)
loss
=
loss
.
mean
(
0
)
loss_terms
=
{
k
:
v
.
mean
(
0
).
item
()
for
k
,
v
in
loss_terms
.
items
()}
loss
.
backward
()
optimizer
.
step
()
ddp_loss
[
0
]
+=
loss
.
item
()
ddp_loss
[
1
]
+=
len
(
input_images
)
dist
.
all_reduce
(
ddp_loss
,
op
=
dist
.
ReduceOp
.
SUM
)
if
rank
==
0
:
print
(
'
Train Epoch: {}
\t
Loss: {:.6f}
'
.
format
(
epoch
,
ddp_loss
[
0
]
/
ddp_loss
[
1
]))
class
SRTTrainer
:
def
__init__
(
self
,
model
,
optimizer
,
cfg
,
device
,
out_dir
,
render_kwargs
):
self
.
model
=
model
...
...
This diff is collapsed.
Click to expand it.
osrt/utils/common.py
+
2
−
0
View file @
2c7468b2
...
...
@@ -189,6 +189,8 @@ def init_ddp():
setup_dist_print
(
local_rank
==
0
)
return
local_rank
,
world_size
def
cleanup
():
dist
.
destroy_process_group
()
def
setup_dist_print
(
is_main
):
import
builtins
as
__builtin__
...
...
This diff is collapsed.
Click to expand it.
segment-anything
@
f7b29ba9
Compare
6fdee8f2
...
f7b29ba9
Subproject commit
6fdee8f2727f4506cfbbe553e23b895e27956588
Subproject commit
f7b29ba9df1496489af8c71a4bdabed7e8b017b1
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment