Skip to content

Commit

Permalink
[Cherry-pick] fix weight quant kernel bug when n div 64 != 0 (PaddleP…
Browse files Browse the repository at this point in the history
…addle#60184)

* fix weight-only quant kernel error for n div 64 !=0

* code style fix
  • Loading branch information
wwbitejotunn authored Dec 26, 2023
1 parent a4cd847 commit 20d3558
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
11 changes: 9 additions & 2 deletions paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void weight_permute_gpu(const GPUContext& dev_ctx,
input_data, output_data, numel, total_k, total_n);
}
}

template <typename T, int VectorSize = 8>
__global__ void per_channel_quant_gpu(const T* weight_data,
int8_t* quanted_weight_data,
Expand Down Expand Up @@ -160,7 +161,6 @@ __global__ void per_channel_quant_gpu(const T* weight_data,
}
}
}

template <typename T, typename GPUContext>
void weight_quant_gpu(const GPUContext& dev_ctx,
const T* weight_data,
Expand All @@ -174,8 +174,15 @@ void weight_quant_gpu(const GPUContext& dev_ctx,
constexpr int kBlockSize = 64;
constexpr int kWarpNum = kBlockSize / kWarpSize;
constexpr int kVectorSize = 128 / sizeof(T) / 8;
PADDLE_ENFORCE_EQ(total_n % kVectorSize,
0,
phi::errors::PreconditionNotMet(
"Currently, weight_quant_gpu kernel only support n "
"with multiple of %d, please use",
kVectorSize));
int vec_total_n = total_n / kVectorSize;
int kGridSize = max(vec_total_n / kBlockSize, static_cast<int>(1));
int kGridSize =
max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast<int>(1));
per_channel_quant_gpu<T, kVectorSize><<<kGridSize, kBlockSize>>>(
weight_data, quanted_weight_data, scale_data, total_k, vec_total_n);
}
Expand Down
42 changes: 42 additions & 0 deletions test/quantization/test_weight_only_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,47 @@ def test_weightonly_linear_backward(self):
np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3)


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


@unittest.skipIf(
not core.is_compiled_with_cuda()
or get_cuda_version() < 11020
or paddle.device.cuda.get_device_capability()[0] < 8,
"quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8",
)
class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase):
def config(self):
super().config()
self.dtype = 'bfloat16'
self.weight_dtype = "int8"
self.in_features = 128
self.out_features = 288


if __name__ == '__main__':
unittest.main()

0 comments on commit 20d3558

Please sign in to comment.