diff --git a/.visualisation_1639.png b/.visualisation_1639.png new file mode 100644 index 0000000000000000000000000000000000000000..fece79284e2d145e76f53cc37c395ef07a98cebf Binary files /dev/null and b/.visualisation_1639.png differ diff --git a/osrt/layers.py b/osrt/layers.py index 39f246a20754e7ec7c048abb7002644d144f8608..deaf9d8419b9e18adc6ffffccabeac5afc687f64 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -5,6 +5,7 @@ import numpy as np import math from einops import rearrange, repeat +import torch.nn.functional as F __USE_DEFAULT_INIT__ = False @@ -194,10 +195,11 @@ class SlotAttention(nn.Module): @edit : we changed the code as to make it possible to handle a different number of slots depending on the input images """ def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8, - randomize_initial_slots=False): + randomize_initial_slots=False, gain = 1, temperature_factor = 1): super().__init__() self.num_slots = num_slots + self.temperature_factor = temperature_factor self.batch_slots = [] self.iters = iters self.scale = slot_dim ** -0.5 @@ -207,24 +209,31 @@ class SlotAttention(nn.Module): self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim)) self.eps = eps + self.slots_mu = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim))) + self.slots_log_sigma = nn.Parameter(nn.init.xavier_uniform_(torch.empty(1, 1, self.slot_dim))) + + + self.to_q = nn.Linear(slot_dim, slot_dim, bias=False) + self.to_k = nn.Linear(input_dim, slot_dim, bias=False) + self.to_v = nn.Linear(input_dim, slot_dim, bias=False) - self.to_q = JaxLinear(slot_dim, slot_dim, bias=False) - self.to_k = JaxLinear(input_dim, slot_dim, bias=False) - self.to_v = JaxLinear(input_dim, slot_dim, bias=False) + nn.init.xavier_uniform_(self.to_q.weight, gain = gain) + nn.init.xavier_uniform_(self.to_k.weight, gain = gain) + nn.init.xavier_uniform_(self.to_v.weight, gain = gain) self.gru = nn.GRUCell(slot_dim, slot_dim) self.mlp = nn.Sequential( - JaxLinear(slot_dim, hidden_dim), + nn.Linear(slot_dim, hidden_dim), nn.ReLU(inplace=True), - JaxLinear(hidden_dim, slot_dim) + nn.Linear(hidden_dim, slot_dim) ) self.norm_input = nn.LayerNorm(input_dim) self.norm_slots = nn.LayerNorm(slot_dim) self.norm_pre_mlp = nn.LayerNorm(slot_dim) - def forward(self, inputs, masks=None): + def forward(self, inputs): """ Args: inputs: set-latent representation [batch_size, num_inputs, dim] @@ -232,74 +241,56 @@ class SlotAttention(nn.Module): batch_size, num_inputs, dim = inputs.shape inputs = self.norm_input(inputs) - - # Initialize the slots. Shape: [batch_size, num_slots, slot_dim]. - if self.randomize_initial_slots: - slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) # from [num_slots, slot_dim] to [batch_size, num_slots, slot_dim] - slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample() - else: - slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1).to(inputs.device) k, v = self.to_k(inputs), self.to_v(inputs) + if slots is None: + slots = self.slots_mu + torch.exp(self.slots_log_sigma) * torch.randn(len(inputs), self.num_slots, self.slot_size, device = self.slots_mu.device) + # Multiple rounds of attention. for _ in range(self.iters): slots_prev = slots - norm_slots = self.norm_slots(slots) + slots = self.norm_slots(slots) - q = self.to_q(norm_slots) + q = self.to_q(slots) + q *= self.scale - dots = torch.einsum('bid,bjd->bij', q, k) * self.scale # Dot product and normalization + attn_logits = torch.bmm(q, k.transpose(-1, -2)) + + attn_pixelwise = F.softmax(attn_logits / self.temperature_factor, dim = 1) - if masks != None: - temp_masks = masks.unsqueeze(1) - attention_masking = torch.where(temp_masks == 1.0, float("-inf"), temp_masks).to(device=dots.device) - dots += attention_masking # shape: [batch_size, num_slots, num_inputs] - attn = dots.softmax(dim=1) + self.eps + attn_slotwise = F.normalize(attn_pixelwise + self.eps, p = 1, dim = -1) - # Weighted mean - attn = attn / attn.sum(dim=-1, keepdim=True) - updates = torch.einsum('bjd,bij->bid', v, attn) # shape: [batch_size, num_inputs, slot_dim] + # shape: [batch_size, num_inputs, slot_dim] + updates = torch.bmm(attn_slotwise, v) # Slot update slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1)) slots = slots.reshape(batch_size, self.num_slots, self.slot_dim) slots = slots + self.mlp(self.norm_pre_mlp(slots)) - return slots # [batch_size, num_slots, dim] + return slots, attn_logits, attn_slotwise # [batch_size, num_slots, dim] def change_slots_number(self, num_slots): self.num_slots = num_slots self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim)) -### Utils for SlotAttentionAutoEncoder -def build_grid(resolution): - ranges = [np.linspace(0., 1., num=res) for res in resolution] - grid = np.meshgrid(*ranges, sparse=False, indexing="ij") - grid = np.stack(grid, axis=-1) - grid = np.reshape(grid, [resolution[0], resolution[1], -1]) - grid = np.expand_dims(grid, axis=0) - grid = grid.astype(np.float32) - return np.concatenate([grid, 1.0 - grid], axis=-1) - -class SoftPositionEmbed(nn.Module): - """Adds soft positional embedding with learnable projection. - Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py""" - - def __init__(self, hidden_size, resolution): - """Builds the soft position embedding layer. - Args: - hidden_size: Size of input feature dimension. - resolution: Tuple of integers specifying width and height of grid. - """ +class PositionEmbeddingImplicit(nn.Module): + """ + Position embedding extracted from + https://github.com/vadimkantorov/yet_another_pytorch_slot_attention/blob/master/models.py + """ + def __init__(self, hidden_dim): super().__init__() - self.dense = JaxLinear(4, hidden_size) - self.grid = build_grid(resolution) + self.dense = nn.Linear(4, hidden_dim) - def forward(self, inputs): - return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] + def forward(self, x): + spatial_shape = x.shape[-3:-1] + grid = torch.stack(torch.meshgrid(*[torch.linspace(0., 1., r, device = x.device) for r in spatial_shape]), dim = -1) + grid = torch.cat([grid, 1 - grid], dim = -1) + return x + self.dense(grid) def fourier_encode(x, max_freq, num_bands = 4): x = x.unsqueeze(-1) @@ -313,7 +304,6 @@ def fourier_encode(x, max_freq, num_bands = 4): x = torch.cat((x, orig_x), dim = -1) return x - ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ @@ -393,4 +383,4 @@ class TransformerSlotAttention(nn.Module): x_d = self_attn(inputs) + inputs inputs = self_ff(x_d) + x_d - return slots # [batch_size, num_slots, dim] + return slots, None, None # [batch_size, num_slots, dim] diff --git a/osrt/model.py b/osrt/model.py index 0102d6153d02118b3866e8bb6fd4adb8b1d0dcb9..3ecc2c26c107c32ab5bb4189565393639566b5b2 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -1,13 +1,16 @@ from torch import nn import torch +import torch.nn.functional as F + import numpy as np from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder -from osrt.layers import SlotAttention, JaxLinear, SoftPositionEmbed, TransformerSlotAttention - +from osrt.layers import SlotAttention, PositionEmbeddingImplicit, TransformerSlotAttention import osrt.layers as layers + + class OSRT(nn.Module): def __init__(self, cfg): super().__init__() @@ -75,39 +78,30 @@ class SlotAttentionAutoEncoder(nn.Module): self.num_iterations = num_iterations self.encoder_cnn = nn.Sequential( - nn.Conv2d(3, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=5, padding=2), - nn.ReLU() + nn.Conv2d(3, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True), + nn.Conv2d(64, 64, kernel_size=5, padding=2), nn.ReLU(inplace=True) ) self.decoder_initial_size = (8, 8) self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), - nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2), - nn.ReLU(), - nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1) + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2, padding=1), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 64, kernel_size=5), nn.ReLU(inplace=True), + nn.ConvTranspose2d(64, 4, kernel_size=3) ) - self.encoder_pos = SoftPositionEmbed(64, self.resolution) - self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size) + self.encoder_pos = PositionEmbeddingImplicit(64) + self.decoder_pos = PositionEmbeddingImplicit(64) self.layer_norm = nn.LayerNorm(64) self.mlp = nn.Sequential( - JaxLinear(64, 64), - nn.ReLU(), - JaxLinear(64, 64) + nn.Linear(64, 64), + nn.ReLU(inplace=True), + nn.Linear(64, 64) ) model_type = cfg['model']['model_type'] @@ -128,35 +122,22 @@ class SlotAttentionAutoEncoder(nn.Module): depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model def forward(self, image): - # `image` has shape: [batch_size, num_channels, width, height]. - # Convolutional encoder with position embedding. - x = self.encoder_cnn(image) # CNN Backbone. - x = self.encoder_pos(x).permute(0, 2, 3, 1) # Position embedding. - x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set). - x = self.mlp(self.layer_norm(x)) # Feedforward network on set. - # `x` has shape: [batch_size, width*height, input_size]. - - # Slot Attention module. - slots = self.slot_attention(x) - # `slots` has shape: [batch_size, num_slots, slot_size]. - - # Spatial broadcast decoder. - x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2) - - # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size]. + x = self.encoder_cnn(image).movedim(1, -1) + x = self.encoder_pos(x) + x = self.mlp(self.layer_norm(x)) + + slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) + x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) x = self.decoder_pos(x) - x = self.decoder_cnn(x).permute(0, 2, 3, 1) - # `x` has shape: [batch_size*num_slots, width, height, num_channels+1]. + x = self.decoder_cnn(x.movedim(-1, 1)) + + x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) - # Undo combination of slot and batch dimension; split alpha masks. - recons, masks = unstack_and_split(x, batch_size=image.shape[0]) - # `recons` has shape: [batch_size, num_slots, width, height, num_channels]. - # `masks` has shape: [batch_size, num_slots, width, height, 1]. + x = x.unflatten(0, (len(image), len(x) // len(image))) - # Normalize alpha masks over slots. - masks = torch.softmax(masks, dim=1) - recon_combined = torch.sum(recons * masks, dim=1) # Recombine image. - # `recon_combined` has shape: [batch_size, width, height, num_channels]. + recons, masks = x.split((3, 1), dim = 2) + masks = masks.softmax(dim = 1) + recon_combined = (recons * masks).sum(dim = 1) - return recon_combined, recons, masks, slots + return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) diff --git a/visualise.py b/visualise.py new file mode 100644 index 0000000000000000000000000000000000000000..05c7d2ea84833a8fefb9855379961b6e9f1b6cb0 --- /dev/null +++ b/visualise.py @@ -0,0 +1,81 @@ +import datetime +import time +import torch +import torch.nn as nn +import torch.optim as optim +import argparse +import yaml + +from osrt.model import SlotAttentionAutoEncoder +from osrt import data +from osrt.utils.visualize import visualize_slot_attention +from osrt.utils.common import mse2psnr + +from torch.utils.data import DataLoader +import torch.nn.functional as F +from tqdm import tqdm + +def main(): + # Arguments + parser = argparse.ArgumentParser( + description='Train a 3D scene representation model.' + ) + parser.add_argument('config', type=str, help="Where to save the checkpoints.") + parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') + parser.add_argument('--seed', type=int, default=0, help='Random seed.') + parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + + args = parser.parse_args() + with open(args.config, 'r') as f: + cfg = yaml.load(f, Loader=yaml.CLoader) + + ### Set random seed. + torch.manual_seed(args.seed) + + ### Hyperparameters of the model. + num_slots = cfg["model"]["num_slots"] + num_iterations = cfg["model"]["iters"] + base_learning_rate = 0.0004 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + resolution = (128, 128) + + #### Create datasets + + vis_dataset = data.get_dataset('test', cfg['data']) + vis_loader = DataLoader( + vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], + shuffle=True, worker_init_fn=data.worker_init_fn) + + #### Create model + model = SlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg).to(device) + num_params = sum(p.numel() for p in model.parameters()) + + print('Number of parameters:') + print(f'Model slot attention: {num_params}') + + optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08) + + ckpt = { + 'network': model, + 'optimizer': optimizer, + 'global_step': 1639 + } + #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') + """ckpt = torch.load('~/ckpt.pth') + model = ckpt['network']""" + model.load_state_dict(torch.load('/home/achapin/ckpt.pth')["model_state_dict"]) + + image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) + image = F.interpolate(image, size=128) + image = image.to(device) + recon_combined, recons, masks, slots = model(image) + loss = nn.MSELoss() + input_image = image.permute(0, 2, 3, 1) + loss_value = loss(recon_combined, input_image) + psnr = mse2psnr(loss_value) + print(f"MSE value : {loss_value} VS PSNR {psnr}") + visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=1639, save_file=True) + +if __name__ == "__main__": + main() \ No newline at end of file