Spaces:
Running
Running
| import torch.nn as nn | |
| import timm | |
| class EncoderDeiTTiny(nn.Module): | |
| def __init__(self, num_classes=50, embed_size=512): | |
| super().__init__() | |
| model = timm.create_model( | |
| "deit_tiny_patch16_224", | |
| pretrained=True | |
| ) | |
| self.backbone = model | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| in_features = model.head.in_features | |
| self.backbone.head = nn.Identity() | |
| self.classifier = nn.Linear( | |
| in_features, | |
| num_classes | |
| ) | |
| self.projector = nn.Linear( | |
| in_features, | |
| embed_size | |
| ) | |
| def forward( | |
| self, | |
| images, | |
| return_features=False | |
| ): | |
| features = self.backbone(images) | |
| features = features.view( | |
| features.size(0), | |
| -1 | |
| ) | |
| logits = self.classifier(features) | |
| features = self.projector(features) | |
| # classification | |
| if not return_features: | |
| return logits | |
| # captioning | |
| return features |