From e62a50d2090c5a1c11daf49243404ad435fb2735 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 20 Jun 2024 11:11:50 +0800 Subject: [PATCH] weight_only_linear arch check 89/90 (#65295) --- .../gpu/fused_weight_only_linear_pass.cc | 2 +- paddle/phi/infermeta/fusion.cc | 11 +++++++---- paddle/phi/infermeta/unary.cc | 5 +++-- .../phi/kernels/cpu/weight_quantize_kernel.cc | 8 +++++--- paddle/phi/kernels/fusion/gpu/block_attn.h | 16 ++++++++-------- .../kernels/gpu/weight_only_linear_kernel.cu | 6 ++++-- .../phi/kernels/gpu/weight_quantize_kernel.cu | 5 +++-- .../impl/weight_quantize_kernel_gpu_impl.h | 3 ++- python/paddle/nn/quant/quantized_linear.py | 18 ++++++++++++++---- 9 files changed, 47 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc index 5babd4072a7b05..1904cbfbbb5722 100644 --- a/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc @@ -311,7 +311,7 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { bool CanApplyOn(pir::Operation *op) const override { if (sm_version_ != 70 && sm_version_ != 75 && sm_version_ != 80 && - sm_version_ != 86) { + sm_version_ != 86 && sm_version_ != 89 && sm_version_ != 90) { return false; } return op->num_regions() > 0; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 455bb70ad63d51..de2a3eeef4278e 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -295,10 +295,13 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const int total_num_head = qkv.dims()[qkv.dims().size() - 1] / dim_head; const int q_num_head = total_num_head - 2 * kv_num_head; - PADDLE_ENFORCE_EQ(q_num_head % kv_num_head, - 0, - errors::InvalidArgument( - "The q num_head must be divisible by kv_num_head")); + PADDLE_ENFORCE_EQ( + q_num_head % kv_num_head, + 0, + errors::InvalidArgument( + "The q num_head (%d) must be divisible by kv num_head (%d)", + q_num_head, + kv_num_head)); PADDLE_ENFORCE_EQ( input_dims.size(), 2UL, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 90c7e2726362b2..b8c3742183146d 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5783,10 +5783,11 @@ void WeightQuantizeInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* scale) { PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 86) || (arch == 70) || (arch == 75)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86.")); + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc index 61304e43d4e85a..bcfcee831b2651 100644 --- a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc +++ b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc @@ -35,10 +35,11 @@ void quant_compute(const DeviceContext& dev_ctx, const int32_t arch, const int32_t group_size) { PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86.")); + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); const auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( @@ -104,7 +105,8 @@ void quant_compute(const DeviceContext& dev_ctx, for (int i = 0; i < out->numel(); ++i) { out_data[i] = x_int_data[i]; } - } else if ((arch == 80) || (arch == 75) || (arch == 86)) { + } else if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || + (arch == 75)) { permute_B_rows_for_mixed_gemm( int_processed_data, x_int_data, std::vector{m, n}); subbyte_transpose_impl( diff --git a/paddle/phi/kernels/fusion/gpu/block_attn.h b/paddle/phi/kernels/fusion/gpu/block_attn.h index df1ad4d952b39d..63f0a1d8303bab 100644 --- a/paddle/phi/kernels/fusion/gpu/block_attn.h +++ b/paddle/phi/kernels/fusion/gpu/block_attn.h @@ -1631,11 +1631,11 @@ inline cudaError_t GetNumBlocks(Func func, template __global__ void cache_int8_kernel( - const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * gqa_group_size, + const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads, // head_size] - uint8_t *__restrict__ key_cache, // [num_blocks, gqa_group_size, + uint8_t *__restrict__ key_cache, // [num_blocks, kv_num_heads, // block_size, head_size] - uint8_t *__restrict__ value_cache, // [num_blocks, gqa_group_size, + uint8_t *__restrict__ value_cache, // [num_blocks, kv_num_heads, // block_size, head_size] const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] const int *__restrict__ padding_offsets, // [num_tokens] @@ -1715,11 +1715,11 @@ __global__ void cache_int8_kernel( template __global__ void cache_kernel( - const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * gqa_group_size, - // head_size] - T *__restrict__ key_cache, // [num_blocks, gqa_group_size, block_size, - // head_size] - T *__restrict__ value_cache, // [num_blocks, gqa_group_size, block_size, + const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads, + // head_size] + T *__restrict__ key_cache, // [num_blocks, kv_num_heads, block_size, + // head_size] + T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size, // head_size] const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq] const int *__restrict__ padding_offsets, // [num_tokens] diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index 901a291d3924db..668d2f29764700 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -35,9 +35,11 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, DenseTensor* out) { #if defined(PADDLE_WITH_CUTLASS) PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 70) || (arch == 75) || (arch == 86)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, - phi::errors::InvalidArgument("Currently, arch only support 70, 80.")); + phi::errors::InvalidArgument( + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); #else PADDLE_THROW(phi::errors::Unimplemented( "Please compile with cutlass to make cutlass available")); diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 51b4786155a923..18244d38989623 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -45,10 +45,11 @@ void WeightQuantizeKernel(const Context& dev_ctx, std::vector weight_shape{static_cast(x.dims()[0]), static_cast(x.dims()[1])}; PADDLE_ENFORCE_EQ( - ((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)), + ((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) || + (arch == 89) || (arch == 90)), true, phi::errors::InvalidArgument( - "Currently, arch only support 70, 75, 80, 86.")); + "Currently, arch only support 70, 75, 80, 86, 89, 90.")); if (algo == "llm.int8") { dev_ctx.template Alloc(scale); diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 05d0e47b314555..963608f7833210 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -97,7 +97,8 @@ void weight_permute_gpu(const GPUContext& dev_ctx, auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, 1); int grid_size = gpu_config.GetGridSize(); int block_size = gpu_config.GetBlockSize(); - if ((arch == 80) || (arch == 86) || (arch == 75)) { + if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) || + (arch == 75)) { weight_permute_kernel_wint8<<>>( input_data, output_data, numel, total_k, total_n); } else if (arch == 70) { diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 41ad1839e1f8a4..417daad4479173 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -69,8 +69,13 @@ def weight_quantize(x, algo="weight_only_int8", arch=None, group_size=-1): arch = _get_arch_info() assert ( - arch == 70 or arch == 80 or arch == 86 or arch == 75 - ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " + arch == 70 + or arch == 75 + or arch == 80 + or arch == 86 + or arch == 89 + or arch == 90 + ), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} " assert ( group_size == -1 or group_size == 64 or group_size == 128 @@ -193,8 +198,13 @@ def weight_only_linear( arch = _get_arch_info() assert ( - arch == 70 or arch == 80 or arch == 86 or arch == 75 - ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " + arch == 70 + or arch == 75 + or arch == 80 + or arch == 86 + or arch == 89 + or arch == 90 + ), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} " assert ( group_size == -1 or group_size == 64 or group_size == 128 ), f"Currently weight_quantize only support group size of -1, 64 or 128. but got {group_size} "