diff --git a/models/model.py b/models/model.py index 5853484a01d4e3bbe0b6923922b5c7327f4eded6..192120347961c6c5e1555d3e80c67f2389228206 100644 --- a/models/model.py +++ b/models/model.py @@ -265,7 +265,7 @@ class Classification_model(nn.Module): def __init__(self, model, n_class, *args, **kwargs): super().__init__(*args, **kwargs) self.n_class = n_class - if model =='Resnet18': + if model =='ResNet18': self.im_encoder = resnet18(num_classes=self.n_class)