Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
Segment-Object-Centric
Manage
Activity
Members
Labels
Plan
Issues
3
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Alexandre Chapin
Segment-Object-Centric
Commits
a037d337
Commit
a037d337
authored
2 years ago
by
Alexandre Chapin
Browse files
Options
Downloads
Patches
Plain Diff
Add beginning of batch forward and multi-images
parent
61bf2c1e
No related branches found
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
automatic_mask_train.py
+44
-39
44 additions, 39 deletions
automatic_mask_train.py
osrt/encoder.py
+174
-53
174 additions, 53 deletions
osrt/encoder.py
osrt/layers.py
+3
-3
3 additions, 3 deletions
osrt/layers.py
with
221 additions
and
95 deletions
automatic_mask_train.py
+
44
−
39
View file @
a037d337
...
@@ -68,50 +68,55 @@ if __name__ == '__main__':
...
@@ -68,50 +68,55 @@ if __name__ == '__main__':
])
])
labels
=
[
1
for
i
in
range
(
len
(
sam_mask
.
points_grid
))]
labels
=
[
1
for
i
in
range
(
len
(
sam_mask
.
points_grid
))]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
j
=
0
images
=
[]
for
image
in
images_path
:
for
j
in
range
(
16
):
image
=
images_path
[
j
]
#import os
#import os
#os.mkdir(f"./results/test_{j}")
#os.mkdir(f"./results/test_{j}")
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
imread
(
image
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
images
.
append
(
np
.
expand_dims
(
img
,
axis
=
0
))
images_np
=
np
.
array
(
images
)
h
,
w
,
_
=
images_np
[
0
][
0
].
shape
points
=
sam_mask
.
points_grid
new_points
=
[]
for
val
in
points
:
x
,
y
=
val
[
0
],
val
[
1
]
x
*=
w
y
*=
h
new_points
.
append
([
x
,
y
])
new_points
=
np
.
array
(
new_points
)
start
=
time
.
time
()
camera_pos
=
torch
.
tensor
(
np
.
expand_dims
(
np
.
expand_dims
(
np
.
array
([
0
,
0
,
0
]),
axis
=
0
),
axis
=
0
))
# TODO : set the right direction
ray_dir
=
torch
.
tensor
(
np
.
expand_dims
(
np
.
expand_dims
(
np
.
expand_dims
(
np
.
expand_dims
(
np
.
array
([
0
,
0
,
0
]),
axis
=
0
),
axis
=
0
),
axis
=
0
),
axis
=
0
))
h
,
w
,
_
=
img
.
shape
print
(
camera_pos
.
shape
)
points
=
sam_mask
.
points_grid
print
(
ray_dir
.
shape
)
new_points
=
[]
for
val
in
points
:
x
,
y
=
val
[
0
],
val
[
1
]
x
*=
w
y
*=
h
new_points
.
append
([
x
,
y
])
new_points
=
np
.
array
(
new_points
)
img_batch
=
[]
img_el
=
{}
img_el
[
"
image
"
]
=
img
img_el
[
"
original_size
"
]
=
(
h
,
w
)
img_batch
.
append
(
img_el
)
start
=
time
.
time
()
masks
=
sam_mask
(
images_np
,
[(
h
,
w
)],
extract_embeddings
=
True
)
masks
=
sam_mask
(
img_batch
,
extract_embeddings
=
True
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
print
(
f
"
Inference time :
{
int
((
end
-
start
)
*
1000
)
}
ms
"
)
"""
plt.figure(figsize=(15,15))
plt
.
figure
(
figsize
=
(
15
,
15
))
plt.imshow(img)
plt
.
imshow
(
img
)
show_anns(masks[0][
"
annotations
"
])
show_anns
(
masks
[
0
][
"
annotations
"
])
show_points(new_points, plt.gca())
show_points
(
new_points
,
plt
.
gca
())
#plt.savefig(f
"
./results/test_{j}/masks.png
"
)
#plt.savefig(f"./results/test_{j}/masks.png")
plt.axis(
'
off
'
)
plt
.
axis
(
'
off
'
)
plt.show()
i = 0
"""
"""
for mask in masks[0][
"
annotations
"
]:
cm = matplotlib.cm.get_cmap(
'
plasma
'
)
plt.imshow(mask[
"
embeddings
"
])
plt.show()
im = cm(mask[
"
embeddings
"
])
#im = np.uint8(im * 255)
plt.imshow(im)
plt.show()
plt.show()
i
=
0
#im = Image.fromarray(im)
"""
for mask in masks[0][
"
annotations
"
]:
#im.save(f
"
./results/test_{j}/mask_{i}.png
"
)
cm = matplotlib.cm.get_cmap(
'
plasma
'
)
i+=1
plt.imshow(mask[
"
embeddings
"
])
j+=1
"""
plt.show()
im = cm(mask[
"
embeddings
"
])
#im = np.uint8(im * 255)
plt.imshow(im)
plt.show()
#im = Image.fromarray(im)
#im.save(f
"
./results/test_{j}/mask_{i}.png
"
)
i+=1
j+=1
"""
This diff is collapsed.
Click to expand it.
osrt/encoder.py
+
174
−
53
View file @
a037d337
...
@@ -245,86 +245,140 @@ class SamAutomaticMask(nn.Module):
...
@@ -245,86 +245,140 @@ class SamAutomaticMask(nn.Module):
self
.
ray_encoder
=
RayEncoder
(
pos_octaves
=
15
,
pos_start_octave
=
pos_start_octave
,
self
.
ray_encoder
=
RayEncoder
(
pos_octaves
=
15
,
pos_start_octave
=
pos_start_octave
,
ray_octaves
=
15
)
ray_octaves
=
15
)
@property
@property
def
device
(
self
)
->
Any
:
def
device
(
self
)
->
Any
:
return
self
.
pixel_mean
.
device
return
self
.
pixel_mean
.
device
def
forward
(
def
forward
(
self
,
self
,
batched_input
:
List
[
Dict
[
str
,
Any
]],
images
,
orig_size
,
camera_pos
=
None
,
camera_pos
=
None
,
rays
=
None
,
rays
=
None
,
extract_embeddings
:
bool
=
False
,
extract_embeddings
=
False
):
)
->
List
[
Dict
[
str
,
torch
.
Tensor
]]:
"""
"""
Predicts masks end-to-end from provided images and prompts.
Args:
If prompts are not known in advance, using SamPredictor is
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
recommended over calling the model directly.
original_size: tuple(height, width) The original size of the image before transformation.
camera_pos: [batch_size, num_images, 3]
Arguments:
rays: [batch_size, num_images, height, width, 3]
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
'
image
'
: The image as in 3xHxW format
'
original_size
'
: (tuple(int, int)) The original size of
the image before transformation, as (H, W).
'
point_coords
'
: (torch.Tensor) Batched point prompts for
this image, with shape BxNx2. Already transformed to the
input frame of the model.
'
point_labels
'
: (torch.Tensor) Batched labels for point prompts,
with shape BxN.
'
boxes
'
: (torch.Tensor) Batched box inputs, with shape Bx4.
Already transformed to the input frame of the model.
'
mask_inputs
'
: (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
Returns:
Returns:
(list(dict)): A list over input images, where each element is
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
as dictionary with the following keys.
'
masks
'
: (torch.Tensor) Batched binary mask predictions,
masks: (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
C is determined by multimask_output, and (H, W) is the
original size of the image.
original size of the image.
'
iou_predictions
'
: (torch.Tensor) The model
'
s predictions
iou_predictions: (torch.Tensor) The model
'
s predictions
of mask quality, in shape BxC.
of mask quality, in shape BxC.
'
low_res_logits
'
: (torch.Tensor) Low resolution logits with
low_res_logits: (torch.Tensor) Low resolution logits with
shape BxCxHxW, where H=W=256. Can be passed as mask input
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
to subsequent iterations of prediction.
scene representation: [batch_size, num_patches, channels_per_patch]
"""
"""
# Extract image embeddings
# TODO : handle the following to concatenate token with camera position
# TODO : handle the following to concatenate token with camera position
"""
camera_pos = camera_pos.flatten(0, 1)
# Encode camera and position and direction following SRT's paper
rays = rays.flatten(0, 1)
"""
if len(camera_pos) > 0 and len(rays) > 0 and extract_embeddings:
camera_pos = camera_pos.flatten(0, 1)
ray_enc = self.ray_encoder(camera_pos, rays)
rays = rays.flatten(0, 1)
x = torch.cat((x, ray_enc), 1)
"""
ray_enc = self.ray_encoder(camera_pos, rays).to(self.device)
#x = torch.cat((x, ray_enc), 1)
"""
B
,
N
,
H
,
W
,
C
=
images
.
shape
outputs
=
[]
for
batch
in
range
(
B
):
input_images
=
torch
.
zeros
(
0
,
C
,
self
.
image_encoder
.
img_size
,
self
.
image_encoder
.
img_size
).
to
(
self
.
device
)
for
img
in
images
[
batch
]:
input_images
=
torch
.
cat
((
input_images
,
self
.
preprocess
(
img
)),
dim
=
0
)
with
torch
.
no_grad
():
image_embeddings
,
embed_no_red
=
self
.
image_encoder
(
input_images
,
before_channel_reduc
=
True
)
for
n
in
range
(
len
(
input_images
)):
curr_embedding
=
image_embeddings
[
n
]
curr_emb_no_red
=
embed_no_red
[
n
]
image_record
=
input_images
[
n
]
im_size
=
self
.
transform
.
apply_image
(
image_record
).
shape
[:
2
]
points_scale
=
np
.
array
(
im_size
)[
None
,
::
-
1
]
points_for_image
=
self
.
points_grid
*
points_scale
mask_data
=
MaskData
()
for
(
points
,)
in
batch_iterator
(
self
.
points_per_batch
,
points_for_image
):
batch_data
=
self
.
process_batch
(
points
,
im_size
,
curr_embedding
,
orig_size
)
mask_data
.
cat
(
batch_data
)
del
batch_data
del
curr_embedding
# Remove duplicates
keep_by_nms
=
batched_nms
(
mask_data
[
"
boxes
"
].
float
(),
mask_data
[
"
iou_preds
"
],
torch
.
zeros_like
(
mask_data
[
"
boxes
"
][:,
0
]),
# categories
iou_threshold
=
self
.
box_nms_thresh
,
)
mask_data
.
filter
(
keep_by_nms
)
input_images
=
[
self
.
preprocess
(
x
[
"
image
"
])
for
x
in
batched_input
][
0
]
mask_data
[
"
segmentations
"
]
=
mask_data
[
"
masks
"
]
with
torch
.
no_grad
():
image_embeddings
,
embed_no_red
=
self
.
image_encoder
(
input_images
,
before_channel_reduc
=
True
)
# Extract mask embeddings
# Handle images
if
extract_embeddings
:
outputs
=
[]
self
.
extract_mask_embedding
(
mask_data
,
curr_emb_no_red
,
im_size
,
scale_box
=
1.5
)
for
image_record
,
curr_embedding
,
curr_emb_no_red
in
zip
(
batched_input
,
image_embeddings
,
embed_no_red
):
"""
print(f
"
Before concat : {mask_data[
'
embeddings
'
][0].shape}, len {len(mask_data[
'
embeddings
'
])}
"
)
# TODO : check if we've got the points given in the batch (to change the current point_grid !)
for tensor in mask_data[
'
embeddings
'
]:
im_size
=
self
.
transform
.
apply_image
(
image_record
[
"
image
"
]).
shape
[:
2
]
print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break
"""
#mask_data['embeddings'] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data['embeddings']]
#print(f"After concat : {mask_data['embeddings'][0].shape}, len {len(mask_data['embeddings'])}")
mask_data
.
to_numpy
()
# Filter small disconnected regions and holes in masks NOT USED
if
self
.
min_mask_region_area
>
0
:
mask_data
=
self
.
postprocess_small_regions
(
mask_data
,
self
.
min_mask_region_area
,
self
.
box_nms_thresh
,
)
# Write mask records
curr_anns
=
[]
for
idx
in
range
(
len
(
mask_data
[
"
segmentations
"
])):
ann
=
{
"
segmentation
"
:
mask_data
[
"
segmentations
"
][
idx
],
"
area
"
:
area_from_rle
(
mask_data
[
"
rles
"
][
idx
]),
"
predicted_iou
"
:
mask_data
[
"
iou_preds
"
][
idx
].
item
(),
"
point_coords
"
:
[
mask_data
[
"
points
"
][
idx
].
tolist
()],
"
stability_score
"
:
mask_data
[
"
stability_score
"
][
idx
].
item
()
}
if
extract_embeddings
:
ann
[
"
embeddings
"
]
=
mask_data
[
"
embeddings
"
][
idx
]
curr_anns
.
append
(
ann
)
outputs
.
append
(
{
"
annotations
"
:
curr_anns
}
)
# Extract masks for individual images
"""
for image_record, orig_size, curr_embedding, curr_emb_no_red in zip(images, orinal_size_img, image_embeddings, embed_no_red):
im_size = self.transform.apply_image(image_record).shape[:2]
points_scale = np.array(im_size)[None, ::-1]
points_scale = np.array(im_size)[None, ::-1]
points_for_image = self.points_grid * points_scale
points_for_image = self.points_grid * points_scale
mask_data = MaskData()
mask_data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data
=
self
.
process_batch
(
points
,
im_size
,
curr_embedding
,
image_record
[
"
original
_size
"
]
)
batch_data = self.process_batch(points, im_size, curr_embedding,
orig
_size)
mask_data.cat(batch_data)
mask_data.cat(batch_data)
del batch_data
del batch_data
del curr_embedding
del curr_embedding
# Remove duplicates
within this crop.
# Remove duplicates
keep_by_nms = batched_nms(
keep_by_nms = batched_nms(
mask_data[
"
boxes
"
].float(),
mask_data[
"
boxes
"
].float(),
mask_data[
"
iou_preds
"
],
mask_data[
"
iou_preds
"
],
...
@@ -335,8 +389,19 @@ class SamAutomaticMask(nn.Module):
...
@@ -335,8 +389,19 @@ class SamAutomaticMask(nn.Module):
mask_data[
"
segmentations
"
] = mask_data[
"
masks
"
]
mask_data[
"
segmentations
"
] = mask_data[
"
masks
"
]
# Extract mask embeddings
if extract_embeddings:
if extract_embeddings:
self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
self.extract_mask_embedding(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
print(f
"
Before concat : {mask_data[
'
embeddings
'
][0].shape}, len {len(mask_data[
'
embeddings
'
])}
"
)
for tensor in mask_data[
'
embeddings
'
]:
print(tensor.shape)
print(ray_enc.shape)
final = torch.cat((tensor, ray_enc), 1)
print(final.shape)
break
#mask_data[
'
embeddings
'
] = [torch.cat((tensor, ray_enc), 1) for tensor in mask_data[
'
embeddings
'
]]
#print(f
"
After concat : {mask_data[
'
embeddings
'
][0].shape}, len {len(mask_data[
'
embeddings
'
])}
"
)
mask_data.to_numpy()
mask_data.to_numpy()
...
@@ -366,7 +431,7 @@ class SamAutomaticMask(nn.Module):
...
@@ -366,7 +431,7 @@ class SamAutomaticMask(nn.Module):
"
annotations
"
: curr_anns
"
annotations
"
: curr_anns
}
}
)
)
"""
return
outputs
return
outputs
def
postprocess_masks
(
def
postprocess_masks
(
...
@@ -645,7 +710,7 @@ class SamAutomaticMask(nn.Module):
...
@@ -645,7 +710,7 @@ class SamAutomaticMask(nn.Module):
mask_embed
+=
pos_embedding
mask_embed
+=
pos_embedding
# Apply mask to image embedding
# Apply mask to image embedding
mask_data
[
"
embeddings
"
].
append
(
mask_embed
)
# [token_dim]
mask_data
[
"
embeddings
"
].
append
(
torch
.
tensor
(
mask_embed
,
device
=
self
.
device
)
)
# [token_dim]
def
complete_holes
(
self
,
def
complete_holes
(
self
,
masks
):
masks
):
...
@@ -690,4 +755,60 @@ class SamAutomaticMask(nn.Module):
...
@@ -690,4 +755,60 @@ class SamAutomaticMask(nn.Module):
new_masks_data
[
"
rles
"
]
=
mask_to_rle_pytorch
(
new_masks_data
[
"
masks
"
])
new_masks_data
[
"
rles
"
]
=
mask_to_rle_pytorch
(
new_masks_data
[
"
masks
"
])
return
new_masks_data
.
to_numpy
()
return
new_masks_data
.
to_numpy
()
\ No newline at end of file
def
position_embeding_3d
(
self
,
img_feats
,
camera_info
):
# TODO : adapter cette fonction à notre usage
"""
3D position embedding on image features following PETR
'
s work in :
https://github.com/megvii-research/PETR/blob/main/projects/mmdet3d_plugin/models/dense_heads/petr_head.py#L282
"""
eps
=
1e-5
B
,
N
,
C
,
H
,
W
=
img_feats
.
shape
coords_h
=
torch
.
arange
(
H
,
device
=
self
.
device
).
float
()
coords_w
=
torch
.
arange
(
W
,
device
=
self
.
device
).
float
()
# TODO : checker ces deux nombres et voir quoi faire avec
depth_num
=
64
position_range
=
[
-
65
,
-
65
,
-
8.0
,
65
,
65
,
8.0
]
# TODO : set cette valeur avec les bon ranges
# [xmin, ymin zmin, xmax, ymax, zmax] ROI 3D world space
depth_start
=
1
# END TODO
index
=
torch
.
arange
(
start
=
0
,
end
=
depth_num
,
device
=
self
.
device
).
float
()
bin_size
=
(
position_range
[
3
]
-
depth_start
)
/
depth_num
coords_d
=
depth_start
+
bin_size
*
index
D
=
coords_d
.
shape
[
0
]
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_w
,
coords_h
,
coords_d
])).
permute
(
1
,
2
,
3
,
0
)
# W, H, D, 3
coords
=
torch
.
cat
((
coords
,
torch
.
ones_like
(
coords
[...,
:
1
])),
-
1
)
coords
[...,
:
2
]
=
coords
[...,
:
2
]
*
torch
.
maximum
(
coords
[...,
2
:
3
],
torch
.
ones_like
(
coords
[...,
2
:
3
])
*
eps
)
# EXTRACT EXTRINSICS
camera_extrinsic
=
camera_info
[
"
extrinsics
"
]
# (B, N, 4, 4) # TODO : prendre cette valeur avec infos camera
# Apply Transform
coords
=
coords
.
view
(
1
,
1
,
W
,
H
,
D
,
4
,
1
).
repeat
(
B
,
N
,
1
,
1
,
1
,
1
,
1
)
camera_extrinsic
=
camera_extrinsic
.
view
(
B
,
N
,
1
,
1
,
1
,
4
,
4
).
repeat
(
1
,
1
,
W
,
H
,
D
,
1
,
1
)
coords3d
=
torch
.
matmul
(
camera_extrinsic
,
coords
).
squeeze
(
-
1
)[...,
:
3
]
# Normalize
coords3d
[...,
0
:
1
]
=
(
coords3d
[...,
0
:
1
]
-
position_range
[
0
])
/
(
position_range
[
3
]
-
position_range
[
0
])
coords3d
[...,
1
:
2
]
=
(
coords3d
[...,
1
:
2
]
-
position_range
[
1
])
/
(
position_range
[
4
]
-
position_range
[
1
])
coords3d
[...,
2
:
3
]
=
(
coords3d
[...,
2
:
3
]
-
position_range
[
2
])
/
(
position_range
[
5
]
-
position_range
[
2
])
# Final embedding
coords3d
=
coords3d
.
permute
(
0
,
1
,
4
,
5
,
3
,
2
).
contiguous
().
view
(
B
*
N
,
-
1
,
H
,
W
)
position_dim
=
3
*
depth_num
embed_dims
=
256
position_encoder
=
nn
.
Sequential
(
nn
.
Conv2d
(
position_dim
,
embed_dims
*
4
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
nn
.
ReLU
(),
nn
.
Conv2d
(
embed_dims
*
4
,
embed_dims
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
),
)
coords_position_embeding
=
position_encoder
(
coords3d
)
return
coords_position_embeding
.
view
(
B
,
N
,
embed_dims
,
H
,
W
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
osrt/layers.py
+
3
−
3
View file @
a037d337
...
@@ -58,7 +58,7 @@ class PositionalEncoding(nn.Module):
...
@@ -58,7 +58,7 @@ class PositionalEncoding(nn.Module):
octaves
=
torch
.
arange
(
self
.
start_octave
,
self
.
start_octave
+
self
.
num_octaves
)
octaves
=
torch
.
arange
(
self
.
start_octave
,
self
.
start_octave
+
self
.
num_octaves
)
octaves
=
octaves
.
float
().
to
(
coords
)
octaves
=
octaves
.
float
().
to
(
coords
)
multipliers
=
2
**
octaves
*
math
.
pi
multipliers
=
2
**
octaves
*
math
.
pi
coords
=
coords
.
unsqueeze
(
-
1
)
coords
=
coords
.
unsqueeze
(
-
1
)
while
len
(
multipliers
.
shape
)
<
len
(
coords
.
shape
):
while
len
(
multipliers
.
shape
)
<
len
(
coords
.
shape
):
multipliers
=
multipliers
.
unsqueeze
(
0
)
multipliers
=
multipliers
.
unsqueeze
(
0
)
...
@@ -79,12 +79,12 @@ class RayEncoder(nn.Module):
...
@@ -79,12 +79,12 @@ class RayEncoder(nn.Module):
def
forward
(
self
,
pos
,
rays
):
def
forward
(
self
,
pos
,
rays
):
if
len
(
rays
.
shape
)
==
4
:
if
len
(
rays
.
shape
)
==
4
:
batchsize
,
height
,
width
,
dims
=
rays
.
shape
batchsize
,
height
,
width
,
_
=
rays
.
shape
pos_enc
=
self
.
pos_encoding
(
pos
.
unsqueeze
(
1
))
pos_enc
=
self
.
pos_encoding
(
pos
.
unsqueeze
(
1
))
pos_enc
=
pos_enc
.
view
(
batchsize
,
pos_enc
.
shape
[
-
1
],
1
,
1
)
pos_enc
=
pos_enc
.
view
(
batchsize
,
pos_enc
.
shape
[
-
1
],
1
,
1
)
pos_enc
=
pos_enc
.
repeat
(
1
,
1
,
height
,
width
)
pos_enc
=
pos_enc
.
repeat
(
1
,
1
,
height
,
width
)
rays
=
rays
.
flatten
(
1
,
2
)
rays
=
rays
.
flatten
(
1
,
2
)
ray_enc
=
self
.
ray_encoding
(
rays
)
ray_enc
=
self
.
ray_encoding
(
rays
)
ray_enc
=
ray_enc
.
view
(
batchsize
,
height
,
width
,
ray_enc
.
shape
[
-
1
])
ray_enc
=
ray_enc
.
view
(
batchsize
,
height
,
width
,
ray_enc
.
shape
[
-
1
])
ray_enc
=
ray_enc
.
permute
((
0
,
3
,
1
,
2
))
ray_enc
=
ray_enc
.
permute
((
0
,
3
,
1
,
2
))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment