Skip to content
Snippets Groups Projects
Commit 6d55e14b authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix tqdm import

parent 3a4fb34b
No related branches found
No related tags found
No related merge requests found
......@@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module):
def forward(self, inputs):
return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module):
"""
......
......@@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution):
# `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid
# TODO : adapt this model
class SlotAttentionAutoEncoder(nn.Module):
"""
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
"""
def __init__(self, resolution, num_slots, num_iterations):
......
......@@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader
import torch.nn.functional as F
import tqdm
from tqdm import tqdm
def train_step(batch, model, optimizer, device):
"""Perform a single training step."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment