Skip to content

Add 2:4 sparse marlin kernels to torchaoΒ #549

Closed
@jcaip

Description

@jcaip

Neuralmagic / IST-DASLab has written a fast INT4A16 kernel with support for 2:4 sparsity (Sparse-Marlin) https://github.com/IST-DASLab/Sparse-Marlin

image

We'd like to integrate this kernel into torchao. We'd like to test them for ViT acceleration as a datapoint for our PTC poster.

Implementation Details

To add a custom quant + sparse layout into torchao, we need to do three things:

1) Add and bind the CUDA kernel.

Sparse-marlin is implemented as a custom CUDA extension for pytorch, which should be easy to port over. Most of the logic is contained to https://github.com/IST-DASLab/Sparse-Marlin/blob/main/marlin/marlin_cuda_kernel_nm.cu

You can follow the tutorial: https://github.com/pytorch/ao/blob/main/torchao/csrc/README.md which provides details on how to add a custom CUDA extension to torchao.

After this, you should have registered the marin-2:4 mm op to torchao.ops.marlin_24_mm

We would also want to benchmark the op at this time and make sure we get the same speedups reported by neuralmagic

2) Register a custom sparse layout and quantized dispatch

Now that we have our kernel connected, we can connect the kernel to our quantization API by writing a new sparse layout for AffineQuantizedTensor, MarlinSparseLayout.

You can use our semi-structured sparse layout implementation as a reference:

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L36-L45

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L471-L511

You'll want to replace the line
int_data_compressed = torch._cslt_compress(int_data)
with the pack function from sparse-marlin found here: https://github.com/IST-DASLab/Sparse-Marlin/blob/c2ffa2395a3ada26c8cb7f910a5ec65bd3ce288a/marlin/__init__.py#L331

While the semi-structured sparse layout extends PlainLayoutType, the marlin packed layout should extend AQTLayout, as the marlin packed format packs both the scales and weights together.

Finally, once your Layout is registered, you'll want to define the quantized_linear_op dispatch. This will call into your earlier registered torchao.ops.marlin_24_mm op, instead of the normal dense mm.

https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor.py#L708-L732

The conditional would look something like this, after line 780, as we want to overload the int4-weight-only dispatch path with the sparse marlin kernels:

        if (
            weight_is_uint4 and
            weight_qtensor.dtype == torch.float16 and
            len(weight_qtensor.shape) == 2 and
            weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
            isinstance(weight_qtensor.layout_type, MarlinSparseLayoutType)
        ):
             # call torchao.ops.marlin_24_mm 

3) Add a layout option to int4_weight_only()

Finally, we need to add a entrypoint to our SparseLayout from the quantize_ API, like we do in https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L462

but for int4_weight_only quantization instead.

You'll then be able to call into your marlin kernels to test end-to-end with

quantize_(m, int4_weight_only(layout_type=MarlinSparseLayoutType())

Validation

In order to test our kernel in an e2e setting we can extend our SAM benchmarks to add in a new compression option:

https://github.com/pytorch/ao/blob/main/scripts/sam/eval_combo.py#L296

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions