diff --git a/eval_sa.py b/eval_sa.py new file mode 100644 index 0000000000000000000000000000000000000000..858462dc755c04c79e1a287b98734f28e606d21d --- /dev/null +++ b/eval_sa.py @@ -0,0 +1,55 @@ +from osrt import data +from osrt.model import SlotAttentionAutoEncoder +import torch +import matplotlib.pyplot as plt +from PIL import Image as Image +import argparse +import yaml +from torch.utils.data import DataLoader +import torch.nn.functional as F +from osrt.utils.visualize import visualize_slot_attention + +if __name__ == "__main__": + # Arguments + parser = argparse.ArgumentParser( + description='Train a 3D scene representation model.' + ) + parser.add_argument('config', type=str, help="Where to save the checkpoints.") + parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') + parser.add_argument('--seed', type=int, default=0, help='Random seed.') + parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + + args = parser.parse_args() + with open(args.config, 'r') as f: + cfg = yaml.load(f, Loader=yaml.CLoader) + + # Hyperparameters. + seed = 0 + batch_size = 1 + num_slots = 7 + num_iterations = 3 + resolution = (128, 128) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations) + model = torch.load('./ckpt.pth')['network'] + print(model) + model.eval() + + + + eval_dataset = data.get_dataset('train', cfg['data']) + eval_loader = DataLoader( + eval_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, + shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) + + model = model.to(device) + + image = torch.squeeze(next(iter(eval_loader)).get('input_images').to(device), dim=1) + image = F.interpolate(image, size=128) + image = image.to(device) + recon_combined, recons, masks, slots = model(image) + + visualize_slot_attention(num_slots, image, recon_combined, recons, masks) + + + diff --git a/osrt/layers.py b/osrt/layers.py index 4e9f0d1fa345f72d1c50e17bc47417d4e9f062b7..0daf93ccfb3560a2c083cccc3d2624a0cee61149 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -267,31 +267,58 @@ class SlotAttention(nn.Module): self.num_slots = num_slots self.initial_slots = nn.Parameter(torch.randn(num_slots, self.slot_dim)) -############################################# -############################################# -############################################# -############################################# -############################################# -############################################# -### New implementations +### Utils for SlotAttentionAutoEncoder +def build_grid(resolution): + ranges = [np.linspace(0., 1., num=res) for res in resolution] + grid = np.meshgrid(*ranges, sparse=False, indexing="ij") + grid = np.stack(grid, axis=-1) + grid = np.reshape(grid, [resolution[0], resolution[1], -1]) + grid = np.expand_dims(grid, axis=0) + grid = grid.astype(np.float32) + return np.concatenate([grid, 1.0 - grid], axis=-1) + +class SoftPositionEmbed(nn.Module): + """Adds soft positional embedding with learnable projection. + Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py""" + + def __init__(self, hidden_size, resolution): + """Builds the soft position embedding layer. + Args: + hidden_size: Size of input feature dimension. + resolution: Tuple of integers specifying width and height of grid. + """ + super().__init__() + self.dense = JaxLinear(4, hidden_size) + self.grid = build_grid(resolution) + + def forward(self, inputs): + return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] + +### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ An extension of Slot Attention using self-attention """ - def __init__(self, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, iters=3, eps=1e-8, + def __init__(self, depth, heads, dim_head, mlp_dim, num_slots, input_dim=768, slot_dim=1536, hidden_dim=3072, eps=1e-8, randomize_initial_slots=False): super().__init__() self.num_slots = num_slots self.batch_slots = [] - self.iters = iters self.scale = slot_dim ** -0.5 self.slot_dim = slot_dim + self.depth = depth + self.num_heads = 8 + self.randomize_initial_slots = randomize_initial_slots self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim)) + #def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None): + self.transformer_stage_1 = Transformer(dim=384, depth=2, heads=8) + self.transformer_stage_2 = Transformer(dim=384, depth=2, heads=8) + self.eps = eps self.to_q = JaxLinear(slot_dim, slot_dim, bias=False) @@ -345,31 +372,3 @@ class TransformerSlotAttention(nn.Module): slots = slots + self.mlp(self.norm_pre_mlp(slots)) return slots # [batch_size, num_slots, dim] - - -def build_grid(resolution): - ranges = [np.linspace(0., 1., num=res) for res in resolution] - grid = np.meshgrid(*ranges, sparse=False, indexing="ij") - grid = np.stack(grid, axis=-1) - grid = np.reshape(grid, [resolution[0], resolution[1], -1]) - grid = np.expand_dims(grid, axis=0) - grid = grid.astype(np.float32) - return np.concatenate([grid, 1.0 - grid], axis=-1).transpose(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] - -class SoftPositionEmbed(nn.Module): - """Adds soft positional embedding with learnable projection. - Implementation extracted from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py""" - - def __init__(self, hidden_size, resolution): - """Builds the soft position embedding layer. - - Args: - hidden_size: Size of input feature dimension. - resolution: Tuple of integers specifying width and height of grid. - """ - super().__init__() - self.dense = JaxLinear(4, hidden_size) - self.grid = build_grid(resolution) - - def forward(self, inputs): - return inputs + self.dense(torch.tensor(self.grid).cuda()) \ No newline at end of file diff --git a/osrt/model.py b/osrt/model.py index 45821355ff17b56ebf0769373ce5211de3999cbf..938d990377e61104dd90a39b2069917d1b0babab 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -90,20 +90,20 @@ class SlotAttentionAutoEncoder(nn.Module): self.decoder_initial_size = (8, 8) self.decoder_cnn = nn.Sequential( - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=2, padding=2), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(2, 2), padding=2, output_padding=1), nn.ReLU(), - nn.ConvTranspose2d(64, 64, kernel_size=5, stride=1, padding=2), + nn.ConvTranspose2d(64, 64, kernel_size=5, stride=(1, 1), padding=2), nn.ReLU(), - nn.ConvTranspose2d(64, 4, kernel_size=3, stride=1, padding=2) + nn.ConvTranspose2d(64, 4, kernel_size=3, stride=(1, 1), padding=1) ) - self.encoder_pos = SoftPositionEmbed(64, (32, 32)) + self.encoder_pos = SoftPositionEmbed(64, self.resolution) self.decoder_pos = SoftPositionEmbed(64, self.decoder_initial_size) self.layer_norm = nn.LayerNorm(64) @@ -115,17 +115,16 @@ class SlotAttentionAutoEncoder(nn.Module): self.slot_attention = SlotAttention( num_slots=self.num_slots, + input_dim=64, slot_dim=64, hidden_dim=128, iters=self.num_iterations) def forward(self, image): # `image` has shape: [batch_size, num_channels, width, height]. - print(f"Shape input {image.shape}") # Convolutional encoder with position embedding. x = self.encoder_cnn(image) # CNN Backbone. - print(f"Shape after encoder {x.shape}") - x = self.encoder_pos(x) # Position embedding. + x = self.encoder_pos(x).permute(0, 2, 3, 1) # Position embedding. x = spatial_flatten(x) # Flatten spatial dimensions (treat image as set). x = self.mlp(self.layer_norm(x)) # Feedforward network on set. # `x` has shape: [batch_size, width*height, input_size]. @@ -135,10 +134,11 @@ class SlotAttentionAutoEncoder(nn.Module): # `slots` has shape: [batch_size, num_slots, slot_size]. # Spatial broadcast decoder. - x = spatial_broadcast(slots, self.decoder_initial_size) + x = spatial_broadcast(slots, self.decoder_initial_size).permute(0, 3, 1, 2) + # `x` has shape: [batch_size*num_slots, width_init, height_init, slot_size]. - #x = self.decoder_pos(x) - x = self.decoder_cnn(x) + x = self.decoder_pos(x) + x = self.decoder_cnn(x).permute(0, 2, 3, 1) # `x` has shape: [batch_size*num_slots, width, height, num_channels+1]. # Undo combination of slot and batch dimension; split alpha masks. diff --git a/osrt/utils/visualize.py b/osrt/utils/visualize.py index 53c59485edf7584ef3ef822b94903d7bd9f49f20..af5129083898c32e2ee13f24b9c3ee8c7584c9f2 100644 --- a/osrt/utils/visualize.py +++ b/osrt/utils/visualize.py @@ -88,4 +88,30 @@ def draw_visualization_grid(columns, outfile, row_labels=None, name=None): plt.savefig(f'{outfile}.png') plt.close() - +def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, save_file = False): + fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2)) + image = image.squeeze(0) + recon_combined = recon_combined.squeeze(0) + recons = recons.squeeze(0) + masks = masks.squeeze(0) + image = image.permute(1,2,0).cpu().numpy() + recon_combined = recon_combined.cpu().detach().numpy() + recons = recons.cpu().detach().numpy() + masks = masks.cpu().detach().numpy() + + if not save_file: + ax[0].imshow(image) + ax[0].set_title('Image') + ax[1].imshow(recon_combined) + ax[1].set_title('Recon.') + for i in range(6): + picture = recons[i] * masks[i] + (1 - masks[i]) + ax[i + 2].imshow(picture) + ax[i + 2].set_title('Slot %s' % str(i + 1)) + for i in range(len(ax)): + ax[i].grid(False) + ax[i].axis('off') + plt.show() + else: + # TODO : save png in file + pass diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index 44592b758dea1ca61e4a92be44eb3202cbe97763..aade5ddcd330c1fd1cdb69c88ae3504ae140970e 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -5,7 +5,7 @@ model: iters: 3 training: num_workers: 2 - batch_size: 64 + batch_size: 32 visualize_every: 5000 validate_every: 5000 checkpoint_every: 1000 diff --git a/train_sa.py b/train_sa.py index 3afbecd956ba117968bda77e3b69fdde9f6e1eab..35aa44dac636980dc38c8e7aa3896940bceb2316 100644 --- a/train_sa.py +++ b/train_sa.py @@ -9,6 +9,7 @@ from osrt.model import SlotAttentionAutoEncoder from osrt import data from torch.utils.data import DataLoader +import torch.nn.functional as F def l2_loss(prediction, target): @@ -17,10 +18,12 @@ def l2_loss(prediction, target): def train_step(batch, model, optimizer, device): """Perform a single training step.""" input_image = torch.squeeze(batch.get('input_images').to(device), dim=1) + input_image = F.interpolate(input_image, size=128) # Get the prediction of the model and compute the loss. preds = model(input_image) recon_combined, recons, masks, slots = preds + input_image = input_image.permute(0, 2, 3, 1) loss_value = l2_loss(input_image, recon_combined) del recons, masks, slots # Unused.