Skip to content
Snippets Groups Projects
Commit bf9c2973 authored by Karl Stelzner's avatar Karl Stelzner
Browse files

Turn MixingBlock into full attention layer

parent 78846365
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,8 @@ from osrt.utils import nerf
class RayPredictor(nn.Module):
def __init__(self, num_att_blocks=2, pos_start_octave=0, out_dims=3, input_mlp=None, output_mlp=None):
def __init__(self, num_att_blocks=2, pos_start_octave=0, out_dims=3, input_mlp=None, output_mlp=None,
z_dim=1536):
super().__init__()
if input_mlp is not None: # Input MLP added with OSRT
self.input_mlp = nn.Sequential(
......@@ -19,13 +20,24 @@ class RayPredictor(nn.Module):
self.query_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
ray_octaves=15)
self.transformer = Transformer(180, depth=num_att_blocks, heads=12, dim_head=128,
mlp_dim=3072, selfatt=False, kv_dim=1536)
self.transformer = Transformer(180, depth=num_att_blocks, heads=12, dim_head=z_dim // 12,
mlp_dim=z_dim * 2, selfatt=False, kv_dim=z_dim)
if output_mlp is not None:
self.output_mlp = nn.Sequential(
nn.Linear(180, 128),
nn.Linear(180, 1536),
nn.ReLU(),
nn.Linear(1536, 1536),
nn.ReLU(),
nn.Linear(1536, 1536),
nn.ReLU(),
nn.Linear(128, out_dims))
nn.Linear(1536, 3),
)
#self.output_mlp = nn.Sequential(
#nn.Linear(180, 128),
#nn.ReLU(),
#nn.Linear(128, out_dims))
else:
self.output_mlp = None
......@@ -36,22 +48,22 @@ class RayPredictor(nn.Module):
x: query camera positions [batch_size, num_rays, 3]
rays: query ray directions [batch_size, num_rays, 3]
"""
orig_queries = queries = self.query_encoder(x, rays)
queries = self.query_encoder(x, rays)
if self.input_mlp is not None:
queries = self.input_mlp(queries)
output = self.transformer(queries, z)
if self.output_mlp is not None:
output = self.output_mlp(output)
return output, orig_queries
return output, queries
class SRTDecoder(nn.Module):
def __init__(self, num_att_blocks=2, pos_start_octave=0):
super().__init__()
self.ray_predictor = RayPredictor(num_att_blocks=num_att_blocks,
pos_start_octave=pos_start_octave,
out_dims=3, output_mlp=True)
pos_start_octave=pos_start_octave, input_mlp=True,
out_dims=3, output_mlp=True, z_dim=768)
def forward(self, z, x, rays, **kwargs):
output, _ = self.ray_predictor(z, x, rays)
......@@ -63,14 +75,22 @@ class MixingBlock(nn.Module):
super().__init__()
self.to_q = nn.Linear(input_dim, att_dim, bias=False)
self.to_k = nn.Linear(slot_dim, att_dim, bias=False)
self.norm1 = nn.LayerNorm(input_dim)
self.norm2 = nn.LayerNorm(slot_dim)
self.scale = att_dim ** -0.5
def forward(self, x, slot_latents):
x = self.norm1(x)
q = self.to_q(x)
k = self.to_k(slot_latents)
w = torch.einsum('bid,bsd->bis', q, k).softmax(dim=2)
dots = torch.einsum('bid,bsd->bis', q, k) * self.scale
w = dots.softmax(dim=2)
s = (w.unsqueeze(-1) * slot_latents.unsqueeze(1)).sum(2)
s = self.norm2(s)
return s, w
......@@ -80,7 +100,7 @@ class SlotMixerDecoder(nn.Module):
super().__init__()
self.allocation_transformer = RayPredictor(num_att_blocks=num_att_blocks,
pos_start_octave=pos_start_octave,
input_mlp=True)
input_mlp=True, z_dim=1536)
self.mixing_block = MixingBlock()
self.render_mlp = nn.Sequential(
nn.Linear(1536 + 180, 1536),
......
......@@ -28,7 +28,7 @@ class TweakedSRTEncoder(nn.Module):
"""
Scene Representation Transformer Encoder with the tweaks from Appendix A.4 in the OSRT paper.
"""
def __init__(self, num_conv_blocks=4, num_att_blocks=10, pos_start_octave=0):
def __init__(self, num_conv_blocks=3, num_att_blocks=5, pos_start_octave=0):
super().__init__()
self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
ray_octaves=15)
......
......@@ -66,6 +66,7 @@ class PreNorm(nn.Module):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
......
from torch import nn
from osrt.encoder import OSRTEncoder
from osrt.encoder import OSRTEncoder, TweakedSRTEncoder
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, SRTDecoder
class OSRT(nn.Module):
def __init__(self, cfg):
super().__init__()
self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
encoder_type = cfg['encoder']
decoder_type = cfg['decoder']
if encoder_type == 'srt':
self.encoder = TweakedSRTEncoder(**cfg['encoder_kwargs'])
elif encoder_type == 'osrt':
self.encoder = OSRTEncoder(**cfg['encoder_kwargs'])
else:
raise ValueError(f'Unknown encoder type: {encoder_type}')
if decoder_type == 'spatial_broadcast':
self.decoder = SpatialBroadcastDecoder(**cfg['decoder_kwargs'])
elif decoder_type == 'srt':
......
......@@ -200,6 +200,7 @@ class SRTTrainer:
columns.append((f'render {angle_deg}°', img.cpu().numpy(), 'image'))
for i, extras in enumerate(all_extras):
angle_deg = (i * 360) // num_angles
if 'depth' in extras:
depth_img = extras['depth'].unsqueeze(-1) / self.render_kwargs['max_dist']
depth_img = depth_img.view(batch_size, height, width, 1)
......
......@@ -5,6 +5,7 @@ data:
kwargs:
downsample: 1
model:
encoder: osrt
encoder_kwargs:
pos_start_octave: -5
num_slots: 7
......
......@@ -192,13 +192,15 @@ if __name__ == '__main__':
num_encoder_params = sum(p.numel() for p in model.encoder.parameters())
num_decoder_params = sum(p.numel() for p in model.decoder.parameters())
num_srt_encoder_params = sum(p.numel() for p in model.encoder.module.srt_encoder.parameters())
num_slotatt_params = sum(p.numel() for p in model.encoder.module.slot_attention.parameters())
print('Number of parameters:')
print(f'\tEncoder: {num_encoder_params}')
print(f'\t\tSRT Encoder: {num_srt_encoder_params}.')
print(f'\t\tSlot Attention: {num_slotatt_params}.')
if cfg['model']['encoder'] == 'osrt':
num_srt_encoder_params = sum(p.numel() for p in model.encoder.module.srt_encoder.parameters())
num_slotatt_params = sum(p.numel() for p in model.encoder.module.slot_attention.parameters())
print(f'\t\tSRT Encoder: {num_srt_encoder_params}.')
print(f'\t\tSlot Attention: {num_slotatt_params}.')
print(f'\tDecoder: {num_decoder_params}')
print(f'Total: {num_encoder_params + num_decoder_params}')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment