|
|
| import math |
| from functools import partial |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, reduce, repeat |
| from torch.utils.checkpoint import checkpoint |
| from transformers import AutoModel, PreTrainedModel |
|
|
| from .config import LUARConfig |
|
|
| |
| |
|
|
| def exists(val): |
| return val is not None |
|
|
| def summarize_qkv_chunk( |
| q, k, v, |
| mask |
| ): |
| """Dot-Product Attention for a chunk of queries, keys, and values. |
| """ |
| weight = torch.einsum('b h i d, b h j d -> b h i j', q, k) |
|
|
| if exists(mask): |
| |
| weight += mask |
|
|
| weight_max = weight.amax(dim = -1, keepdim = True).detach() |
| weight = weight - weight_max |
|
|
| exp_weight = weight.exp() |
| weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v) |
|
|
| return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...') |
|
|
| checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk) |
|
|
| def memory_efficient_attention( |
| q, k, v, |
| mask = None, |
| q_bucket_size = 512, |
| k_bucket_size = 1024, |
| eps = 1e-8 |
| ): |
| scale = q.shape[-1] ** -0.5 |
| q = q * scale |
|
|
| |
| needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad |
| summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk |
|
|
| |
| q_chunks = q.split(q_bucket_size, dim = -2) |
| k_chunks = k.split(k_bucket_size, dim = -2) |
| v_chunks = v.split(k_bucket_size, dim = -2) |
| mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks)) |
|
|
| |
| out = [] |
| for q_index, q_chunk in enumerate(q_chunks): |
| exp_weights = [] |
| weighted_values = [] |
| weight_maxes = [] |
| |
| for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)): |
|
|
| exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn( |
| q_chunk, |
| k_chunk, |
| v_chunk, |
| mask_chunk, |
| ) |
|
|
| exp_weights.append(exp_weight_chunk) |
| weighted_values.append(weighted_value_chunk) |
| weight_maxes.append(weight_max_chunk) |
|
|
| exp_weights = torch.stack(exp_weights, dim = -1) |
| weighted_values = torch.stack(weighted_values, dim = -1) |
| weight_maxes = torch.stack(weight_maxes, dim = -1) |
|
|
| global_max = weight_maxes.amax(dim = -1, keepdim = True) |
| renorm_factor = (weight_maxes - global_max).exp().detach() |
|
|
| exp_weights = exp_weights * renorm_factor |
| weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c') |
|
|
| all_values = weighted_values.sum(dim = -1) |
| all_weights = exp_weights.sum(dim = -1) |
|
|
| normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps) |
| out.append(normalized_values) |
|
|
| return torch.cat(out, dim=-2) |
|
|
| class SelfAttention(nn.Module): |
| """Implements Dot-Product Self-Attention as used in "Attention is all You Need". |
| """ |
| def __init__( |
| self, |
| memory_efficient_attention=False, |
| q_bucket_size=512, |
| k_bucket_size=1024, |
| ): |
| super(SelfAttention, self).__init__() |
| self.use_memory_efficient_attention = memory_efficient_attention |
| self.q_bucket_size = q_bucket_size |
| self.k_bucket_size = k_bucket_size |
|
|
| def forward(self, k, q, v): |
|
|
| if self.use_memory_efficient_attention: |
| q, k, v = map( |
| lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12), |
| (q, k, v) |
| ) |
|
|
| out = memory_efficient_attention( |
| q, k, v, |
| q_bucket_size=self.q_bucket_size, |
| k_bucket_size=self.k_bucket_size |
| ) |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| return out |
| else: |
| d_k = q.size(-1) |
| scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k) |
| p_attn = F.softmax(scores, dim=-1) |
| return torch.matmul(p_attn, v) |
|
|
| class LUAR(PreTrainedModel): |
| """Defines the LUAR model. |
| """ |
| config_class = LUARConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.create_transformer(revision=config.upstream_transformer_revision) |
| self.attn_fn = SelfAttention( |
| config.use_memory_efficient_attention, |
| config.q_bucket_size, |
| config.k_bucket_size, |
| ) |
| self.linear = nn.Linear(self.hidden_size, config.embedding_size) |
|
|
| def create_transformer(self, revision: Optional[str] = None): |
| """Creates the Transformer backbone. |
| """ |
| kwargs = {"revision": revision} if revision else {} |
| self.transformer = AutoModel.from_pretrained("sentence-transformers/paraphrase-distilroberta-base-v1", **kwargs) |
| self.hidden_size = self.transformer.config.hidden_size |
| self.num_attention_heads = self.transformer.config.num_attention_heads |
| self.dim_head = self.hidden_size // self.num_attention_heads |
| |
| def mean_pooling(self, token_embeddings, attention_mask): |
| """Mean Pooling as described in the SBERT paper. |
| """ |
| input_mask_expanded = repeat(attention_mask, 'b l -> b l d', d=self.hidden_size).type(token_embeddings.type()) |
| sum_embeddings = reduce(token_embeddings * input_mask_expanded, 'b l d -> b d', 'sum') |
| sum_mask = torch.clamp(reduce(input_mask_expanded, 'b l d -> b d', 'sum'), min=1e-9) |
| return sum_embeddings / sum_mask |
| |
| def get_episode_embeddings(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
| """Computes the Author Embedding. |
| """ |
| B, E, _ = attention_mask.shape |
|
|
| input_ids = rearrange(input_ids, 'b e l -> (b e) l') |
| attention_mask = rearrange(attention_mask, 'b e l -> (b e) l') |
|
|
| if document_batch_size > 0: |
| outputs = {"last_hidden_state": [], "attentions": []} |
| for i in range(0, len(input_ids), document_batch_size): |
| out = self.transformer( |
| input_ids=input_ids[i:i+document_batch_size], |
| attention_mask=attention_mask[i:i+document_batch_size], |
| return_dict=True, |
| output_hidden_states=False, |
| output_attentions=output_attentions, |
| ) |
| outputs["last_hidden_state"].append(out["last_hidden_state"]) |
| if output_attentions: |
| outputs["attentions"].append(out["attentions"]) |
| outputs["last_hidden_state"] = torch.cat(outputs["last_hidden_state"], dim=0) |
| if output_attentions: |
| outputs["attentions"] = tuple([torch.cat([x[i] for x in outputs["attentions"]], dim=0) for i in range(len(outputs["attentions"][0]))]) |
| else: |
| outputs = self.transformer( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| output_hidden_states=False, |
| output_attentions=output_attentions, |
| ) |
| |
| |
| comment_embeddings = self.mean_pooling(outputs['last_hidden_state'], attention_mask) |
| comment_embeddings = rearrange(comment_embeddings, '(b e) l -> b e l', b=B, e=E) |
|
|
| |
| episode_embeddings = self.attn_fn(comment_embeddings, comment_embeddings, comment_embeddings) |
| episode_embeddings = reduce(episode_embeddings, 'b e l -> b l', 'max') |
| |
| episode_embeddings = self.linear(episode_embeddings) |
| |
| if output_attentions: |
| return episode_embeddings, outputs["attentions"] |
|
|
| return episode_embeddings |
| |
| def forward(self, input_ids, attention_mask, output_attentions=False, document_batch_size=0): |
| """Calculates a fixed-length feature vector for a batch of episode samples. |
| """ |
| output = self.get_episode_embeddings(input_ids, attention_mask, output_attentions, document_batch_size) |
|
|
| return output |
|
|