From 74ec94a7281c4f60531181d51294ec5069520bc6 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 24 Jul 2023 16:57:36 +0200 Subject: [PATCH] Fix interpolate --- osrt/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/osrt/model.py b/osrt/model.py index 955d198..a653271 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -118,7 +118,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) - x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) + x = F.interpolate(x, image.shape[-2:], mode='bilinear') x = x.unflatten(0, (len(image), len(x) // len(image))) @@ -142,7 +142,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) - x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) + x = F.interpolate(x, image.shape[-2:], mode='bilinear') x = x.unflatten(0, (len(image), len(x) // len(image))) -- GitLab