Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| class EncoderViTB16(nn.Module): | |
| def __init__(self, num_classes=50, embed_size=512): | |
| super().__init__() | |
| model = models.vit_b_16( | |
| weights=models.ViT_B_16_Weights.DEFAULT | |
| ) | |
| self.backbone = model | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| in_features = model.heads.head.in_features | |
| self.backbone.heads = nn.Identity() | |
| self.classifier = nn.Linear( | |
| in_features, | |
| num_classes | |
| ) | |
| self.projector = nn.Linear( | |
| in_features, | |
| embed_size | |
| ) | |
| def forward( | |
| self, | |
| images, | |
| return_features=False | |
| ): | |
| features = self.backbone(images) | |
| if isinstance(features, tuple): | |
| features = features[0] | |
| features = features.view( | |
| features.size(0), | |
| -1 | |
| ) | |
| logits = self.classifier(features) | |
| # 특성 추출 | |
| cap_features = self.backbone._process_input(images) # B, 196, 768 | |
| cap_features = cap_features + self.backbone.encoder.pos_embedding[:, 1:, :] # 위치 임베딩 | |
| for layer in self.backbone.encoder.layers: # B, 196, 768 | |
| cap_features = layer(cap_features) | |
| cap_features = self.backbone.encoder.ln(cap_features) # LayerNorm | |
| cap_features = self.projector(cap_features) # B, 196, d_model | |
| # classification | |
| if not return_features: | |
| return logits | |
| # captioning | |
| return cap_features |