Instructions to use sunweiwei/AirRep-Flan-Small with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use sunweiwei/AirRep-Flan-Small with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="sunweiwei/AirRep-Flan-Small")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("sunweiwei/AirRep-Flan-Small", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """AirRep model implementation.""" | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, BertConfig, PreTrainedModel | |
| from transformers.modeling_outputs import BaseModelOutput | |
| def mean_pooling(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| """Apply mean pooling to hidden states.""" | |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
| class AirRepConfig(BertConfig): | |
| """Configuration class for AirRep model.""" | |
| model_type = "airrep" | |
| def __init__( | |
| self, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| class AirRepModel(PreTrainedModel): | |
| """ | |
| AirRep model with BERT encoder and projection layer. | |
| This is a standalone model, not a wrapper. | |
| """ | |
| config_class = AirRepConfig | |
| base_model_prefix = "airrep" | |
| def __init__(self, config: AirRepConfig): | |
| super().__init__(config) | |
| self.config = config | |
| # BERT encoder | |
| self.bert = BertModel(config, add_pooling_layer=False) | |
| # Projection layer | |
| self.projector = nn.Linear( | |
| config.hidden_size, | |
| config.hidden_size, | |
| dtype=torch.bfloat16 | |
| ) | |
| # Initialize weights | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Input token IDs | |
| attention_mask: Attention mask | |
| token_type_ids: Token type IDs | |
| Returns: | |
| Pooled and projected embeddings (batch_size, hidden_size) | |
| """ | |
| # Get BERT outputs | |
| outputs = self.bert( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| ) | |
| # Mean pooling | |
| last_hidden_state = outputs.last_hidden_state | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| pooled = mean_pooling(last_hidden_state, attention_mask) | |
| # Project | |
| projected = self.projector(pooled) | |
| return projected | |
| def save_pretrained(self, save_directory: str, **kwargs): | |
| """Save model and config.""" | |
| super().save_pretrained(save_directory, **kwargs) |