diff --git a/osrt/layers.py b/osrt/layers.py index 0cfcff18aff1d98d7d7ee8759df3fed3acf20fd9..ef2d03ee10e5b06f666b57cf94c8082c65227e85 100644 --- a/osrt/layers.py +++ b/osrt/layers.py @@ -295,6 +295,7 @@ class SoftPositionEmbed(nn.Module): def forward(self, inputs): return inputs + self.dense(torch.tensor(self.grid).cuda()).permute(0, 3, 1, 2) # from [b, h, w, c] to [b, c, h, w] + ### New transformer implementation of SlotAttention inspired from https://github.com/ThomasMrY/VCT/blob/master/models/visual_concept_tokenizor.py class TransformerSlotAttention(nn.Module): """ diff --git a/osrt/model.py b/osrt/model.py index 938d990377e61104dd90a39b2069917d1b0babab..4c46fc6d43fcbe5c4b391414c941ac3431685092 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -54,14 +54,11 @@ def spatial_broadcast(slots, resolution): # `grid` has shape: [batch_size*num_slots, width, height, slot_size]. return grid - -# TODO : adapt this model class SlotAttentionAutoEncoder(nn.Module): """ Slot Attention as introduced by Locatello et al. but with the AutoEncoder part to extract image embeddings. Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py - """ def __init__(self, resolution, num_slots, num_iterations): diff --git a/train_sa.py b/train_sa.py index 376f02d6e73b5a2b42393b8c50982b0aaf8d57da..d973ec9e8d6f854d436ea249751dfe4b768f83a2 100644 --- a/train_sa.py +++ b/train_sa.py @@ -13,7 +13,7 @@ from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F -import tqdm +from tqdm import tqdm def train_step(batch, model, optimizer, device): """Perform a single training step."""