Skip to content

merge paged attention feature and moe feature into llama_fp8_12062024 #370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: llama_fp8_12062024
Choose a base branch
from
Draft
6 changes: 4 additions & 2 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def get_rocm_tuning_space(use_fp16):
# small search space, no pruning required
# bypassLDS: block_n/num_warps=16 for perf
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [128] if use_fp16 else [64]
block_n_range = [128] if use_fp16 else [128]
block_k_range = [128] if use_fp16 else [256]

num_warps_range = [8] if use_fp16 else [4]
num_warps_range = [8] if use_fp16 else [8]
group_m_range = [1]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
Expand Down Expand Up @@ -211,6 +211,8 @@ def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
assert config['num_warps'] == config['BLOCK_SIZE_N'] // 16, \
"num_warps should be equal to BLOCK_SIZE_N divided by 16"
configs.append(config)
return configs

Expand Down
10 changes: 6 additions & 4 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
create_kv_caches_with_random)

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE = 256


@torch.inference_mode()
Expand Down Expand Up @@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = 0.1

for _ in range(num_iters):
if version == "v1":
Expand Down Expand Up @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
None,
PARTITION_SIZE
)
else:
raise ValueError(f"Invalid version: {version}")
Expand All @@ -174,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)
run_benchmark(num_iters=500, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=1000, profile=False)
latency = run_benchmark(num_iters=10000, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


Expand Down
1 change: 1 addition & 0 deletions benchmarks/kernels/tune_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export FUSED_MOE_PERSISTENT=1
export VLLM_MOE_PADDING=128
export VLLM_MOE_SHUFFLE=1
export TRITON_HIP_USE_NEW_STREAM_PIPELINE=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

## ---- Mixtral fp8 tuning ---- ##

Expand Down
177 changes: 173 additions & 4 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,185 @@
#include "common.cuh"
#include "dispatch_utils.h"

#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#if defined(USE_CUDA_FP8_FORMAT)
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#if defined(USE_CUDA_FP8_FORMAT)
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;
}

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#if defined(USE_CUDA_FP8_FORMAT)
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;

__syncthreads();

// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}

template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};

typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

return absmax_val;
}

template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int64_t const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,
Expand Down
Loading
Loading