Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up AWQ 2.5x with updated kernel #2874

Closed
wants to merge 1 commit into from

Conversation

zcnrex
Copy link
Contributor

@zcnrex zcnrex commented Feb 14, 2024

Follow up from #2566, #2723

Update the latest implementation of AWQ gemm kernel from and saw significant speed improvements

Tested on Google Colab A100 40GB

Sample command:
python benchmarks/benchmark_throughput.py --input-len 1024 --output-len 64 --model TheBloke/Mistral-7B-Instruct-v0.2-AWQ --quantization awq --num-prompts 100 --max-model-len 8192 --dtype half

Screenshot 2024-02-14 at 10 26 24 AM

TODOs

  • Test and validate the performance of longer input seq len
  • Verify correctness / perplexity
  • Try calling cublas for seq len > 256

@casper-hansen
Copy link
Contributor

Note that the GEMM kernel on llm-awq main branch uses a new packed format for weights which should be incompatible with existing weights on Huggingface. But maybe you can make it work?

@zcnrex
Copy link
Contributor Author

zcnrex commented Feb 17, 2024

Thanks @casper-hansen could you share more information on this? And could you share models that may not work with the latest awq main?

@casper-hansen
Copy link
Contributor

casper-hansen commented Feb 18, 2024

Hi @zcnrex, TheBloke publishes most models on the hub. I collaborated with him to use AutoAWQ for quantizing models. For example, https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-AWQ should not be compatible since it is packed differently. vLLM has compatibility with the first generation of kernels that AWQ authors published.

Here are a few differences:

  • The order map is different in the new format [0, 1, 2, 3, 4, 5, 6, 7] versus [0, 2, 4, 6, 1, 3, 5, 7] in the older format
  • Some zeros are being skipped during packing as well.
  • They use a calculate_zeros_width function which results in different shapes than with original GEMM.

I have tried mapping from the GEMV format to GEMM in torch together with people from Huggingface, however, our conclusion is simply that they are incompatible.

The new GEMM v2 is only used for processing context:

        if inputs.shape[0] > 8:
            out = awq_ext.gemmv2_forward_cuda(
                inputs,
                self.qweight,
                self.scales,
                self.qzeros,
                self.group_size,
                self.split_k_iters,
            )
        else:
            out = awq_ext.gemv_forward_cuda(
                inputs, self.qweight, self.scales, self.qzeros, self.group_size
            )

References:

@zcnrex zcnrex closed this Feb 20, 2024
@zcnrex
Copy link
Contributor Author

zcnrex commented Feb 20, 2024

Closed the PR because the gain was likely from removing the cuda stream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants