diff --git a/osrt/model.py b/osrt/model.py index 955d19885ec6642630a668d6914fd16fb68fa4aa..a653271b4437b557a97cbe34a25d11fa4ecb3cd1 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -118,7 +118,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) - x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) + x = F.interpolate(x, image.shape[-2:], mode='bilinear') x = x.unflatten(0, (len(image), len(x) // len(image))) @@ -142,7 +142,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) - x = F.interpolate(x, image.shape[-2:], mode = self.interpolate_mode) + x = F.interpolate(x, image.shape[-2:], mode='bilinear') x = x.unflatten(0, (len(image), len(x) // len(image)))