From 6d55e14b29651ebc85d2b0675a3d6ca5c3e30a5b Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Mon, 24 Jul 2023 10:02:51 +0200
Subject: [PATCH] Fix tqdm import

---
 osrt/layers.py | 1 +
 osrt/model.py  | 3 ---
 train_sa.py    | 2 +-
 3 files changed, 2 insertions(+), 4 deletions(-)

diff --git a/osrt/layers.py b/osrt/layers.py
index 0cfcff1..ef2d03e 100644
--- a/osrt/layers.py
+++ b/osrt/layers.py
@@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module):
     def forward(self, inputs):
         return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) #  from [b, h, w, c] to [b, c, h, w]
     
+    
 ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py
 class TransformerSlotAttention(nn.Module):
     """
diff --git a/osrt/model.py b/osrt/model.py
index 938d990..4c46fc6 100644
--- a/osrt/model.py
+++ b/osrt/model.py
@@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution):
     # `grid` has shape: [batch_size*num_slots, width, height, slot_size].
     return grid
 
-
-# TODO : adapt this model
 class SlotAttentionAutoEncoder(nn.Module):
     """
     Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings.
 
     Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py
-
     """
 
     def __init__(self, resolution, num_slots, num_iterations):
diff --git a/train_sa.py b/train_sa.py
index 376f02d..d973ec9 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr
 
 from torch.utils.data import DataLoader
 import torch.nn.functional as F
-import tqdm
+from tqdm import tqdm
 
 def train_step(batch, model, optimizer, device):
     """Perform a single training step."""
-- 
GitLab