Skip to content

Commit b149edb

Browse files
authored
Move op definition for custom kernel from C++ to Python (#949)
move op def to python
1 parent e83c35d commit b149edb

File tree

5 files changed

+21
-33
lines changed

5 files changed

+21
-33
lines changed

torchao/csrc/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ To learn more about custom ops in PyTorch you can refer to the [PyTorch Custom O
1111

1212
## How to add your own kernel in ao
1313

14-
We've integrated a test kernel which implements a non-maximum supression (NMS) op which you can use as a template for your own kernels.
14+
We've integrated several kernels which you can use as a template for your own kernels. `tensor_core_tiled_layout` is the most straight-forward to get started with.
1515

1616
1. Install the cudatoolkit https://anaconda.org/conda-forge/cudatoolkit
1717
2. In `csrc/cuda` author your custom kernel and ensure you expose a `TORCH_LIBRARY_IMPL` which will expose `torchao::your_custom_kernel`
18-
3. In `csrc/` author a `cpp` stub which will include a `TORCH_LIBRARY_FRAGMENT` which will place your custom kernel in the `torchao.ops` namespace and also expose a public function with the right arguments
19-
4. In `torchao/ops.py` is where you'll expose the python API which your new end users will leverage
18+
3. In `torchao/ops.py`, define your op signature at the top of the file. You can refer to [this](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md) on how to write the signature correctly
19+
4. `torchao/ops.py` is also where you'll expose the python API which your new end users will leverage
2020
5. Write a new test in `test/test_ops.py` which most importantly needs to pass `opcheck()`, this ensures that your custom kernel composes out of the box with `torch.compile()`
2121

2222
And that's it! Once CI passes and your code merged you'll be able to point people to `torchao.ops.your_custom_kernel`. If you're working on an interesting kernel and would like someone else to handle the release and package management please feel free to open an issue.

torchao/csrc/fp6_llm.cpp

Lines changed: 0 additions & 8 deletions
This file was deleted.

torchao/csrc/sparse_marlin.cpp

Lines changed: 0 additions & 8 deletions
This file was deleted.

torchao/csrc/tensor_core_tiled_layout.cpp

Lines changed: 0 additions & 10 deletions
This file was deleted.

torchao/ops.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
55

66

7+
lib = torch.library.Library("torchao", "FRAGMENT")
8+
lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor")
9+
lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor")
10+
lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor")
11+
lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor")
12+
13+
714
def register_custom_op(name):
815
def decorator(func):
916
if TORCH_VERSION_AT_LEAST_2_4:
@@ -39,7 +46,14 @@ def quant_llm_linear(
3946

4047

4148
@register_custom_op("torchao::quant_llm_linear")
42-
def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1):
49+
def _(
50+
EXPONENT: int,
51+
MANTISSA: int,
52+
_in_feats: Tensor,
53+
_weights: Tensor,
54+
_scales: Tensor,
55+
splitK: int = 1,
56+
) -> Tensor:
4357
torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
4458
torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
4559
torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D")
@@ -76,7 +90,7 @@ def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Ten
7690
)
7791

7892

79-
@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout")
93+
@register_custom_op("torchao::unpack_tensor_core_tiled_layout")
8094
def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor:
8195
torch._check(
8296
packed_w.dim() == 4,
@@ -127,7 +141,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens
127141
)
128142

129143

130-
@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout")
144+
@register_custom_op("torchao::dequantize_tensor_core_tiled_layout")
131145
def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor:
132146
# packed_w preconditions
133147
torch._check(
@@ -192,7 +206,7 @@ def marlin_24_gemm(
192206
)
193207

194208

195-
@register_custom_op(f"torchao::marlin_24_gemm")
209+
@register_custom_op("torchao::marlin_24_gemm")
196210
def _(
197211
x: Tensor,
198212
weight_marlin: Tensor,

0 commit comments

Comments
 (0)