Skip to content

Conversation

@amd-rcorzine
Copy link

@amd-rcorzine amd-rcorzine commented Mar 19, 2025

Overview:

For my project, I have added support for online rotations (for quantization) into vLLM - I have implemented QuaRot paper - (rotations to improve int8 quantization by reducing outliers), into vLLM and achieved promising performance results -

  • Online rotations, such as QuaRot, greatly improve integer quantization accuracy by reducing outliers in the layer input - this allows the efficiency if integer quantization while maintaining most of the accuracy of the unquantized model
  • No existing interface within vLLM to support online rotations
  • This PR adds a feature to support online rotations within vLLM, specifically for Quark quantized models in vLLM
    • This feature is easily configurable
    • Integrates well with existing vLLM / Quark quantization design
  • Additionally, strong performance results have been demonstrated using this feature, and are detailed in the design doc

Design Document:

Introduction

Recent papers QuaRot and SpinQuant papers propose a quantization method for large language models (LLMs) that use rotations to simplify quantization. It quantizes all components, including weights, activations, and KV cache, into 4 bits by removing outliers without changing the output. It accomplishes this by inserting four fundamental rotations (called R1, R2, R3 and R4 in SpinQuant), into Llama models.

To support these features, the ability to add online hadamard rotations - explicitly rotating the input before it is passed into the quantized layer - is necessary (see figure below).

Alt text

Features

  • Specify specific layers, by regular expression, and corresponding online rotation to be applied to their inputs
  • Specify a custom Python module and method for the rotation function

Requirements

The custom Python module method for the rotation function must satisfy the following:

  • Must accept torch.Tensor in the form $b \times n$ where $b$ denotes batch size and $n$ denotes embedding dimension
  • Must be a CUDA-graph-compatible Python function (cannot involve dynamic memory allocation in the nn.Module forward pass etc.)

Current Limitations

  • Requires rotations to be defined internally in vLLM quark/schemes directory - user must define every rotation as a method inside a Python file in this directory

Proof of Concept (PoC) & Performance Impact

To demonstrate the effectiveness of implementing online rotations into vLLM, I have implemented QuaRot into vLLM and demonstrated significant end-end performance benefits.

Throughput / Latency:

A comprehensive set of tests have been run to measure performance:
documentation-repo/e2e-benchmarks/performance_compare.md at main · rcorzine/documentation-repo

Below is a specific example using benchmarks/benchmark_throughput.py:

python3 benchmark_throughput.py --model {name} --num-prompts 50 --input-len 64 --output-len 128

Alt text
QuaRot demonstrates very strong performance improvements when integrated with vLLM, close to int8.

Accuracy:

To ensure the implementation accuracy matched the claims from the research paper ([2404.00456] QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs) comprehensive perplexity benchmark as well as empirical validation were used.

To measure perplexity, I used benchmarks/P3L.py. Note: this benchmark has not yet been merged into the main branch of vLLM. Please refer to it here: vllm/benchmarks/P3L.py at main · ROCm/vllm
The perplexity score matches the claims of the paper: achieve near int8 efficiency with near accuracy of the unquantized model.

unquantized quarot FHT 512 BMM Groups int8
4.2 7.1 210.5

Reproducibility

Model

Utilize Quark to quantize the model:
Viva Engage - Conversation
Use this command to export to huggingface

python quantize_quark.py --output_dir {name} --model_export hf_format --model_dir meta-llama/Meta-Llama-3-8B --quant_scheme w_int8_a_int8_per_tensor_sym --pre_quantization_optimization quarot

vLLM

Build vLLM from this fork https://gitenterprise.xilinx.com/rcorzine/vllm-online-rotations (forked from main branch).
Install the custom rotation module used.
pip install git+https://gitenterprise.xilinx.com/rcorzine/FHT_ROCm.git

Usage Example

  • Inside config.json include a field online_rotations
    • Include the layer name regex
      • Under module_name, define the Python file
      • Under method_name, define the rotation function method name
"online_rotations": {
    "mlp.down_proj": {
        "name": "quarot_R4",
        "module_name": "quarot",
        "method_name": "R4_Function_Llama_3_8B"
    }
}

System Architecture

Configuration

For reference vLLM contains the following classes:

class QuarkConfig(QuantizationConfig)

  • def get_quant_method
    • Maps a layer to a particular LinearMethodBase quant method that will be applied to that layer. Associates the layer with a QuarkScheme. See QuantizationConfig for param details.

class QuarkLinearMethod(LinearMethodBase)

  • def apply_weights
    • Use the output of create_weights and the QuarkScheme associated with the layer to apply the forward pass with the layer input. See LinearMethodBase for param details.

class QuarkW8A8Int8(QuarkScheme)

  • def apply_weights
    • Apply the forward pass with the layer input. See QuarkScheme for param details.

To achieve online rotations, I use the QuarkScheme associated with the layer to apply the correct rotation to its input.

Below is a detailed explanation of how I implemented online rotations:

End goal:
A way to apply online rotation to the input of a particular layer.

Modifications:

  • apply_weights method, inside the QuarkW8A8Int8 class, takes in a layer and the input, and applies the forward pass. I modify the apply_weights method to apply online rotation (if it exists).
if self.online_rotation_method:
	x=self.online_rotation_method(x)
  • get_quant_method method, inside QuarkConfig class, takes in the layer, associates it with a QuarkW8A8Int8 scheme, and returns a QuarkLinearMethod object, that will later use the QuarkW8A8Int8 associated with the layer to apply the forward pass with the layer input.
    • The helper method _get_scheme_from_config of the class QuarkConfig takes in a layer config and retrieves the appropriate QuarkW8A8Int8 scheme. I modify the _get_scheme_from_config method to, based on whether the online rotation is specified in the layer config passed in, import the online rotation module, and initialize the QuarkW8A8Int8 with the online rotation method.
     if online_rotation_config:
     	import importlib
     	import os
     	module_path = os.path.join(os.path.dirname(__file__), "schemes", online_rotation_config["module_name"]+".py")
     	if not os.path.exists(module_path):
     		raise FileNotFoundError(f"The file at {module_path} does not exist.")
     	spec = importlib.util.spec_from_file_location(online_rotation_config["module_name"], module_path)
     	module = importlib.util.module_from_spec(spec)
     	spec.loader.exec_module(module)
     	online_rotation_method = getattr(module, online_rotation_config["method_name"])
     else:
     	online_rotation_method = None
    • The helper method _find_matched_config of the class QuarkConfig, takes in the layer name, and retrieves the matching layer config. I modify _find_matched_config to, based on whether the layer name is listed in the configuration, add the online rotation method specification to the layer config.
     if "online_rotations" in self.quant_config:
     	rot_info = next((value for key, value in self.quant_config['online_rotations'].items() if layer_name.endswith(key)), None)
     	rv['online_rotations']=rot_info

Custom Operators (Compiled PyTorch ops for the rotations)

Have the online rotation module (Fast Hadamard Transform) registered as a custom operator, bundled within the vLLM codebase.

Appendix

Fast Hadamard Transform - Algorithm Design Reference

Algorithm Overview

R4 corresponds to performing the Fast Hadamard Transform (FHT, aka the Butterfly Algorithm) on the input before multiplying by a hadamard matrix.

More concretely, it involves the following steps:

  • Apply FHT to each 512-element group of the length 14336 input
    • Performed in $O(n \cdot \log n)$
      • Recursively combine smaller Hadamard transforms into larger ones using butterfly operations.
      • Each butterfly operation merges pairs of elements by summing and subtracting them.
  • Reshape input into a 28 x 512 matrix, and post-multiply the 28 x 28 hadamard matrix by input

Tri Dao's implementation

Dao-AILab/fast-hadamard-transform: Fast Hadamard transform in CUDA, with a PyTorch interface
For each 512 chunk of the 14336 input, apply FHT butterfly algorithm iteratively for $\log n$ iterations. This I/O aware implementation effectively minimizes memory access.

  • Load kNelts from segment $i$ of input to registers ofthread $i$
    • Apply for $\log$ kNelts iterations, thread wise
    • Apply for $\log$ kWarpSize iterations, warp wise
      • Use __shfl_xor_sync for each pair to receive their element from complement thread
    • Apply for $\log$ kNWarps iterations, block wise
      • Use shared memory to copy each pair complement to be on same warp
      • Use __shfl_xor_sync for each pair to receive their element from complement thread

Modifications for AMD Compatibility

  • Ensure warp size is defined to be 64
  • ensure __shfl_xor_sync is replaced with __shfl (which seems more broadly compatible with ROCm)
  • ensure the correct lane_id of the thread in its warp is calculated

Implementation Nuance

If given the full 14336 size input, the Dao Lab kernel performs both the FHT and the 28 x 28 hadamard post matrix multiplication (using shared memory), in the same kernel. In my implementation, I perform only the FHT in the Dao Lab kernel, and perform the matrix multiply separately.

  • I chose this implementation because I found it performed slightly better than performing the matrix multiply in the Dao Lab kernel - I believe this is because the separate matrix multiply uses the matrix cores for the 28 x 28 post-multiply, which are more efficient than performing the matrix multiply in shared memory. This also uses bfloat16 (while the shared memory uses float32).

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@markurtz
Copy link

Looks very promising! The design document isn't loading for me, is it an internal one? If that's the case, could you share over what the expected disk formats look like / attach them to this PR? We have another initiative in LLM Compressor and Compressed Tensors to get general support for transforms and it'd be great to have a unified format there

@amd-rcorzine amd-rcorzine closed this by deleting the head repository Apr 11, 2025
@mratsim
Copy link

mratsim commented Nov 12, 2025

Re

Recent papers QuaRot and SpinQuant papers propose a quantization method for large language models (LLMs) that use rotations to simplify quantization. It quantizes all components, including weights, activations, and KV cache, into 4 bits by removing outliers without changing the output. It accomplishes this by inserting four fundamental rotations (called R1, R2, R3 and R4 in SpinQuant), into Llama models.

I was looking into KV-cache quantization strategies, this PR doesn't implement KV-cache quantization using Hadamard transforms right? Did you perhaps implement it in a fork of vllm?

By the way Hadamard transforms support was added in:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants