diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h index 5dafe445e2b46..201dd403270f3 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_gpu_impl.h @@ -105,6 +105,7 @@ void weight_permute_gpu(const GPUContext& dev_ctx, input_data, output_data, numel, total_k, total_n); } } + template __global__ void per_channel_quant_gpu(const T* weight_data, int8_t* quanted_weight_data, @@ -160,7 +161,6 @@ __global__ void per_channel_quant_gpu(const T* weight_data, } } } - template void weight_quant_gpu(const GPUContext& dev_ctx, const T* weight_data, @@ -174,8 +174,15 @@ void weight_quant_gpu(const GPUContext& dev_ctx, constexpr int kBlockSize = 64; constexpr int kWarpNum = kBlockSize / kWarpSize; constexpr int kVectorSize = 128 / sizeof(T) / 8; + PADDLE_ENFORCE_EQ(total_n % kVectorSize, + 0, + phi::errors::PreconditionNotMet( + "Currently, weight_quant_gpu kernel only support n " + "with multiple of %d, please use", + kVectorSize)); int vec_total_n = total_n / kVectorSize; - int kGridSize = max(vec_total_n / kBlockSize, static_cast(1)); + int kGridSize = + max((vec_total_n + kBlockSize - 1) / kBlockSize, static_cast(1)); per_channel_quant_gpu<<>>( weight_data, quanted_weight_data, scale_data, total_k, vec_total_n); } diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index 81f84f138e70b..c7bbc1c658267 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -399,5 +399,47 @@ def test_weightonly_linear_backward(self): np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3) +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int8" + self.in_features = 128 + self.out_features = 288 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.weight_dtype = "int8" + self.in_features = 128 + self.out_features = 288 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.in_features = 128 + self.out_features = 288 + + if __name__ == '__main__': unittest.main()