Skip to content

Commit

Permalink
weight_only_linear arch check 89/90 (PaddlePaddle#65295)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome authored Jun 20, 2024
1 parent 47cae92 commit e62a50d
Show file tree
Hide file tree
Showing 9 changed files with 47 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass {

bool CanApplyOn(pir::Operation *op) const override {
if (sm_version_ != 70 && sm_version_ != 75 && sm_version_ != 80 &&
sm_version_ != 86) {
sm_version_ != 86 && sm_version_ != 89 && sm_version_ != 90) {
return false;
}
return op->num_regions() > 0;
Expand Down
11 changes: 7 additions & 4 deletions paddle/phi/infermeta/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,13 @@ void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv,
const int total_num_head = qkv.dims()[qkv.dims().size() - 1] / dim_head;
const int q_num_head = total_num_head - 2 * kv_num_head;

PADDLE_ENFORCE_EQ(q_num_head % kv_num_head,
0,
errors::InvalidArgument(
"The q num_head must be divisible by kv_num_head"));
PADDLE_ENFORCE_EQ(
q_num_head % kv_num_head,
0,
errors::InvalidArgument(
"The q num_head (%d) must be divisible by kv num_head (%d)",
q_num_head,
kv_num_head));
PADDLE_ENFORCE_EQ(
input_dims.size(),
2UL,
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5783,10 +5783,11 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* scale) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86) || (arch == 70) || (arch == 75)),
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
(arch == 89) || (arch == 90)),
true,
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86."));
"Currently, arch only support 70, 75, 80, 86, 89, 90."));

auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
Expand Down
8 changes: 5 additions & 3 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ void quant_compute(const DeviceContext& dev_ctx,
const int32_t arch,
const int32_t group_size) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)),
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
(arch == 89) || (arch == 90)),
true,
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86."));
"Currently, arch only support 70, 75, 80, 86, 89, 90."));

const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
Expand Down Expand Up @@ -104,7 +105,8 @@ void quant_compute(const DeviceContext& dev_ctx,
for (int i = 0; i < out->numel(); ++i) {
out_data[i] = x_int_data[i];
}
} else if ((arch == 80) || (arch == 75) || (arch == 86)) {
} else if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) ||
(arch == 75)) {
permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n});
subbyte_transpose_impl<bits>(
Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1631,11 +1631,11 @@ inline cudaError_t GetNumBlocks(Func func,

template <typename T, int VecSize = 1>
__global__ void cache_int8_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * gqa_group_size,
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t *__restrict__ key_cache, // [num_blocks, gqa_group_size,
uint8_t *__restrict__ key_cache, // [num_blocks, kv_num_heads,
// block_size, head_size]
uint8_t *__restrict__ value_cache, // [num_blocks, gqa_group_size,
uint8_t *__restrict__ value_cache, // [num_blocks, kv_num_heads,
// block_size, head_size]
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int *__restrict__ padding_offsets, // [num_tokens]
Expand Down Expand Up @@ -1715,11 +1715,11 @@ __global__ void cache_int8_kernel(

template <typename T, int VecSize = 1>
__global__ void cache_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * gqa_group_size,
// head_size]
T *__restrict__ key_cache, // [num_blocks, gqa_group_size, block_size,
// head_size]
T *__restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
// head_size]
T *__restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size]
T *__restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size]
const int *__restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int *__restrict__ padding_offsets, // [num_tokens]
Expand Down
6 changes: 4 additions & 2 deletions paddle/phi/kernels/gpu/weight_only_linear_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
DenseTensor* out) {
#if defined(PADDLE_WITH_CUTLASS)
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 70) || (arch == 75) || (arch == 86)),
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
(arch == 89) || (arch == 90)),
true,
phi::errors::InvalidArgument("Currently, arch only support 70, 80."));
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86, 89, 90."));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"Please compile with cutlass to make cutlass available"));
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/gpu/weight_quantize_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ void WeightQuantizeKernel(const Context& dev_ctx,
std::vector<int> weight_shape{static_cast<int>(x.dims()[0]),
static_cast<int>(x.dims()[1])};
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)),
((arch == 70) || (arch == 75) || (arch == 80) || (arch == 86) ||
(arch == 89) || (arch == 90)),
true,
phi::errors::InvalidArgument(
"Currently, arch only support 70, 75, 80, 86."));
"Currently, arch only support 70, 75, 80, 86, 89, 90."));

if (algo == "llm.int8") {
dev_ctx.template Alloc<float>(scale);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void weight_permute_gpu(const GPUContext& dev_ctx,
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, 1);
int grid_size = gpu_config.GetGridSize();
int block_size = gpu_config.GetBlockSize();
if ((arch == 80) || (arch == 86) || (arch == 75)) {
if ((arch == 90) || (arch == 89) || (arch == 86) || (arch == 80) ||
(arch == 75)) {
weight_permute_kernel_wint8<<<grid_size, block_size>>>(
input_data, output_data, numel, total_k, total_n);
} else if (arch == 70) {
Expand Down
18 changes: 14 additions & 4 deletions python/paddle/nn/quant/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ def weight_quantize(x, algo="weight_only_int8", arch=None, group_size=-1):
arch = _get_arch_info()

assert (
arch == 70 or arch == 80 or arch == 86 or arch == 75
), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} "
arch == 70
or arch == 75
or arch == 80
or arch == 86
or arch == 89
or arch == 90
), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} "

assert (
group_size == -1 or group_size == 64 or group_size == 128
Expand Down Expand Up @@ -193,8 +198,13 @@ def weight_only_linear(
arch = _get_arch_info()

assert (
arch == 70 or arch == 80 or arch == 86 or arch == 75
), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} "
arch == 70
or arch == 75
or arch == 80
or arch == 86
or arch == 89
or arch == 90
), f"Currently weight_quantize only support SM70/75/80/86/89/90. but got {arch} "
assert (
group_size == -1 or group_size == 64 or group_size == 128
), f"Currently weight_quantize only support group size of -1, 64 or 128. but got {group_size} "
Expand Down

0 comments on commit e62a50d

Please sign in to comment.