Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview:
For my project, I have added support for online rotations (for quantization) into vLLM - I have implemented QuaRot paper - (rotations to improve int8 quantization by reducing outliers), into vLLM and achieved promising performance results -
Design Document:
Introduction
Recent papers QuaRot and SpinQuant papers propose a quantization method for large language models (LLMs) that use rotations to simplify quantization. It quantizes all components, including weights, activations, and KV cache, into 4 bits by removing outliers without changing the output. It accomplishes this by inserting four fundamental rotations (called R1, R2, R3 and R4 in SpinQuant), into Llama models.
To support these features, the ability to add online hadamard rotations - explicitly rotating the input before it is passed into the quantized layer - is necessary (see figure below).
Features
Requirements
The custom Python module method for the rotation function must satisfy the following:
torch.Tensorin the formCurrent Limitations
quark/schemesdirectory - user must define every rotation as a method inside a Python file in this directoryProof of Concept (PoC) & Performance Impact
To demonstrate the effectiveness of implementing online rotations into vLLM, I have implemented QuaRot into vLLM and demonstrated significant end-end performance benefits.
Throughput / Latency:
A comprehensive set of tests have been run to measure performance:
documentation-repo/e2e-benchmarks/performance_compare.md at main · rcorzine/documentation-repo
Below is a specific example using
benchmarks/benchmark_throughput.py:python3 benchmark_throughput.py --model {name} --num-prompts 50 --input-len 64 --output-len 128QuaRot demonstrates very strong performance improvements when integrated with vLLM, close to int8.
Accuracy:
To ensure the implementation accuracy matched the claims from the research paper ([2404.00456] QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs) comprehensive perplexity benchmark as well as empirical validation were used.
To measure perplexity, I used
benchmarks/P3L.py. Note: this benchmark has not yet been merged into the main branch of vLLM. Please refer to it here: vllm/benchmarks/P3L.py at main · ROCm/vllmThe perplexity score matches the claims of the paper: achieve near int8 efficiency with near accuracy of the unquantized model.
Reproducibility
Model
Utilize Quark to quantize the model:
Viva Engage - Conversation
Use this command to export to huggingface
python quantize_quark.py --output_dir {name} --model_export hf_format --model_dir meta-llama/Meta-Llama-3-8B --quant_scheme w_int8_a_int8_per_tensor_sym --pre_quantization_optimization quarotvLLM
Build vLLM from this fork https://gitenterprise.xilinx.com/rcorzine/vllm-online-rotations (forked from main branch).
Install the custom rotation module used.
pip install git+https://gitenterprise.xilinx.com/rcorzine/FHT_ROCm.gitUsage Example
config.jsoninclude a fieldonline_rotationsmodule_name, define the Python filemethod_name, define the rotation function method nameSystem Architecture
Configuration
For reference vLLM contains the following classes:
class QuarkConfig(QuantizationConfig)def get_quant_methodLinearMethodBasequant method that will be applied to that layer. Associates the layer with aQuarkScheme. SeeQuantizationConfigfor param details.class QuarkLinearMethod(LinearMethodBase)def apply_weightscreate_weightsand theQuarkSchemeassociated with the layer to apply the forward pass with the layer input. SeeLinearMethodBasefor param details.class QuarkW8A8Int8(QuarkScheme)def apply_weightsQuarkSchemefor param details.To achieve online rotations, I use the
QuarkSchemeassociated with the layer to apply the correct rotation to its input.Below is a detailed explanation of how I implemented online rotations:
End goal:
A way to apply online rotation to the input of a particular layer.
Modifications:
apply_weightsmethod, inside theQuarkW8A8Int8class, takes in a layer and the input, and applies the forward pass. I modify theapply_weightsmethod to apply online rotation (if it exists).get_quant_methodmethod, insideQuarkConfigclass, takes in the layer, associates it with aQuarkW8A8Int8scheme, and returns aQuarkLinearMethodobject, that will later use theQuarkW8A8Int8associated with the layer to apply the forward pass with the layer input._get_scheme_from_configof the classQuarkConfigtakes in a layer config and retrieves the appropriateQuarkW8A8Int8scheme. I modify the_get_scheme_from_configmethod to, based on whether the online rotation is specified in the layer config passed in, import the online rotation module, and initialize theQuarkW8A8Int8with the online rotation method._find_matched_configof the classQuarkConfig, takes in the layer name, and retrieves the matching layer config. I modify_find_matched_configto, based on whether the layer name is listed in the configuration, add the online rotation method specification to the layer config.Custom Operators (Compiled PyTorch ops for the rotations)
Have the online rotation module (Fast Hadamard Transform) registered as a custom operator, bundled within the vLLM codebase.
Appendix
Fast Hadamard Transform - Algorithm Design Reference
Algorithm Overview
R4 corresponds to performing the Fast Hadamard Transform (FHT, aka the Butterfly Algorithm) on the input before multiplying by a hadamard matrix.
More concretely, it involves the following steps:
Tri Dao's implementation
Dao-AILab/fast-hadamard-transform: Fast Hadamard transform in CUDA, with a PyTorch interface$\log n$ iterations. This I/O aware implementation effectively minimizes memory access.
For each 512 chunk of the 14336 input, apply FHT butterfly algorithm iteratively for
kNeltsfrom segmentkNeltsiterations, thread wisekWarpSizeiterations, warp wise__shfl_xor_syncfor each pair to receive their element from complement threadkNWarpsiterations, block wise__shfl_xor_syncfor each pair to receive their element from complement threadModifications for AMD Compatibility
__shfl_xor_syncis replaced with__shfl(which seems more broadly compatible with ROCm)lane_idof the thread in its warp is calculatedImplementation Nuance
If given the full 14336 size input, the Dao Lab kernel performs both the FHT and the 28 x 28 hadamard post matrix multiplication (using shared memory), in the same kernel. In my implementation, I perform only the FHT in the Dao Lab kernel, and perform the matrix multiply separately.