Spaces:
Running
Running
| import torch.nn as nn | |
| from torchvision import models | |
| class EncoderEfficientNetB0(nn.Module): | |
| def __init__(self, num_classes=50, embed_size=512): | |
| super().__init__() | |
| model = models.efficientnet_b0( | |
| weights=models.EfficientNet_B0_Weights.DEFAULT | |
| ) | |
| self.backbone = model.features | |
| self.pool = nn.AdaptiveAvgPool2d(1) | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| in_features = model.classifier[1].in_features | |
| self.classifier = nn.Linear( | |
| in_features, | |
| num_classes | |
| ) | |
| self.projector = nn.Linear( | |
| in_features, | |
| embed_size | |
| ) | |
| def forward( | |
| self, | |
| images, | |
| return_features=False | |
| ): | |
| features = self.backbone(images) | |
| features = self.pool(features) | |
| features = features.view( | |
| features.size(0), | |
| -1 | |
| ) | |
| logits = self.classifier(features) | |
| features = self.projector(features) | |
| # classification | |
| if not return_features: | |
| return logits | |
| # captioning | |
| return features |