|
21 | 21 |
|
22 | 22 | #include "pytorch_extension_utils.h" |
23 | 23 |
|
24 | | -namespace flashinfer { |
25 | | -namespace bmm_fp8 { |
26 | | - |
27 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>( |
28 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, |
29 | | - __nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale, |
30 | | - const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream); |
31 | | - |
32 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>( |
33 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B, |
34 | | - half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale, |
35 | | - cublasLtHandle_t lt_handle, cudaStream_t stream); |
36 | | - |
37 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>( |
38 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, |
39 | | - __nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale, |
40 | | - const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream); |
41 | | - |
42 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>( |
43 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B, |
44 | | - half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale, |
45 | | - cublasLtHandle_t lt_handle, cudaStream_t stream); |
46 | | - |
47 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>( |
48 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, |
49 | | - __nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale, |
50 | | - const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream); |
51 | | - |
52 | | -template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>( |
53 | | - void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B, |
54 | | - half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale, |
55 | | - cublasLtHandle_t lt_handle, cudaStream_t stream); |
56 | | - |
57 | | -} // namespace bmm_fp8 |
58 | | -} // namespace flashinfer |
59 | 24 |
|
60 | 25 | void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, |
61 | 26 | torch::Tensor& A_scale, torch::Tensor& B_scale) { |
|
0 commit comments