|
4 | 4 | from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
|
5 | 5 |
|
6 | 6 |
|
| 7 | +lib = torch.library.Library("torchao", "FRAGMENT") |
| 8 | +lib.define("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor") |
| 9 | +lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") |
| 10 | +lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") |
| 11 | +lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") |
| 12 | + |
| 13 | + |
7 | 14 | def register_custom_op(name):
|
8 | 15 | def decorator(func):
|
9 | 16 | if TORCH_VERSION_AT_LEAST_2_4:
|
@@ -39,7 +46,14 @@ def quant_llm_linear(
|
39 | 46 |
|
40 | 47 |
|
41 | 48 | @register_custom_op("torchao::quant_llm_linear")
|
42 |
| -def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): |
| 49 | +def _( |
| 50 | + EXPONENT: int, |
| 51 | + MANTISSA: int, |
| 52 | + _in_feats: Tensor, |
| 53 | + _weights: Tensor, |
| 54 | + _scales: Tensor, |
| 55 | + splitK: int = 1, |
| 56 | +) -> Tensor: |
43 | 57 | torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D")
|
44 | 58 | torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}")
|
45 | 59 | torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D")
|
@@ -76,7 +90,7 @@ def unpack_tensor_core_tiled_layout(packed_w: Tensor, inner_k_tiles: int) -> Ten
|
76 | 90 | )
|
77 | 91 |
|
78 | 92 |
|
79 |
| -@register_custom_op(f"torchao::unpack_tensor_core_tiled_layout") |
| 93 | +@register_custom_op("torchao::unpack_tensor_core_tiled_layout") |
80 | 94 | def _(packed_w: Tensor, inner_k_tiles: int) -> Tensor:
|
81 | 95 | torch._check(
|
82 | 96 | packed_w.dim() == 4,
|
@@ -127,7 +141,7 @@ def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tens
|
127 | 141 | )
|
128 | 142 |
|
129 | 143 |
|
130 |
| -@register_custom_op(f"torchao::dequantize_tensor_core_tiled_layout") |
| 144 | +@register_custom_op("torchao::dequantize_tensor_core_tiled_layout") |
131 | 145 | def _(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor:
|
132 | 146 | # packed_w preconditions
|
133 | 147 | torch._check(
|
@@ -192,7 +206,7 @@ def marlin_24_gemm(
|
192 | 206 | )
|
193 | 207 |
|
194 | 208 |
|
195 |
| -@register_custom_op(f"torchao::marlin_24_gemm") |
| 209 | +@register_custom_op("torchao::marlin_24_gemm") |
196 | 210 | def _(
|
197 | 211 | x: Tensor,
|
198 | 212 | weight_marlin: Tensor,
|
|
0 commit comments