| | #include <ATen/cuda/CUDAContext.h> |
| | #include <torch/all.h> |
| |
|
| | #include <cmath> |
| |
|
| | #include "../dispatch_utils.h" |
| | #include "../vectorization_utils.cuh" |
| |
|
| | #ifndef USE_ROCM |
| | #include <cub/cub.cuh> |
| | #include <cub/util_type.cuh> |
| | #else |
| | #include <hipcub/hipcub.hpp> |
| | #include <hipcub/util_type.hpp> |
| | #endif |
| |
|
| | static inline __device__ int8_t float_to_int8_rn(float x) { |
| | #ifdef USE_ROCM |
| | static constexpr auto i8_min = |
| | static_cast<float>(std::numeric_limits<int8_t>::min()); |
| | static constexpr auto i8_max = |
| | static_cast<float>(std::numeric_limits<int8_t>::max()); |
| |
|
| | |
| | |
| | |
| | |
| | float dst = std::nearbyint(x); |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; |
| | return static_cast<int8_t>(dst); |
| | #else |
| | |
| | uint32_t dst; |
| | asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); |
| | return reinterpret_cast<const int8_t&>(dst); |
| | #endif |
| | } |
| |
|
| | static inline __device__ int32_t float_to_int32_rn(float x) { |
| | #ifdef USE_ROCM |
| | |
| | |
| | |
| | |
| | static constexpr auto i32_min = std::numeric_limits<int32_t>::min(); |
| | static constexpr auto i32_min_f = static_cast<float>(i32_min); |
| | static constexpr auto i32_max = std::numeric_limits<int32_t>::max(); |
| | static constexpr auto i32_max_f = static_cast<float>(i32_max); |
| |
|
| | |
| | |
| | |
| | |
| | float dst = std::nearbyint(x); |
| |
|
| | |
| | if (dst >= i32_max_f) { |
| | return i32_max; |
| | } |
| | |
| | if (dst <= i32_min_f) { |
| | return i32_min; |
| | } |
| |
|
| | return static_cast<int32_t>(dst); |
| | #else |
| | |
| | uint32_t dst; |
| | asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); |
| | return reinterpret_cast<const int32_t&>(dst); |
| | #endif |
| | } |
| |
|
| | static inline __device__ int8_t int32_to_int8(int32_t x) { |
| | #ifdef USE_ROCM |
| | static constexpr auto i8_min = |
| | static_cast<int32_t>(std::numeric_limits<int8_t>::min()); |
| | static constexpr auto i8_max = |
| | static_cast<int32_t>(std::numeric_limits<int8_t>::max()); |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; |
| | return static_cast<int8_t>(dst); |
| | #else |
| | |
| | uint32_t dst; |
| | asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); |
| | return reinterpret_cast<const int8_t&>(dst); |
| | #endif |
| | } |
| |
|
| | namespace vllm { |
| |
|
| | template <typename scalar_t, typename scale_t> |
| | __global__ void static_scaled_int8_quant_kernel( |
| | const scalar_t* __restrict__ input, int8_t* __restrict__ output, |
| | const scale_t* scale_ptr, const int hidden_size) { |
| | const int tid = threadIdx.x; |
| | const int stride = blockDim.x; |
| | const int64_t token_idx = blockIdx.x; |
| | const float scale = *scale_ptr; |
| |
|
| | |
| | const scalar_t* row_in = input + token_idx * hidden_size; |
| | int8_t* row_out = output + token_idx * hidden_size; |
| |
|
| | vectorize_with_alignment<16>( |
| | row_in, row_out, hidden_size, tid, stride, |
| | [=] __device__(int8_t& dst, const scalar_t& src) { |
| | dst = float_to_int8_rn(static_cast<float>(src) / scale); |
| | }); |
| | } |
| |
|
| | template <typename scalar_t, typename scale_t, typename azp_t> |
| | __global__ void static_scaled_int8_azp_quant_kernel( |
| | const scalar_t* __restrict__ input, int8_t* __restrict__ output, |
| | const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) { |
| | const int tid = threadIdx.x; |
| | const int stride = blockDim.x; |
| | const int64_t token_idx = blockIdx.x; |
| | const float scale = *scale_ptr; |
| | const azp_t azp = *azp_ptr; |
| | const float inv_s = 1.0f / scale; |
| |
|
| | |
| | const scalar_t* row_in = input + token_idx * hidden_size; |
| | int8_t* row_out = output + token_idx * hidden_size; |
| |
|
| | vectorize_with_alignment<16>( |
| | row_in, row_out, hidden_size, tid, stride, |
| | [=] __device__(int8_t& dst, const scalar_t& src) { |
| | const auto v = static_cast<float>(src) * inv_s; |
| | dst = int32_to_int8(float_to_int32_rn(v) + azp); |
| | }); |
| | } |
| |
|
| | template <typename scalar_t, typename scale_t> |
| | __global__ void dynamic_scaled_int8_quant_kernel( |
| | const scalar_t* __restrict__ input, int8_t* __restrict__ output, |
| | scale_t* scale_out, const int hidden_size) { |
| | const int tid = threadIdx.x; |
| | const int stride = blockDim.x; |
| | const int64_t token_idx = blockIdx.x; |
| |
|
| | |
| | const scalar_t* row_in = input + token_idx * hidden_size; |
| | int8_t* row_out = output + token_idx * hidden_size; |
| |
|
| | |
| | float thread_max = 0.f; |
| | for (int i = tid; i < hidden_size; i += stride) { |
| | const auto v = fabsf(static_cast<float>(row_in[i])); |
| | thread_max = fmaxf(thread_max, v); |
| | } |
| | using BlockReduce = cub::BlockReduce<float, 256>; |
| | __shared__ typename BlockReduce::TempStorage tmp; |
| | float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); |
| | __shared__ float absmax; |
| | if (tid == 0) { |
| | absmax = block_max; |
| | scale_out[blockIdx.x] = absmax / 127.f; |
| | } |
| | __syncthreads(); |
| |
|
| | float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; |
| |
|
| | |
| | vectorize_with_alignment<16>( |
| | row_in, row_out, hidden_size, tid, stride, |
| | [=] __device__(int8_t& dst, const scalar_t& src) { |
| | dst = float_to_int8_rn(static_cast<float>(src) * inv_s); |
| | }); |
| | } |
| |
|
| | |
| | struct MinMax { |
| | float min, max; |
| |
|
| | __host__ __device__ MinMax() |
| | : min(std::numeric_limits<float>::max()), |
| | max(std::numeric_limits<float>::lowest()) {} |
| |
|
| | __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} |
| |
|
| | |
| | __host__ __device__ MinMax& operator+=(float v) { |
| | min = fminf(min, v); |
| | max = fmaxf(max, v); |
| | return *this; |
| | } |
| |
|
| | |
| | __host__ __device__ MinMax& operator&=(const MinMax& other) { |
| | min = fminf(min, other.min); |
| | max = fmaxf(max, other.max); |
| | return *this; |
| | } |
| | }; |
| |
|
| | __host__ __device__ inline MinMax operator+(MinMax a, float v) { |
| | return a += v; |
| | } |
| | __host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) { |
| | return a &= b; |
| | } |
| |
|
| | template <typename scalar_t, typename scale_t, typename azp_t> |
| | __global__ void dynamic_scaled_int8_azp_quant_kernel( |
| | const scalar_t* __restrict__ input, int8_t* __restrict__ output, |
| | scale_t* scale_out, azp_t* azp_out, const int hidden_size) { |
| | const int tid = threadIdx.x; |
| | const int stride = blockDim.x; |
| | const int64_t token_idx = blockIdx.x; |
| |
|
| | |
| | const scalar_t* row_in = input + token_idx * hidden_size; |
| | int8_t* row_out = output + token_idx * hidden_size; |
| |
|
| | |
| | MinMax thread_mm; |
| | for (int i = tid; i < hidden_size; i += stride) { |
| | thread_mm += static_cast<float>(row_in[i]); |
| | } |
| |
|
| | using BlockReduce = cub::BlockReduce<MinMax, 256>; |
| | __shared__ typename BlockReduce::TempStorage tmp; |
| |
|
| | MinMax mm = BlockReduce(tmp).Reduce( |
| | thread_mm, |
| | [] __device__(MinMax a, const MinMax& b) { |
| | a &= b; |
| | return a; |
| | }, |
| | blockDim.x); |
| |
|
| | __shared__ float scale_sh; |
| | __shared__ azp_t azp_sh; |
| | if (tid == 0) { |
| | float s = (mm.max - mm.min) / 255.f; |
| | float zp = nearbyintf(-128.f - mm.min / s); |
| | scale_sh = s; |
| | azp_sh = azp_t(zp); |
| | scale_out[blockIdx.x] = s; |
| | azp_out[blockIdx.x] = azp_sh; |
| | } |
| | __syncthreads(); |
| |
|
| | const float inv_s = 1.f / scale_sh; |
| | const azp_t azp = azp_sh; |
| |
|
| | |
| | vectorize_with_alignment<16>( |
| | row_in, row_out, hidden_size, tid, stride, |
| | [=] __device__(int8_t& dst, const scalar_t& src) { |
| | const auto v = static_cast<float>(src) * inv_s; |
| | dst = int32_to_int8(float_to_int32_rn(v) + azp); |
| | }); |
| | } |
| |
|
| | } |
| |
|
| | void static_scaled_int8_quant(torch::Tensor& out, |
| | torch::Tensor const& input, |
| | torch::Tensor const& scale, |
| | std::optional<torch::Tensor> const& azp) { |
| | TORCH_CHECK(input.is_contiguous()); |
| | TORCH_CHECK(out.is_contiguous()); |
| | TORCH_CHECK(scale.numel() == 1); |
| | TORCH_CHECK(!azp || azp->numel() == 1); |
| |
|
| | int const hidden_size = input.size(-1); |
| | int const num_tokens = input.numel() / hidden_size; |
| | dim3 const grid(num_tokens); |
| | dim3 const block(std::min(hidden_size, 256)); |
| | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| | VLLM_DISPATCH_FLOATING_TYPES( |
| | input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { |
| | if (!azp) { |
| | vllm::static_scaled_int8_quant_kernel<scalar_t, float> |
| | <<<grid, block, 0, stream>>>( |
| | input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
| | scale.data_ptr<float>(), hidden_size); |
| | } else { |
| | vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t> |
| | <<<grid, block, 0, stream>>>( |
| | input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
| | scale.data_ptr<float>(), azp->data_ptr<int32_t>(), |
| | hidden_size); |
| | } |
| | }); |
| | } |
| |
|
| | void dynamic_scaled_int8_quant( |
| | torch::Tensor& out, |
| | torch::Tensor const& input, |
| | torch::Tensor& scales, std::optional<torch::Tensor> const& azp) { |
| | TORCH_CHECK(input.is_contiguous()); |
| | TORCH_CHECK(out.is_contiguous()); |
| | TORCH_CHECK(scales.is_contiguous()); |
| | TORCH_CHECK(!azp || azp->is_contiguous()); |
| |
|
| | int const hidden_size = input.size(-1); |
| | int const num_tokens = input.numel() / hidden_size; |
| | dim3 const grid(num_tokens); |
| | dim3 const block(std::min(hidden_size, 256)); |
| | const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| | VLLM_DISPATCH_FLOATING_TYPES( |
| | input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { |
| | if (!azp) { |
| | vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float> |
| | <<<grid, block, 0, stream>>>( |
| | input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
| | scales.data_ptr<float>(), hidden_size); |
| | } else { |
| | vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t> |
| | <<<grid, block, 0, stream>>>( |
| | input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), |
| | scales.data_ptr<float>(), azp->data_ptr<int32_t>(), |
| | hidden_size); |
| | } |
| | }); |
| | } |
| |
|