Description
Neuralmagic / IST-DASLab has written a fast INT4A16 kernel with support for 2:4 sparsity (Sparse-Marlin) https://github.com/IST-DASLab/Sparse-Marlin
We'd like to integrate this kernel into torchao. We'd like to test them for ViT acceleration as a datapoint for our PTC poster.
Implementation Details
To add a custom quant + sparse layout into torchao, we need to do three things:
1) Add and bind the CUDA kernel.
Sparse-marlin is implemented as a custom CUDA extension for pytorch, which should be easy to port over. Most of the logic is contained to https://github.com/IST-DASLab/Sparse-Marlin/blob/main/marlin/marlin_cuda_kernel_nm.cu
You can follow the tutorial: https://github.com/pytorch/ao/blob/main/torchao/csrc/README.md which provides details on how to add a custom CUDA extension to torchao.
After this, you should have registered the marin-2:4 mm op to torchao.ops.marlin_24_mm
We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic
2) Register a custom sparse layout and quantized dispatch
Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor, MarlinSparseLayout
.
You can use our semi-structured sparse layout implementation as a reference:
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L36-L45
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L471-L511
You'll want to replace the line
int_data_compressed = torch._cslt_compress(int_data)
with the pack
function from sparse-marlin found here: https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L331
While the semi-structured sparse layout extends PlainLayoutType
, the marlin packed layout should extend AQTLayout
, as the marlin packed format packs both the scales and weights together.
Finally, once your Layout is registered, you'll want to define the quantized_linear_op
dispatch. This will call into your earlier registered torchao.ops.marlin_24_mm
op, instead of the normal dense mm.
https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L708-L732
The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:
if (
weight_is_uint4 and
weight_qtensor.dtype == torch.float16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
isinstance(weight_qtensor.layout_type, MarlinSparseLayoutType)
):
# call torchao.ops.marlin_24_mm
3) Add a layout option to int4_weight_only()
Finally, we need to add a entrypoint to our SparseLayout from the quantize_
API, like we do in https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L462
but for int4_weight_only
quantization instead.
You'll then be able to call into your marlin kernels to test end-to-end with
quantize_(m, int4_weight_only(layout_type=MarlinSparseLayoutType())
Validation
In order to test our kernel in an e2e setting we can extend our SAM benchmarks to add in a new compression option:
https://github.com/pytorch/ao/blob/main/scripts/sam/eval_combo.py#L296