Skip to content
Snippets Groups Projects
Commit 6d55e14b authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Fix tqdm import

parent 3a4fb34b
No related branches found
No related tags found
No related merge requests found
...@@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module): ...@@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module):
def forward(self, inputs): def forward(self, inputs):
return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w]
### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
class TransformerSlotAttention(nn.Module): class TransformerSlotAttention(nn.Module):
""" """
......
...@@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution): ...@@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution):
# `grid` has shape: [batch_size*num_slots, width, height, slot_size]. # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
return grid return grid
# TODO : adapt this model
class SlotAttentionAutoEncoder(nn.Module): class SlotAttentionAutoEncoder(nn.Module):
""" """
Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
""" """
def __init__(self, resolution, num_slots, num_iterations): def __init__(self, resolution, num_slots, num_iterations):
......
...@@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr ...@@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F import torch.nn.functional as F
import tqdm from tqdm import tqdm
def train_step(batch, model, optimizer, device): def train_step(batch, model, optimizer, device):
"""Perform a single training step.""" """Perform a single training step."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment