From 6d55e14b29651ebc85d2b0675a3d6ca5c3e30a5b Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 10:02:51 +0200 Subject: [PATCH] Fix tqdm import --- osrt/layers.py | 1 + osrt/model.py | 3 --- train_sa.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/osrt/layers.py b/osrt/layers.py index 0cfcff1..ef2d03e 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module): def forward(self, inputs): return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] + ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ diff --git a/osrt/model.py b/osrt/model.py index 938d990..4c46fc6 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution): # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. return grid - -# TODO : adapt this model class SlotAttentionAutoEncoder(nn.Module): """ Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py - """ def __init__(self, resolution, num_slots, num_iterations): diff --git a/train_sa.py b/train_sa.py index 376f02d..d973ec9 100644 --- a/train_sa.py +++ b/train_sa.py @@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F -import tqdm +from tqdm import tqdm def train_step(batch, model, optimizer, device): """Perform a single training step.""" -- GitLab