Spaces:
Build error
Build error
| import contextlib | |
| import functools | |
| import inspect | |
| from enum import Enum | |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union | |
| import torch | |
| # Since we will be patching the `scaled_dot_product_attention` function with `attention_dispatch` to take | |
| # control for dispatching to different attention providers, we need to import the original function | |
| # to be able to use it and not go into infinite recursion when the dispatcher calls `scaled_dot_product_attention`. | |
| import torch.autograd | |
| from diffusers.utils.import_utils import OptionalDependencyNotAvailable | |
| from torch.nn.functional import scaled_dot_product_attention as native_sdpa | |
| from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER | |
| from finetrainers.logging import get_logger | |
| from finetrainers.utils.import_utils import ( | |
| is_flash_attn_available, | |
| is_flash_attn_version, | |
| is_sageattention_available, | |
| is_sageattention_version, | |
| is_torch_version, | |
| is_xformers_available, | |
| is_xformers_version, | |
| ) | |
| if is_flash_attn_available(): | |
| if is_flash_attn_version("<", "2.6.3"): | |
| raise OptionalDependencyNotAvailable( | |
| "The `flash-attn` library version is too old. Please update it to at least 2.6.3." | |
| ) | |
| from flash_attn import flash_attn_func, flash_attn_varlen_func | |
| from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward | |
| else: | |
| flash_attn_func = None | |
| flash_attn_varlen_func = None | |
| _flash_attn_forward = None | |
| _flash_attn_backward = None | |
| if is_sageattention_available(): | |
| if is_sageattention_version("<", "2.1.1"): | |
| raise OptionalDependencyNotAvailable( | |
| "The `sageattention` library version is too old. Please update it to at least 2.1.1." | |
| ) | |
| from sageattention import ( | |
| sageattn, | |
| sageattn_qk_int8_pv_fp8_cuda, | |
| sageattn_qk_int8_pv_fp8_cuda_sm90, | |
| sageattn_qk_int8_pv_fp16_cuda, | |
| sageattn_qk_int8_pv_fp16_triton, | |
| sageattn_varlen, | |
| ) | |
| else: | |
| sageattn = None | |
| sageattn_qk_int8_pv_fp16_cuda = None | |
| sageattn_qk_int8_pv_fp16_triton = None | |
| sageattn_qk_int8_pv_fp8_cuda = None | |
| sageattn_qk_int8_pv_fp8_cuda_sm90 = None | |
| sageattn_varlen = None | |
| if is_torch_version(">=", "2.5.0"): | |
| import torch.nn.attention.flex_attention as flex_attention | |
| if is_torch_version(">=", "2.6.0"): | |
| from torch.distributed.tensor.experimental._attention import ( | |
| _AttentionOp, | |
| _cp_options, | |
| _templated_ring_attention, | |
| _templated_ring_attention_backward, | |
| set_rotate_method, | |
| ) | |
| else: | |
| _cp_options = None | |
| _templated_ring_attention = None | |
| set_rotate_method = None | |
| class _AttentionOp: | |
| def __init__(self, *args, **kwargs): | |
| raise OptionalDependencyNotAvailable( | |
| "The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." | |
| ) | |
| if is_xformers_available(): | |
| if is_xformers_version("<", "0.0.29"): | |
| raise OptionalDependencyNotAvailable( | |
| "The `xformers` library version is too old. Please update it to at least 0.0.29." | |
| ) | |
| import xformers.ops as xops | |
| else: | |
| xops = None | |
| logger = get_logger() | |
| _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] | |
| _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] | |
| _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] | |
| # ===== Custom operator implementations/wrappers ===== | |
| def _finetrainers_scaled_dot_product_efficient_attention_forward( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_bias: Optional[torch.Tensor] = None, | |
| compute_log_sumexp: bool = False, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| # Wrapper for https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 | |
| # See: https://github.com/pytorch/pytorch/issues/152942 | |
| seqlen_q = query.shape[-2] | |
| out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_bias, | |
| compute_log_sumexp=compute_log_sumexp, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| ) | |
| # LSE is aligned to the next nearest multiple of 32. This is a workaround to return the lse without alignment so that pytorch | |
| # ring attention does not error out with shape mismatch | |
| if compute_log_sumexp: | |
| assert lse.ndim == 3 | |
| lse = lse[:, :, :seqlen_q] # .contiguous() | |
| return out, lse, philox_seed, philox_offset | |
| # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) | |
| def _finetrainers_scaled_dot_product_efficient_attention_backward( | |
| grad_out_: torch.Tensor, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_bias: torch.Tensor, | |
| out: torch.Tensor, | |
| logsumexp: torch.Tensor, | |
| philox_seed: torch.Tensor, | |
| philox_offset: torch.Tensor, | |
| dropout_p: float, | |
| grad_input_mask: List[bool], | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| assert len(grad_input_mask) == 4 | |
| # https://github.com/pytorch/pytorch/blob/bb9fbb294af385057a72e5b1386cf40f86aadbec/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h#L113 | |
| kAlignLSE = 32 | |
| logsumexp = torch.nn.functional.pad( | |
| logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") | |
| ) | |
| grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( | |
| grad_out_=grad_out_, | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_bias, | |
| out=out, | |
| logsumexp=logsumexp, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| dropout_p=dropout_p, | |
| grad_input_mask=grad_input_mask, | |
| is_causal=is_causal, | |
| scale=scale, | |
| ) | |
| return grad_query, grad_key, grad_value, grad_attn_bias | |
| # This function wraps the actual _flash_attn_forward call to return LSE at index 1 to be compatible with pytorch's native ring attention | |
| def _finetrainers_flash_attn_forward( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| scale: Optional[float] = None, | |
| is_causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| return_softmax: bool = False, | |
| ): | |
| query, key, value = ( | |
| x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) | |
| ) # [B, N, S, D] -> [B, S, N, D] | |
| out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( | |
| query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax | |
| ) | |
| out = out.permute(0, 2, 1, 3).contiguous() # [B, S, N, D] -> [B, N, S, D] | |
| return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state | |
| # This function wraps the actual _flash_attn_backward call as the counterpart of the _finetrainers_flash_attn_forward function | |
| def _finetrainers_flash_attn_backward( | |
| grad_out: torch.Tensor, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| out: torch.Tensor, | |
| logsumexp: torch.Tensor, # Needs a different names than the one used in flash-attn because _templated_ring_attention_backward assumes name is logsumexp | |
| dropout_p: float, | |
| scale: Optional[float] = None, | |
| is_causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| deterministic: bool = False, | |
| rng_state: Optional[torch.Tensor] = None, | |
| _permute_outputs: bool = True, | |
| ): | |
| dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) | |
| grad_out = grad_out.permute(0, 2, 1, 3).contiguous() # [B, N, S, D] -> [B, S, N, D] | |
| dq, dk, dv, softmax_d = _flash_attn_backward( | |
| grad_out, | |
| query, | |
| key, | |
| value, | |
| out, | |
| logsumexp, | |
| dq, | |
| dk, | |
| dv, | |
| dropout_p, | |
| scale, | |
| is_causal, | |
| window_size, | |
| softcap, | |
| alibi_slopes, | |
| deterministic, | |
| rng_state, | |
| ) | |
| # Head dimension may have been padded | |
| dq = dq[..., : grad_out.shape[-1]] | |
| dk = dk[..., : grad_out.shape[-1]] | |
| dv = dv[..., : grad_out.shape[-1]] | |
| if _permute_outputs: | |
| dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) # [B, S, N, D] -> [B, N, S, D] | |
| return dq, dk, dv | |
| # ===== Attention provider ===== | |
| class AttentionProvider(str, Enum): | |
| # EAGER = "eager" | |
| # `flash-attn` | |
| FLASH = "flash" | |
| FLASH_VARLEN = "flash_varlen" | |
| # PyTorch native | |
| FLEX = "flex" | |
| NATIVE = "native" | |
| _NATIVE_CUDNN = "_native_cudnn" | |
| _NATIVE_EFFICIENT = "_native_efficient" | |
| _NATIVE_FLASH = "_native_flash" | |
| _NATIVE_MATH = "_native_math" | |
| # `sageattention` | |
| SAGE = "sage" | |
| SAGE_VARLEN = "sage_varlen" | |
| _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" | |
| _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" | |
| _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" | |
| _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" | |
| # TODO: let's not add support for Sparge Attention now because it requires tuning per model | |
| # We can look into supporting something "autotune"-ing in the future | |
| # SPARGE = "sparge" | |
| # `xformers` | |
| XFORMERS = "xformers" | |
| class _AttentionProviderRegistry: | |
| _providers = {} | |
| _constraints = {} | |
| _supports_cp = {} | |
| _supported_arg_names = {} | |
| _active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) | |
| _checks_enabled = FINETRAINERS_ATTN_CHECKS | |
| # Context parallel attributes | |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | |
| _convert_to_fp32: bool = None | |
| _rotate_method: Literal["allgather", "alltoall"] = None | |
| def register( | |
| cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False | |
| ): | |
| logger.debug(f"Registering attention provider: {provider}") | |
| def decorator(func): | |
| cls._providers[provider] = func | |
| cls._constraints[provider] = constraints or [] | |
| cls._supports_cp[provider] = supports_cp | |
| cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) | |
| return func | |
| return decorator | |
| def get_active_provider(cls): | |
| return cls._active_provider, cls._providers[cls._active_provider] | |
| def list_providers(cls): | |
| return list(cls._providers.keys()) | |
| def supports_context_parallel(cls, provider: AttentionProvider): | |
| if provider not in cls._providers: | |
| raise ValueError(f"Provider {provider} is not registered.") | |
| return cls._supports_cp.get(provider, False) | |
| def context_parallel_enabled(cls): | |
| return cls._mesh is not None | |
| def _set_context_parallel( | |
| cls, | |
| mesh: torch.distributed.device_mesh.DeviceMesh = None, | |
| convert_to_fp32: bool = None, | |
| rotate_method: str = None, | |
| *, | |
| reset: bool = False, | |
| ): | |
| if reset: | |
| mesh = convert_to_fp32 = rotate_method = None | |
| cls._mesh = mesh | |
| cls._convert_to_fp32 = convert_to_fp32 | |
| cls._rotate_method = rotate_method | |
| def _raise_cp_error_if_mesh_not_set(cls): | |
| if cls._mesh is None: | |
| raise ValueError( | |
| "`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." | |
| ) | |
| def attention_provider( | |
| provider: AttentionProvider = AttentionProvider.NATIVE, | |
| *, | |
| mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | |
| convert_to_fp32: bool = True, | |
| rotate_method: str = "allgather", | |
| ): | |
| """Context manager to set the active attention provider and possibly enable context parallelism.""" | |
| if provider not in _AttentionProviderRegistry._providers: | |
| raise ValueError(f"Provider {provider} is not registered.") | |
| if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): | |
| raise ValueError(f"Provider {provider} does not support context parallelism.") | |
| old_provider = _AttentionProviderRegistry._active_provider | |
| _AttentionProviderRegistry._active_provider = provider | |
| _AttentionProviderRegistry._mesh = mesh | |
| _AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 | |
| _AttentionProviderRegistry._rotate_method = rotate_method | |
| if mesh is not None: | |
| _convert_to_f32 = _cp_options.convert_to_f32 | |
| _enable_load_balance = _cp_options.enable_load_balance | |
| _rotate_method = _cp_options.rotate_method | |
| try: | |
| yield | |
| finally: | |
| _AttentionProviderRegistry._active_provider = old_provider | |
| _AttentionProviderRegistry._mesh = None | |
| _AttentionProviderRegistry._convert_to_fp32 = None | |
| _AttentionProviderRegistry._rotate_method = None | |
| if mesh is not None: | |
| _cp_options.convert_to_f32 = _convert_to_f32 | |
| _cp_options.enable_load_balance = _enable_load_balance | |
| _cp_options.rotate_method = _rotate_method | |
| def attention_dispatch( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| enable_gqa: bool = False, | |
| attention_kwargs: Optional[Dict[str, Any]] = None, | |
| ) -> torch.Tensor: | |
| attention_kwargs = attention_kwargs or {} | |
| provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() | |
| kwargs = { | |
| "query": query, | |
| "key": key, | |
| "value": value, | |
| "attn_mask": attn_mask, | |
| "dropout_p": dropout_p, | |
| "is_causal": is_causal, | |
| "scale": scale, | |
| "enable_gqa": enable_gqa, | |
| **attention_kwargs, | |
| } | |
| if _AttentionProviderRegistry._checks_enabled: | |
| removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) | |
| if removed_kwargs: | |
| log_freq = 512 | |
| msg = ( | |
| f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " | |
| f"message will be logged every {log_freq} calls." | |
| ) | |
| logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) | |
| for check in _AttentionProviderRegistry._constraints.get(provider_name): | |
| check(**kwargs) | |
| kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} | |
| if _AttentionProviderRegistry.context_parallel_enabled(): | |
| _set_context_parallel_options(**kwargs) | |
| return provider_fn(**kwargs) | |
| # ===== Helper functions ===== | |
| # @torch.compiler.assume_constant_result | |
| def _set_context_parallel_options(is_causal: bool, **kwargs): | |
| _cp_options.enable_load_balance = is_causal | |
| _cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 | |
| set_rotate_method(_AttentionProviderRegistry._rotate_method) | |
| def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: | |
| if attn_mask is not None: | |
| raise ValueError("Attention mask must be None for this provider.") | |
| def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: | |
| if attn_mask is not None and is_causal: | |
| raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") | |
| def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
| if query.device != key.device or query.device != value.device: | |
| raise ValueError("Query, key, and value must be on the same device.") | |
| if query.dtype != key.dtype or query.dtype != value.dtype: | |
| raise ValueError("Query, key, and value must have the same dtype.") | |
| def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
| _check_device(query, key, value) | |
| if query.device.type != "cuda": | |
| raise ValueError("Query, key, and value must be on a CUDA device.") | |
| def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: | |
| def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
| _check_device_cuda(query, key, value) | |
| if torch.cuda.get_device_capability(query.device) < (major, minor): | |
| raise ValueError( | |
| f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." | |
| ) | |
| return check_device_cuda | |
| def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
| if query.dtype != key.dtype: | |
| raise ValueError("Query and key must have the same dtype.") | |
| if query.dtype != value.dtype: | |
| raise ValueError("Query and value must have the same dtype.") | |
| def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: | |
| _check_qkv_dtype_match(query, key, value) | |
| if query.dtype not in (torch.bfloat16, torch.float16): | |
| raise ValueError("Query, key, and value must be either bfloat16 or float16.") | |
| def _check_shape( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> None: | |
| if query.shape[-1] != key.shape[-1]: | |
| raise ValueError("Query and key must have the same last dimension.") | |
| if query.shape[-2] != value.shape[-2]: | |
| raise ValueError("Query and value must have the same second to last dimension.") | |
| if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: | |
| raise ValueError("Attention mask must match the key's second to last dimension.") | |
| def _prepare_for_flash_attn_or_sage_varlen( | |
| batch_size: int, | |
| seq_len_q: int, | |
| seq_len_kv: int, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| device: Optional[torch.device] = None, | |
| ) -> None: | |
| seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) | |
| if attn_mask is None: | |
| seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) | |
| else: | |
| seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) | |
| cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) | |
| cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) | |
| cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) | |
| cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) | |
| max_seqlen_q = seqlens_q.max().item() | |
| max_seqlen_k = seqlens_k.max().item() | |
| return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) | |
| def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: | |
| """ | |
| Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in | |
| FlashAttention/Sage varlen. | |
| Supports 1D to 4D shapes and common broadcasting patterns. | |
| """ | |
| if attn_mask.dtype != torch.bool: | |
| raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") | |
| if attn_mask.ndim == 1: | |
| # [seq_len_k] -> broadcast across batch | |
| attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) | |
| elif attn_mask.ndim == 2: | |
| # [batch_size, seq_len_k]. Maybe broadcast across batch | |
| if attn_mask.size(0) not in [1, batch_size]: | |
| raise ValueError( | |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." | |
| ) | |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) | |
| elif attn_mask.ndim == 3: | |
| # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension | |
| if attn_mask.size(0) not in [1, batch_size]: | |
| raise ValueError( | |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." | |
| ) | |
| attn_mask = attn_mask.any(dim=1) | |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) | |
| elif attn_mask.ndim == 4: | |
| # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions | |
| if attn_mask.size(0) not in [1, batch_size]: | |
| raise ValueError( | |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." | |
| ) | |
| attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] | |
| attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] | |
| else: | |
| raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") | |
| if attn_mask.shape != (batch_size, seq_len_k): | |
| raise ValueError( | |
| f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" | |
| ) | |
| return attn_mask | |
| def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): | |
| return q_idx >= kv_idx | |
| # ===== Attention provider implementations ===== | |
| # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 | |
| class _flash_attn_flash_attention(torch.autograd.Function): | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| deterministic: bool = False, | |
| return_softmax: bool = False, | |
| ): | |
| if softmax_scale is None: | |
| softmax_scale = q.shape[-1] ** (-0.5) | |
| ctx.dropout_p = dropout_p | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| ctx.window_size = window_size | |
| ctx.softcap = softcap | |
| ctx.alibi_slopes = alibi_slopes | |
| ctx.deterministic = deterministic | |
| out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( | |
| query=q, | |
| key=k, | |
| value=v, | |
| dropout_p=dropout_p, | |
| scale=softmax_scale, | |
| is_causal=causal, | |
| window_size=window_size, | |
| softcap=softcap, | |
| alibi_slopes=alibi_slopes, | |
| return_softmax=return_softmax, | |
| ) | |
| ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) | |
| return (out, lse) if return_softmax else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| q, k, v, out, lse, rng_state = ctx.saved_tensors | |
| grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( | |
| grad_out=grad_out, | |
| query=q, | |
| key=k, | |
| value=v, | |
| out=out, | |
| logsumexp=lse, | |
| dropout_p=ctx.dropout_p, | |
| scale=ctx.softmax_scale, | |
| is_causal=ctx.causal, | |
| window_size=ctx.window_size, | |
| softcap=ctx.softcap, | |
| alibi_slopes=ctx.alibi_slopes, | |
| deterministic=ctx.deterministic, | |
| rng_state=rng_state, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | |
| # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 | |
| class _native_ring_flash_attn_flash_attention(torch.autograd.Function): | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| softmax_scale: Optional[float] = None, | |
| causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| deterministic: bool = False, | |
| return_softmax: bool = False, | |
| ): | |
| if softmax_scale is None: | |
| softmax_scale = q.shape[-1] ** (-0.5) | |
| # For ring flash attention using the flash-attn repo, we want the LSE but flash-attn only supports it if dropout_p > 0 | |
| dropout_p = dropout_p if dropout_p > 0 else 1e-30 | |
| ctx.dropout_p = dropout_p | |
| ctx.softmax_scale = softmax_scale | |
| ctx.causal = causal | |
| ctx.window_size = window_size | |
| ctx.softcap = softcap | |
| ctx.alibi_slopes = alibi_slopes | |
| ctx.deterministic = deterministic | |
| out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=_finetrainers_flash_attn_forward, | |
| query=q, | |
| key=k, | |
| value=v, | |
| dropout_p=dropout_p, | |
| scale=softmax_scale, | |
| is_causal=causal, | |
| window_size=window_size, | |
| softcap=softcap, | |
| alibi_slopes=alibi_slopes, | |
| return_softmax=True, | |
| ) | |
| ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) | |
| return (out, lse) if return_softmax else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| q, k, v, out, lse, rng_state = ctx.saved_tensors | |
| lse = lse.permute(0, 2, 1).contiguous() # [B, N, S] -> [B, S, N] | |
| grad_query, grad_key, grad_value = _templated_ring_attention_backward( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| # This needs to be 1 because q, k, v, out_padded returned from forward are BSND instead of BNSD | |
| # The grad_out permutation is handled in _finetrainers_flash_attn_backward, and the outputs from that are expected to have | |
| # shape BSND instead of BNSD (requirement of _templated_ring_attention_backward), so we need to set seq_dim=1 and permute the | |
| # returned outputs | |
| seq_dim=1, | |
| op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), | |
| grad_out=grad_out, | |
| grad_out_name="grad_out", | |
| query=q, | |
| key=k, | |
| value=v, | |
| out=out, | |
| logsumexp=lse, | |
| dropout_p=ctx.dropout_p, | |
| scale=ctx.softmax_scale, | |
| is_causal=ctx.causal, | |
| window_size=ctx.window_size, | |
| softcap=ctx.softcap, | |
| alibi_slopes=ctx.alibi_slopes, | |
| deterministic=ctx.deterministic, | |
| rng_state=rng_state, | |
| ) | |
| grad_query, grad_key, grad_value = ( | |
| x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | |
| def flash_attn_flash_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| scale: Optional[float] = None, | |
| is_causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| deterministic: bool = False, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| dispatch_fn = ( | |
| _native_ring_flash_attn_flash_attention | |
| if _AttentionProviderRegistry.context_parallel_enabled() | |
| else _flash_attn_flash_attention | |
| ) | |
| return dispatch_fn.apply( | |
| query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse | |
| ) | |
| def _flash_varlen_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| dropout_p: float = 0.0, | |
| scale: Optional[float] = None, | |
| is_causal: bool = False, | |
| window_size: Tuple[int, int] = (-1, -1), | |
| softcap: float = 0.0, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| deterministic: bool = False, | |
| return_attn_probs: bool = False, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| batch_size, _, seq_len_q, _ = query.shape | |
| _, _, seq_len_kv, _ = key.shape | |
| if attn_mask is not None: | |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | |
| if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): | |
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | |
| _prepare_for_flash_attn_or_sage_varlen( | |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | |
| ) | |
| ) | |
| else: | |
| seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) | |
| cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) | |
| cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) | |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
| key_valid, value_valid = [], [] | |
| for b in range(batch_size): | |
| valid_len = seqlens_k[b] | |
| key_valid.append(key[b, :valid_len]) | |
| value_valid.append(value[b, :valid_len]) | |
| query_packed = query.flatten(0, 1) | |
| key_packed = torch.cat(key_valid, dim=0) | |
| value_packed = torch.cat(value_valid, dim=0) | |
| if _AttentionProviderRegistry.context_parallel_enabled(): | |
| return_attn_probs = True | |
| out = flash_attn_varlen_func( | |
| q=query_packed, | |
| k=key_packed, | |
| v=value_packed, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_q, | |
| max_seqlen_k=max_seqlen_k, | |
| dropout_p=dropout_p, | |
| softmax_scale=scale, | |
| causal=is_causal, | |
| window_size=window_size, | |
| softcap=softcap, | |
| alibi_slopes=alibi_slopes, | |
| deterministic=deterministic, | |
| return_attn_probs=return_attn_probs, | |
| ) | |
| rest = None | |
| if return_attn_probs: | |
| out, *rest = out | |
| out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() | |
| if return_attn_probs: | |
| return out, *rest[:1] | |
| return out | |
| def _native_flex_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| enable_gqa: bool = False, | |
| return_lse: bool = False, | |
| kernel_options: Optional[Dict[str, Any]] = None, | |
| ) -> torch.Tensor: | |
| # TODO: should we LRU cache the block mask creation? | |
| score_mod = None | |
| block_mask = None | |
| batch_size, num_heads, seq_len_q, _ = query.shape | |
| _, _, seq_len_kv, _ = key.shape | |
| if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): | |
| block_mask = attn_mask | |
| elif is_causal: | |
| block_mask = flex_attention.create_block_mask( | |
| _flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device | |
| ) | |
| elif torch.is_tensor(attn_mask): | |
| if attn_mask.ndim == 2: | |
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) | |
| attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) | |
| if attn_mask.dtype == torch.bool: | |
| # TODO: this probably does not work but verify! | |
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx): | |
| return attn_mask[batch_idx, head_idx, q_idx, kv_idx] | |
| block_mask = flex_attention.create_block_mask( | |
| mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device | |
| ) | |
| else: | |
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): | |
| return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] | |
| else: | |
| raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") | |
| return flex_attention.flex_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| score_mod=score_mod, | |
| block_mask=block_mask, | |
| scale=scale, | |
| enable_gqa=enable_gqa, | |
| return_lse=return_lse, | |
| kernel_options=None, | |
| ) | |
| def _native_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| return native_sdpa( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_mask=attn_mask, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| enable_gqa=enable_gqa, | |
| ) | |
| class _native_cudnn_attention(torch.autograd.Function): | |
| # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 | |
| # forward declaration: | |
| # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |
| # backward declaration: | |
| # aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| ctx.attn_mask = attn_mask | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| torch.ops.aten._scaled_dot_product_cudnn_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_mask, | |
| compute_log_sumexp=True, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| return_debug_mask=False, | |
| scale=scale, | |
| ) | |
| ) | |
| ctx.max_q = max_q | |
| ctx.max_k = max_k | |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( | |
| grad_out=grad_out, | |
| query=query, | |
| key=key, | |
| value=value, | |
| out=out, | |
| logsumexp=lse, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| attn_bias=ctx.attn_mask, | |
| cum_seq_q=cum_seq_q, | |
| cum_seq_k=cum_seq_k, | |
| max_q=ctx.max_q, | |
| max_k=ctx.max_k, | |
| dropout_p=ctx.dropout_p, | |
| is_causal=ctx.is_causal, | |
| scale=ctx.scale, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None | |
| class _native_ring_native_cudnn_attention(torch.autograd.Function): | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| ctx.attn_mask = attn_mask | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| _templated_ring_attention( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=torch.ops.aten._scaled_dot_product_cudnn_attention, | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_mask, | |
| compute_log_sumexp=True, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| return_debug_mask=False, | |
| scale=scale, | |
| ) | |
| ) | |
| ctx.max_q = max_q | |
| ctx.max_k = max_k | |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
| grad_query, grad_key, grad_value = _templated_ring_attention_backward( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, | |
| grad_out=grad_out, | |
| grad_out_name="grad_out", | |
| query=query, | |
| key=key, | |
| value=value, | |
| out=out, | |
| logsumexp=lse, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| attn_bias=ctx.attn_mask, | |
| cum_seq_q=cum_seq_q, | |
| cum_seq_k=cum_seq_k, | |
| max_q=ctx.max_q, | |
| max_k=ctx.max_k, | |
| dropout_p=ctx.dropout_p, | |
| is_causal=ctx.is_causal, | |
| scale=ctx.scale, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None | |
| def native_cudnn_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| dispatch_fn = ( | |
| _native_ring_native_cudnn_attention | |
| if _AttentionProviderRegistry.context_parallel_enabled() | |
| else _native_cudnn_attention | |
| ) | |
| return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) | |
| class _native_efficient_attention(torch.autograd.Function): | |
| # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14946 | |
| # forward declaration: | |
| # aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) | |
| # backward declaration: | |
| # aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| ctx.attn_mask = attn_mask | |
| # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
| out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_mask, | |
| compute_log_sumexp=True, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| ) | |
| ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors | |
| # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
| grad_query, grad_key, grad_value, grad_attn_bias = ( | |
| _finetrainers_scaled_dot_product_efficient_attention_backward( | |
| grad_out_=grad_out, | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=ctx.attn_mask, | |
| out=out, | |
| logsumexp=lse, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| dropout_p=ctx.dropout_p, | |
| grad_input_mask=[True, True, True, False], | |
| is_causal=ctx.is_causal, | |
| scale=ctx.scale, | |
| ) | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None | |
| class _native_ring_native_efficient_attention(torch.autograd.Function): | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| ctx.attn_mask = attn_mask | |
| # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
| out, lse, philox_seed, philox_offset = _templated_ring_attention( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=_finetrainers_scaled_dot_product_efficient_attention_forward, | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=attn_mask, | |
| compute_log_sumexp=True, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| ) | |
| ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors | |
| # NOTE: Uses finetrainers registered op because of LSE alignment issue. See the op registration for more details. | |
| grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=_finetrainers_scaled_dot_product_efficient_attention_backward, | |
| grad_out=grad_out, | |
| grad_out_name="grad_out_", | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_bias=ctx.attn_mask, | |
| out=out, | |
| logsumexp=lse, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| dropout_p=ctx.dropout_p, | |
| grad_input_mask=[True, True, True, False], | |
| is_causal=ctx.is_causal, | |
| scale=ctx.scale, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None, None | |
| def native_efficient_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| dispatch_fn = ( | |
| _native_ring_native_efficient_attention | |
| if _AttentionProviderRegistry.context_parallel_enabled() | |
| else _native_efficient_attention | |
| ) | |
| return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) | |
| class _native_flash_attention(torch.autograd.Function): | |
| # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14910 | |
| # forward declaration: | |
| # aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) | |
| # backward declaration: | |
| # aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| torch.ops.aten._scaled_dot_product_flash_attention( | |
| query=query, | |
| key=key, | |
| value=value, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| return_debug_mask=False, | |
| scale=scale, | |
| ) | |
| ) | |
| ctx.max_q = max_q | |
| ctx.max_k = max_k | |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( | |
| grad_out=grad_out, | |
| query=query, | |
| key=key, | |
| value=value, | |
| out=out, | |
| logsumexp=lse, | |
| cum_seq_q=cum_seq_q, | |
| cum_seq_k=cum_seq_k, | |
| max_q=ctx.max_q, | |
| max_k=ctx.max_k, | |
| dropout_p=ctx.dropout_p, | |
| is_causal=ctx.is_causal, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| scale=ctx.scale, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None | |
| class _native_ring_native_flash_attention(torch.autograd.Function): | |
| def forward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| ctx.dropout_p = dropout_p | |
| ctx.is_causal = is_causal | |
| ctx.scale = scale | |
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( | |
| _templated_ring_attention( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=torch.ops.aten._scaled_dot_product_flash_attention, | |
| query=query, | |
| key=key, | |
| value=value, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| ) | |
| ) | |
| ctx.max_q = max_q | |
| ctx.max_k = max_k | |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) | |
| return (out, lse) if return_lse else out | |
| def backward( | |
| ctx: torch.autograd.function.FunctionCtx, | |
| grad_out: torch.Tensor, | |
| *args: torch.Tensor, | |
| ): | |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() | |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors | |
| grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( | |
| mesh=_AttentionProviderRegistry._mesh, | |
| seq_dim=2, | |
| op=torch.ops.aten._scaled_dot_product_flash_attention_backward, | |
| grad_out=grad_out, | |
| grad_out_name="grad_out", | |
| query=query, | |
| key=key, | |
| value=value, | |
| out=out, | |
| logsumexp=lse, | |
| dropout_p=ctx.dropout_p, | |
| is_causal=ctx.is_causal, | |
| scale=ctx.scale, | |
| cum_seq_q=cum_seq_q, | |
| cum_seq_k=cum_seq_k, | |
| max_q=ctx.max_q, | |
| max_k=ctx.max_k, | |
| philox_seed=philox_seed, | |
| philox_offset=philox_offset, | |
| ) | |
| return grad_query, grad_key, grad_value, None, None, None, None | |
| def native_flash_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| dispatch_fn = ( | |
| _native_ring_native_flash_attention | |
| if _AttentionProviderRegistry.context_parallel_enabled() | |
| else _native_flash_attention | |
| ) | |
| return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) | |
| # class _native_math_attention(torch.autograd.Function): | |
| # # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14901 | |
| # # forward declaration: | |
| # # aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0., bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) | |
| # # backward declaration: | |
| # # does not exist | |
| # @staticmethod | |
| # def forward( | |
| # ctx: torch.autograd.function.FunctionCtx, | |
| # query: torch.Tensor, | |
| # key: torch.Tensor, | |
| # value: torch.Tensor, | |
| # attn_mask: Optional[torch.Tensor] = None, | |
| # dropout_p: float = 0.0, | |
| # is_causal: bool = False, | |
| # dropout_mask: Optional[torch.Tensor] = None, | |
| # scale: Optional[float] = None, | |
| # enable_gqa: bool = False, | |
| # return_scores: bool = False, | |
| # ): | |
| # ctx.dropout_p = dropout_p | |
| # ctx.is_causal = is_causal | |
| # ctx.scale = scale | |
| # ctx.enable_gqa = enable_gqa | |
| # print(f"query.shape: {query.shape}") | |
| # with torch.enable_grad(): | |
| # out, scores = torch.ops.aten._scaled_dot_product_attention_math( | |
| # query=query, | |
| # key=key, | |
| # value=value, | |
| # attn_mask=attn_mask, | |
| # dropout_p=dropout_p, | |
| # is_causal=is_causal, | |
| # dropout_mask=dropout_mask, | |
| # scale=scale, | |
| # enable_gqa=enable_gqa, | |
| # ) | |
| # ctx.save_for_backward(query, key, value, out) | |
| # return (out, scores) if return_scores else out | |
| # @staticmethod | |
| # def backward( | |
| # ctx: torch.autograd.function.FunctionCtx, | |
| # grad_out: torch.Tensor, | |
| # ): | |
| # raise NotImplementedError("Backward pass for native math attention is not implemented.") | |
| def native_math_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): | |
| return native_sdpa( | |
| query=query, | |
| key=key, | |
| value=value, | |
| attn_mask=attn_mask, | |
| dropout_p=dropout_p, | |
| is_causal=is_causal, | |
| scale=scale, | |
| enable_gqa=enable_gqa, | |
| ) | |
| def _sage_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| if _AttentionProviderRegistry.context_parallel_enabled(): | |
| return_lse = True | |
| kwargs = { | |
| "q": query, | |
| "k": key, | |
| "v": value, | |
| "tensor_layout": "HND", | |
| "is_causal": is_causal, | |
| "sm_scale": scale, | |
| "return_lse": return_lse, | |
| } | |
| out = sageattn(**kwargs) | |
| rest = None | |
| if return_lse: | |
| out, *rest = out | |
| if return_lse: | |
| return out, *rest[:1] | |
| return out | |
| def _sage_varlen_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| cu_seqlens_q: Optional[torch.Tensor] = None, | |
| cu_seqlens_k: Optional[torch.Tensor] = None, | |
| max_seqlen_q: Optional[int] = None, | |
| max_seqlen_k: Optional[int] = None, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| smooth_k: bool = True, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| batch_size, _, seq_len_q, _ = query.shape | |
| _, _, seq_len_kv, _ = key.shape | |
| if attn_mask is not None: | |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | |
| if enable_gqa: | |
| # TODO | |
| pass | |
| if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): | |
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | |
| _prepare_for_flash_attn_or_sage_varlen( | |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | |
| ) | |
| ) | |
| else: | |
| seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) | |
| cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) | |
| cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) | |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
| key_valid, value_valid = [], [] | |
| for b in range(batch_size): | |
| valid_len = seqlens_k[b] | |
| key_valid.append(key[b, :valid_len]) | |
| value_valid.append(value[b, :valid_len]) | |
| query_packed = query.flatten(0, 1) | |
| key_packed = torch.cat(key_valid, dim=0) | |
| value_packed = torch.cat(value_valid, dim=0) | |
| out = sageattn_varlen( | |
| q=query_packed, | |
| k=key_packed, | |
| v=value_packed, | |
| cu_seqlens_q=cu_seqlens_q, | |
| cu_seqlens_k=cu_seqlens_k, | |
| max_seqlen_q=max_seqlen_q, | |
| max_seqlen_k=max_seqlen_k, | |
| is_causal=is_causal, | |
| sm_scale=scale, | |
| smooth_k=smooth_k, | |
| ) | |
| out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) # .contiguous() | |
| return out | |
| def _sage_qk_int8_pv_fp8_cuda_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
| smooth_k: bool = True, | |
| smooth_v: bool = False, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| return sageattn_qk_int8_pv_fp8_cuda( | |
| q=query, | |
| k=key, | |
| v=value, | |
| tensor_layout="HND", | |
| is_causal=is_causal, | |
| qk_quant_gran=qk_quant_gran, | |
| sm_scale=scale, | |
| pv_accum_dtype=pv_accum_dtype, | |
| smooth_k=smooth_k, | |
| smooth_v=smooth_v, | |
| return_lse=return_lse, | |
| ) | |
| def _sage_qk_int8_pv_fp8_cuda_sm90_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
| smooth_k: bool = True, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| return sageattn_qk_int8_pv_fp8_cuda_sm90( | |
| q=query, | |
| k=key, | |
| v=value, | |
| tensor_layout="HND", | |
| is_causal=is_causal, | |
| qk_quant_gran=qk_quant_gran, | |
| sm_scale=scale, | |
| pv_accum_dtype=pv_accum_dtype, | |
| smooth_k=smooth_k, | |
| return_lse=return_lse, | |
| ) | |
| def _sage_qk_int8_pv_fp16_cuda_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", | |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", | |
| smooth_k: bool = True, | |
| smooth_v: bool = False, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| return sageattn_qk_int8_pv_fp16_cuda( | |
| q=query, | |
| k=key, | |
| v=value, | |
| tensor_layout="HND", | |
| is_causal=is_causal, | |
| qk_quant_gran=qk_quant_gran, | |
| sm_scale=scale, | |
| pv_accum_dtype=pv_accum_dtype, | |
| smooth_k=smooth_k, | |
| smooth_v=smooth_v, | |
| return_lse=return_lse, | |
| ) | |
| def _sage_qk_int8_pv_fp16_triton_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", | |
| smooth_k: bool = True, | |
| return_lse: bool = False, | |
| ) -> torch.Tensor: | |
| return sageattn_qk_int8_pv_fp16_triton( | |
| q=query, | |
| k=key, | |
| v=value, | |
| tensor_layout="HND", | |
| quantization_backend=quantization_backend, | |
| is_causal=is_causal, | |
| sm_scale=scale, | |
| smooth_k=smooth_k, | |
| return_lse=return_lse, | |
| ) | |
| def _xformers_attention( | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attn_mask: Optional[torch.Tensor] = None, | |
| dropout_p: float = 0.0, | |
| is_causal: bool = False, | |
| scale: Optional[float] = None, | |
| enable_gqa: bool = False, | |
| ) -> torch.Tensor: | |
| batch_size, num_heads_q, seq_len_q, _ = query.shape | |
| _, num_heads_kv, seq_len_kv, _ = key.shape | |
| # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns | |
| if is_causal: | |
| attn_mask = xops.LowerTriangularMask() | |
| elif attn_mask is not None: | |
| if attn_mask.ndim == 2: | |
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) | |
| elif attn_mask.ndim != 4: | |
| raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") | |
| attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) | |
| # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers | |
| # query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value)) | |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) | |
| if enable_gqa: | |
| if num_heads_q % num_heads_kv != 0: | |
| raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") | |
| num_heads_per_group = num_heads_q // num_heads_kv | |
| query = query.unflatten(2, (num_heads_kv, -1)) | |
| key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) | |
| value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) | |
| out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) | |
| if enable_gqa: | |
| out = out.flatten(2, 3) | |
| out = out.permute(0, 2, 1, 3) # .contiguous() | |
| return out | |