Skip to content

Commit 2918c5e

Browse files
authored
Add TAO dependency to cuda predictor
Differential Revision: D83214553 Pull Request resolved: #3082
1 parent 75944ab commit 2918c5e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass_f8f8.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ rowwise_scaled_linear_sparse_cutlass_f8f8(
4646
return at::Tensor{};
4747
}
4848

49+
#ifdef DEF_SPARSE_CUTLASS_OPS
50+
TORCH_LIBRARY(torchao, m) {
51+
m.def(
52+
"torchao::rowwise_scaled_linear_sparse_cutlass_f8f8(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_meta, Tensor weight_scale, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor",
53+
rowwise_scaled_linear_sparse_cutlass_f8f8);
54+
}
55+
#endif
56+
4957
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
5058
m.impl("torchao::rowwise_scaled_linear_sparse_cutlass_f8f8",
5159
&rowwise_scaled_linear_sparse_cutlass_f8f8);

0 commit comments

Comments
 (0)