Skip to content

[Feat]: Add support for Dynamic Quant 4 bit CPU kleidiai kernels #17112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nikhil-arm
Copy link

@nikhil-arm nikhil-arm commented Apr 24, 2025

Description:

  1. Add optimized kernel support for Arm 4 bit matmul

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mgoin
Copy link
Member

mgoin commented Apr 24, 2025

This is really cool! Looking forward to review when ready

Description:
1. Add optimized kernel support for Arm 4 bit matmul kernels

Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
@nikhil-arm
Copy link
Author

This is really cool! Looking forward to review when ready

Thanks @mgoin , The change is ready for initial review from my end. Can you help me with CI failures.

Error: retrieving gpg key timed out.
ERROR: process "/bin/sh -c echo 'tzdata tzdata/Areas select America' | debconf-set-selections     && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections     && apt-get update -y     && apt-get install -y ccache software-properties-common git curl sudo     && add-apt-repository ppa:deadsnakes/ppa     && apt-get update -y     && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv     && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1     && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION}     && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config     && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}     && python3 --version && python3 -m pip --version" did not complete successfully: exit code: 1

@nikhil-arm
Copy link
Author

Also will it be possible to allow some of my team members to push branches to vLLM project, allow adding specific reviewers etc. for better contributions
We have more patches incoming :)

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR! How are you thinking of testing the integration?

@mgoin
Copy link
Member

mgoin commented Apr 24, 2025

Also will it be possible to allow some of my team members to push branches to vLLM project, allow adding specific reviewers etc. for better contributions

We don't recommend pushing branches directly to vLLM as there are many developers, rather please continue to make branches on your fork and submit PRs against upstream main. I recommend joining the Developer Slack to discuss complex work!

Also would you be willing to keep copyright changes/additions to just files that are completely new or remove? We don't use this header, just the Apache-2.0 one, and we want consistency across the code base.

@nikhil-arm
Copy link
Author

Also would you be willing to keep copyright changes/additions to just files that are completely new or remove? We don't use this header, just the Apache-2.0 one, and we want consistency across the code base.

Hello,

Also would you be willing to keep copyright changes/additions to just files that are completely new or remove? We don't use this header, just the Apache-2.0 one, and we want consistency across the code base.

We are checking on this and will get back with some resolution soon

@nikhil-arm
Copy link
Author

Thank you for your PR! How are you thinking of testing the integration?

Hello,
We have tested this on arm neoverse cores and ran vLLM benchmarking.
It should ideally work with pytorch version > 2.6 though latest nightly is prefered.

@nikhil-arm
Copy link
Author

nikhil-arm commented Apr 29, 2025

Hello @mgoin ,
Circling back on Copyright issue:

  1. Will it be possible to maintain Arm Copyright header in a single and common place in project ( probably LICENSE ?)
    Pytorch takes a similar approach :
    https://github.com/pytorch/pytorch/blob/c6d3b8f861e1834049f8420526afe6b72a43ec56/LICENSE#L39
    If we can maintain it in a similar fashion then individual files don't need to have copyright headers.

  2. How does vLLM handle copyright attribution for other contributors currently?


scales_and_zp_size = input_size_per_partition // effective_group_size

weight = ModelWeightParameter(data=torch.empty(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are just int8 values in the in4 range I'm assuming?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are correct

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove these debug changes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove the prints.
I remember this issue about a particular layer not being found was reported earlier. We faced the same issue and skipped the tensor and it worked fine. I think resolution along same lines were provided by vLLM team as well. I will find and link the issue for reference

@russellb
Copy link
Member

russellb commented May 2, 2025

Hello @mgoin , Circling back on Copyright issue:

  1. Will it be possible to maintain Arm Copyright header in a single and common place in project ( probably LICENSE ?)
    Pytorch takes a similar approach :
    https://github.com/pytorch/pytorch/blob/c6d3b8f861e1834049f8420526afe6b72a43ec56/LICENSE#L39
    If we can maintain it in a similar fashion then individual files don't need to have copyright headers.
  2. How does vLLM handle copyright attribution for other contributors currently?

Given that vLLM now has contributions from over 1000 individuals, maintaining a list of copyright holders seems overburdensome. It will inevitably always be incomplete and/or out of date. I'd rather we avoid it. These notices do not actually reflect copyright ownership, and if really needed, all of that information can be retrieved through git history.

I'd rather we do not expand the use of these notices in the code base.

There are some exceptions, such as copying code from another place and we want to (or are required to by the license) attribute the source, but in general, I don't think it's necessary.

@nikhil-arm
Copy link
Author

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

I can share the llm compressor config that can quantize the model with w416. Please check this for more context
neuralmagic/compressed-tensors#269 (comment)

@dsikka
Copy link
Contributor

dsikka commented May 2, 2025

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

I can share the llm compressor config that can quantize the model with w416. Please check this for more context neuralmagic/compressed-tensors#269 (comment)

W4A16 is for weight only quantization. Do you have a model with W4A8 quantization from compressed-tensors to validate the integration? Or are you just editing the config from W4A16?

We have a sample scheme for W4A8: https://github.com/neuralmagic/compressed-tensors/blob/1068c848e420443d4d5d73fe031b78bf7a832926/src/compressed_tensors/quantization/quant_scheme.py#L158

@nikhil-arm
Copy link
Author

nikhil-arm commented May 2, 2025

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

I can share the llm compressor config that can quantize the model with w416. Please check this for more context neuralmagic/compressed-tensors#269 (comment)

W4A16 is for weight only quantization. Do you have a model with W4A8 quantization from compressed-tensors to validate the integration? Or are you just editing the config from W4A16?

We have a sample scheme for W4A8: https://github.com/neuralmagic/compressed-tensors/blob/1068c848e420443d4d5d73fe031b78bf7a832926/src/compressed_tensors/quantization/quant_scheme.py#L158

The model quantization at llmcompressor level is w4a16. When the Linear layers go to KleidiAI kernel in pytorch, the activations are quantized down to 8 bit. As per the this comment neuralmagic/compressed-tensors#269 (comment) from @brian-dellabetta we want to make it clear to the user that activations are quantized to 8 bit at any level ( llmcompressor or pytorch ). This is the reason we are calling our scheme w4a8 but it is actually w416 at llmcompressor level.

@dsikka
Copy link
Contributor

dsikka commented May 2, 2025

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

I can share the llm compressor config that can quantize the model with w416. Please check this for more context neuralmagic/compressed-tensors#269 (comment)

W4A16 is for weight only quantization. Do you have a model with W4A8 quantization from compressed-tensors to validate the integration? Or are you just editing the config from W4A16?
We have a sample scheme for W4A8: https://github.com/neuralmagic/compressed-tensors/blob/1068c848e420443d4d5d73fe031b78bf7a832926/src/compressed_tensors/quantization/quant_scheme.py#L158

The model quantization at llmcompressor level is w4a16. When the Linear layers go to KleidiAI kernel in pytorch, the activations are quantized down to 8 bit. As per the this comment neuralmagic/compressed-tensors#269 (comment) from @brian-dellabetta we want to make it clear to the user that activations are quantized to 8 bit at any level ( llmcompressor or pytorch ). This is the reason we are calling our scheme w4a8 but it is actually w416 at llmcompressor level.

@nikhil-arm
Models produced using W4A16 will not have a config defining its activation quantization scheme. So this will return false: _is_dynamic_token_w4a8

Even if compressed-tensors is not generating a scale for the activations that are stored in the checkpoint, this information needs to be captured in the config. We expect compressed-tensors models with dynamic activations to still have a config for its activations as is expected by the integration layer in vLLM.

Do you have a sample model that you can share that you ran to test the integration?

@nikhil-arm
Copy link
Author

nikhil-arm commented May 2, 2025

Can you share a model checkpoint that has been quantized this way? We don't have an ARM runner at the moment for vLLM but we also don't have testing for W4A8 production in llmcompressor either, so I just worry about this format being tested over time.

I can share the llm compressor config that can quantize the model with w416. Please check this for more context neuralmagic/compressed-tensors#269 (comment)

W4A16 is for weight only quantization. Do you have a model with W4A8 quantization from compressed-tensors to validate the integration? Or are you just editing the config from W4A16?
We have a sample scheme for W4A8: https://github.com/neuralmagic/compressed-tensors/blob/1068c848e420443d4d5d73fe031b78bf7a832926/src/compressed_tensors/quantization/quant_scheme.py#L158

The model quantization at llmcompressor level is w4a16. When the Linear layers go to KleidiAI kernel in pytorch, the activations are quantized down to 8 bit. As per the this comment neuralmagic/compressed-tensors#269 (comment) from @brian-dellabetta we want to make it clear to the user that activations are quantized to 8 bit at any level ( llmcompressor or pytorch ). This is the reason we are calling our scheme w4a8 but it is actually w416 at llmcompressor level.

@nikhil-arm Models produced using W4A16 will not have a config defining its activation quantization scheme. So the values you have defined here will not return True and this will return false: _is_dynamic_token_w4a8

Even if compressed-tensors is not generating a scale for the activations that are stored in the checkpoint, this information needs to be captured in the config. We expect compressed-tensors models with dynamic activations to still have a config for its activations as is expected by the integration layer in vLLM.

Do you have a sample model that you can share that you ran to test the integration?

Hello, can not share a model externally but I can share a small script to quantize supported models
Pytorch : https://download.pytorch.org/whl/nightly/cpu/torch-2.8.0.dev20250502%2Bcpu-cp312-cp312-manylinux_2_28_aarch64.whl

python vllm_quantize_model.py meta-llama/Llama-3.2-3B --mode int4 --scheme channelwise

# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate open-source-office@arm.com
import argparse
import os
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor.modifiers.quantization import QuantizationModifier
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.quantization.quant_args import (
    QuantizationArgs,
    QuantizationStrategy,
    QuantizationType,
)
from llmcompressor.transformers import oneshot


def main():
    parser = argparse.ArgumentParser(
        description="Quantize a model using LLM Compressor with customizable mode, scheme, and group size."
    )
    parser.add_argument(
        "model_id",
        type=str,
        help="Model identifier or path (e.g., 'meta-llama/Llama-3.2-3B' or '/path/to/model')",
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["int4"],
        required=True,
        help="Quantization mode: int4",
    )
    parser.add_argument(
        "--scheme",
        type=str,
        choices=["channelwise", "groupwise"],
        required=True,
        help="Quantization scheme for weights (groupwise is only supported for int4)",
    )
    parser.add_argument(
        "--groupsize",
        type=int,
        default=32,
        help="Group size for groupwise quantization (only used when scheme is groupwise). Defaults to 32."
    )
    args = parser.parse_args()

    # Extract a base model name from the model id or path for the output directory
    if "/" in args.model_id:
        base_model_name = args.model_id.split("/")[-1]
    else:
        base_model_name = os.path.basename(args.model_id)

    # Determine output directory based on mode and scheme
    if args.mode == "int4":
        output_dir = f"{base_model_name}-w4a8-{args.scheme}"

    print(f"Loading model '{args.model_id}'...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)

    # Define quantization arguments based on mode and chosen scheme.
    if args.mode == "int4":
        if args.scheme == "channelwise":
            strategy = QuantizationStrategy.CHANNEL
            weights_args = QuantizationArgs(
                num_bits=4,
                type=QuantizationType.INT,
                strategy=strategy,
                symmetric=True,
                dynamic=False,
            )
        else:  # groupwise
            strategy = QuantizationStrategy.GROUP
            weights_args = QuantizationArgs(
                num_bits=4,
                type=QuantizationType.INT,
                strategy=strategy,
                group_size=args.groupsize,
                symmetric=True,
                dynamic=False
            )

    # Activation quantization remains the same for both modes.
    activations_args = QuantizationArgs(
        num_bits=8,
        type=QuantizationType.INT,
        strategy=QuantizationStrategy.TOKEN,
        symmetric=False,
        dynamic=True,
        observer=None,
    )

    # Create a quantization scheme for Linear layers.
    scheme = QuantizationScheme(
        targets=["Linear"],
        weights=weights_args,
        input_activations=activations_args,
    )

    # Create a quantization modifier. We ignore the "lm_head" layer.
    modifier = QuantizationModifier(config_groups={"group_0": scheme}, ignore=["lm_head"])

    # Apply quantization and save the quantized model.
    oneshot(
        model=model,
        recipe=modifier,
        tokenizer=tokenizer,
        output_dir=output_dir,
    )
    print(f"Quantized model saved to: {output_dir}")


if __name__ == "__main__":
    main()

When you try to inference the converted model, please upcast the model to float32 at vLLM inference time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants