Skip to content

Commit c6d0f65

Browse files
authored
refactor: remove unused template instantiation (#519)
1 parent ca1fa7d commit c6d0f65

File tree

1 file changed

+0
-35
lines changed

1 file changed

+0
-35
lines changed

python/csrc/bmm_fp8.cu

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,41 +21,6 @@
2121

2222
#include "pytorch_extension_utils.h"
2323

24-
namespace flashinfer {
25-
namespace bmm_fp8 {
26-
27-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>(
28-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
29-
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
30-
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);
31-
32-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e4m3, half>(
33-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
34-
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
35-
cublasLtHandle_t lt_handle, cudaStream_t stream);
36-
37-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, __nv_bfloat16>(
38-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B,
39-
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
40-
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);
41-
42-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e4m3, __nv_fp8_e5m2, half>(
43-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e4m3* A, const __nv_fp8_e5m2* B,
44-
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
45-
cublasLtHandle_t lt_handle, cudaStream_t stream);
46-
47-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, __nv_bfloat16>(
48-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B,
49-
__nv_bfloat16* D, int batch_size, int m, int n, int k, const float* A_scale,
50-
const float* B_scale, cublasLtHandle_t lt_handle, cudaStream_t stream);
51-
52-
template cublasStatus_t bmm_fp8_internal_cublaslt<__nv_fp8_e5m2, __nv_fp8_e4m3, half>(
53-
void* workspace, size_t workspace_size_in_bytes, const __nv_fp8_e5m2* A, const __nv_fp8_e4m3* B,
54-
half* D, int batch_size, int m, int n, int k, const float* A_scale, const float* B_scale,
55-
cublasLtHandle_t lt_handle, cudaStream_t stream);
56-
57-
} // namespace bmm_fp8
58-
} // namespace flashinfer
5924

6025
void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
6126
torch::Tensor& A_scale, torch::Tensor& B_scale) {

0 commit comments

Comments
 (0)