| #pragma once |
|
|
| #include "../../../attention/attention_dtypes.h" |
| #include <assert.h> |
| #include <float.h> |
| #include <stdint.h> |
| #include <type_traits> |
|
|
| namespace vllm { |
| #ifndef USE_ROCM |
|
|
| namespace fp8 { |
| #ifdef ENABLE_FP8 |
|
|
| #if 0 |
| template <typename Tout, typename Tin> |
| __inline__ __device__ Tout |
| vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { |
| return x; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>( |
| const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); |
| return res.x; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>( |
| const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint16_t u16[2]; |
| uint32_t u32; |
| } tmp; |
| __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); |
| tmp.u16[0] = res.x; |
| tmp.u16[1] = res.y; |
| return tmp.u32; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>( |
| const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint2 u32x2; |
| uint32_t u32[2]; |
| } tmp; |
| tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type); |
| tmp.u32[1] = |
| vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type); |
| return tmp.u32x2; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint4 vec_conversion<uint4, uint2>( |
| const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint4 u64x2; |
| uint2 u64[2]; |
| } tmp; |
| tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type); |
| tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type); |
| return tmp.u64x2; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( |
| const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| |
| |
| __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); |
| |
| float tmp = half_to_float(res.x); |
| return __float2bfloat16(tmp); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( |
| const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| __nv_bfloat162 res; |
| res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); |
| res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>( |
| const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| bf16_4_t res; |
| res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); |
| res.y = |
| vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>( |
| const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { |
| bf16_4_t tmp1, tmp2; |
| tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type); |
| tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type); |
| bf16_8_t res; |
| res.x = tmp1.x; |
| res.y = tmp1.y; |
| res.z = tmp2.x; |
| res.w = tmp2.y; |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float |
| vec_conversion<float, uint8_t>(const uint8_t &a, |
| const __nv_fp8_interpretation_t fp8_type) { |
| |
| uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type); |
| |
| return half_to_float(tmp); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float2 vec_conversion<float2, uint16_t>( |
| const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| |
| uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type); |
| |
| return half2_to_float2(tmp); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>( |
| const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ res; |
| res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type); |
| res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>( |
| const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ tmp1, tmp2; |
| tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type); |
| tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type); |
| Float8_ res; |
| res.x = tmp1.x; |
| res.y = tmp1.y; |
| res.z = tmp2.x; |
| res.w = tmp2.y; |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>( |
| const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| __half_raw tmp; |
| tmp.x = a; |
| __nv_fp8_storage_t res = |
| __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>( |
| const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
| assert(false); |
| #else |
| __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( |
| __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| #endif |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t vec_conversion<uint8_t, float>( |
| const float &a, const __nv_fp8_interpretation_t fp8_type) { |
| __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float4 vec_conversion<float4, uint32_t>( |
| const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type); |
| float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); |
| return res; |
| } |
|
|
| template <> |
| __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>( |
| const float2 &a, const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| half2 float16; |
| uint32_t uint32; |
| }; |
|
|
| float16 = __float22half2_rn(a); |
| return uint32; |
| } |
|
|
| template <> |
| __inline__ __device__ uint2 vec_conversion<uint2, Float4_>( |
| const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { |
| uint2 b; |
| float2 val; |
| val.x = a.x.x; |
| val.y = a.x.y; |
| b.x = vec_conversion<uint32_t, float2>(val, fp8_type); |
|
|
| val.x = a.y.x; |
| val.y = a.y.y; |
| b.y = vec_conversion<uint32_t, float2>(val, fp8_type); |
|
|
| return b; |
| } |
|
|
| template <> |
| __inline__ __device__ float4 vec_conversion<float4, Float4_>( |
| const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { |
| float4 b; |
| b.x = a.x.x; |
| b.y = a.x.y; |
| b.z = a.y.x; |
| b.w = a.y.y; |
| return b; |
| } |
|
|
| template <> |
| __inline__ __device__ uint4 vec_conversion<uint4, Float8_>( |
| const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { |
| uint4 b; |
| b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type); |
| b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type); |
| b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type); |
| b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type); |
| return b; |
| } |
|
|
| template <> |
| __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( |
| const float2 &a, const __nv_fp8_interpretation_t fp8_type) { |
| __nv_bfloat162 b; |
| from_float(b, a); |
| return b; |
| } |
|
|
| template <> |
| __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>( |
| const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { |
| bf16_4_t b; |
| from_float(b, a); |
| return b; |
| } |
|
|
| template <> |
| __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>( |
| const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { |
| bf16_8_t b; |
| from_float(b, a); |
| return b; |
| } |
| #endif |
|
|
| |
| |
| |
| |
| |
|
|
| template <typename Tout, typename Tin> |
| __inline__ __device__ Tout scaled_vec_conversion( |
| const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { |
| return x; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>( |
| const uint8_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); |
| return float_to_half(half_to_float(tmp.x) * scale); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>( |
| const uint16_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint16_t u16[2]; |
| uint32_t u32; |
| } tmp; |
| __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); |
| tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); |
| tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); |
| return tmp.u32; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>( |
| const uint32_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint2 u32x2; |
| uint32_t u32[2]; |
| } tmp; |
| tmp.u32[0] = |
| scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type); |
| tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), |
| scale, fp8_type); |
| return tmp.u32x2; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint4 |
| scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| union { |
| uint4 u64x2; |
| uint2 u64[2]; |
| } tmp; |
| tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type); |
| tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type); |
| return tmp.u64x2; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ __nv_bfloat16 |
| scaled_vec_conversion<__nv_bfloat16, uint8_t>( |
| const uint8_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| |
| |
| __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); |
| |
| float tmp = half_to_float(res.x); |
| return __float2bfloat16(tmp * scale); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ __nv_bfloat162 |
| scaled_vec_conversion<__nv_bfloat162, uint16_t>( |
| const uint16_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| __nv_bfloat162 res; |
| res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, |
| fp8_type); |
| res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), |
| scale, fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>( |
| const uint32_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| bf16_4_t res; |
| res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, |
| fp8_type); |
| res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), |
| scale, fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>( |
| const uint2& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| bf16_4_t tmp1, tmp2; |
| tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type); |
| tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type); |
| bf16_8_t res; |
| res.x = tmp1.x; |
| res.y = tmp1.y; |
| res.z = tmp2.x; |
| res.w = tmp2.y; |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float scaled_vec_conversion<float, uint8_t>( |
| const uint8_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| |
| __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); |
| uint16_t tmp = res.x; |
|
|
| |
| return half_to_float(tmp) * scale; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>( |
| const uint16_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| |
| uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type); |
| |
| return half2_to_float2(tmp); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>( |
| const uint32_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ res; |
| res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type); |
| res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, |
| fp8_type); |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>( |
| const uint2& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ tmp1, tmp2; |
| tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type); |
| tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type); |
| Float8_ res; |
| res.x = tmp1.x; |
| res.y = tmp1.y; |
| res.z = tmp2.x; |
| res.w = tmp2.y; |
| return res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>( |
| const uint16_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| __nv_fp8_storage_t res = |
| __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>( |
| const __nv_bfloat16& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 |
| assert(false); |
| #else |
| __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, |
| __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| #endif |
| __builtin_unreachable(); |
| } |
|
|
| |
| template <> |
| __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>( |
| const float& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| __nv_fp8_storage_t res = |
| __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); |
| return (uint8_t)res; |
| } |
|
|
| |
| template <> |
| __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>( |
| const uint32_t& a, const float scale, |
| const __nv_fp8_interpretation_t fp8_type) { |
| Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type); |
| float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); |
| return res; |
| } |
| #endif |
|
|
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> |
| __inline__ __device__ Tout convert(const Tin& x) { |
| #if 0 |
| if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { |
| return vec_conversion<Tout, Tin>(x, __NV_E4M3); |
| } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { |
| return vec_conversion<Tout, Tin>(x, __NV_E5M2); |
| } |
| #endif |
| assert(false); |
| __builtin_unreachable(); |
| } |
|
|
| template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> |
| __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { |
| #ifdef ENABLE_FP8 |
| if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { |
| return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3); |
| } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { |
| return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2); |
| } |
| #endif |
| assert(false); |
| __builtin_unreachable(); |
| } |
|
|
| |
| |
| |
| |
| #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ |
| if (KV_DTYPE == "auto") { \ |
| if (SRC_DTYPE == at::ScalarType::Float) { \ |
| FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ |
| } else if (SRC_DTYPE == at::ScalarType::Half) { \ |
| FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ |
| } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ |
| FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ |
| } else { \ |
| TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ |
| } \ |
| } else { \ |
| if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ |
| if (SRC_DTYPE == at::ScalarType::Float) { \ |
| FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ |
| } else if (SRC_DTYPE == at::ScalarType::Half) { \ |
| FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ |
| } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ |
| FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ |
| } else { \ |
| TORCH_CHECK(false, \ |
| "Unsupported input type of kv cache: ", SRC_DTYPE); \ |
| } \ |
| } else if (KV_DTYPE == "fp8_e5m2") { \ |
| if (SRC_DTYPE == at::ScalarType::Float) { \ |
| FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ |
| } else if (SRC_DTYPE == at::ScalarType::Half) { \ |
| FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ |
| } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ |
| FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2); \ |
| } else { \ |
| TORCH_CHECK(false, \ |
| "Unsupported input type of kv cache: ", SRC_DTYPE); \ |
| } \ |
| } else { \ |
| TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ |
| } \ |
| } |
|
|
| } |
| #endif |
| } |
|
|