diff --git a/models/model.py b/models/model.py index b601c6431bec7fc838a9b0bf86fb586af5837388..3fde4b152815984884db77f031ae8e15f6dcb164 100644 --- a/models/model.py +++ b/models/model.py @@ -324,7 +324,7 @@ class Classification_model_duo_pretrained(nn.Module): if model =='ResNet50': self.im_encoder = torch.hub.load('pretrained_model/pytorch_vision_v0.10.0/', 'resnet50', pretrained=True,source='local') - self.im_encoder.fc = torch.nn.Linear(2000, self.n_class) + self.im_encoder.fc = torch.nn.Linear(2048, self.n_class) #freeze backbone for n, p in self.im_encoder.named_parameters(): if n not in ['fc.weight', 'fc.bias']: