Skip to content

Commit

Permalink
[Doc] Add docs for llmcompressor INT8 and FP8 checkpoints (vllm-proje…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Aug 16, 2024
1 parent 93478b6 commit b3f4e17
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 111 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Documentation
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/int8
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache
Expand Down
217 changes: 106 additions & 111 deletions docs/source/quantization/fp8.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _fp8:

FP8
FP8 W8A8
==================

vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
Expand All @@ -15,6 +15,11 @@ The FP8 types typically supported in hardware have two distinct representations,
- **E4M3**: Consists of 1 sign bit, 4 exponent bits, and 3 bits of mantissa. It can store values up to +/-448 and ``nan``.
- **E5M2**: Consists of 1 sign bit, 5 exponent bits, and 2 bits of mantissa. It can store values up to +/-57344, +/- ``inf``, and ``nan``. The tradeoff for the increased dynamic range is lower precision of the stored values.

.. note::

FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada Lovelace, Hopper).
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.

Quick Start with Online Dynamic Quantization
--------------------------------------------

Expand All @@ -33,106 +38,134 @@ In this mode, all Linear modules (except for the final ``lm_head``) have their w

Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.

Offline Quantization
Installation
------------

To produce performant FP8 quantized models with vLLM, you'll need to install the `llm-compressor <https://github.com/vllm-project/llm-compressor/>`_ library:

.. code-block:: console
$ pip install llmcompressor==0.1.0
Quantization Process
--------------------

For offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.
The quantization process involves three main steps:

.. code-block:: bash
1. Loading the model
2. Applying quantization
3. Evaluating accuracy in vLLM

git clone https://github.com/neuralmagic/AutoFP8.git
pip install -e AutoFP8
1. Loading the Model
^^^^^^^^^^^^^^^^^^^^

This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed.
Use ``SparseAutoModelForCausalLM``, which wraps ``AutoModelForCausalLM``, for saving and loading quantized models:

Offline Quantization with Dynamic Activation Scaling Factors
------------------------------------------------------------
.. code-block:: python
You can use AutoFP8 to produce checkpoints with their weights quantized to FP8 ahead of time and let vLLM handle calculating dynamic scales for the activations at runtime for maximum accuracy. You can enable this with the ``activation_scheme="dynamic"`` argument.
from llmcompressor.transformers import SparseAutoModelForCausalLM
from transformers import AutoTokenizer
.. warning::
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
2. Applying Quantization
^^^^^^^^^^^^^^^^^^^^^^^^

Please note that although this mode doesn't give you better performance, it reduces memory footprint compared to online quantization.
For FP8 quantization, we can recover accuracy with simple RTN quantization. We recommend targeting all ``Linear`` layers using the ``FP8_DYNAMIC`` scheme, which uses:

- Static, per-channel quantization on the weights
- Dynamic, per-token quantization on the activations

Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow.

.. code-block:: python
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8-Dynamic"
# Configure the simple PTQ quantization
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="dynamic")
# For dynamic activation scales, there is no need for calbration examples
examples = []
# Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Load the model, quantize, and save checkpoint
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(examples)
model.save_quantized(quantized_model_dir)
# Save the model.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
3. Evaluating Accuracy
^^^^^^^^^^^^^^^^^^^^^^

Install ``vllm`` and ``lm-evaluation-harness``:

.. code-block:: console
$ pip install vllm lm_eval==0.4.3
Load and run the model in ``vllm``:

.. code-block:: python
from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
model.generate("Hello my name is")
Evaluate accuracy with ``lm_eval`` (for example on 250 samples of ``gsm8k``):

.. note::

Quantized models can be sensitive to the presence of the ``bos`` token. ``lm_eval`` does not add a ``bos`` token by default, so make sure to include the ``add_bos_token=True`` argument when running your evaluations.

.. code-block:: console
$ MODEL=$PWD/Meta-Llama-3-8B-Instruct-FP8-Dynamic
$ lm_eval \
--model vllm \
--model_args pretrained=$MODEL,add_bos_token=True \
--tasks gsm8k --num_fewshot 5 --batch_size auto --limit 250
In the output of the above script, you should be able to see the quantized Linear modules (FP8DynamicLinear) replaced in the model definition.
Note that the ``lm_head`` Linear module at the end is currently skipped by default.
Here's an example of the resulting scores:

.. code-block:: text
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): FP8DynamicLinear()
(k_proj): FP8DynamicLinear()
(v_proj): FP8DynamicLinear()
(o_proj): FP8DynamicLinear()
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): FP8DynamicLinear()
(up_proj): FP8DynamicLinear()
(down_proj): FP8DynamicLinear()
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
Saving the model to Meta-Llama-3-8B-Instruct-FP8-Dynamic
Your model checkpoint with quantized weights should be available at ``Meta-Llama-3-8B-Instruct-FP8/``.
We can see that the weights are smaller than the original BF16 precision.
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.768|± |0.0268|
| | |strict-match | 5|exact_match|↑ |0.768|± |0.0268|
.. code-block:: bash
Troubleshooting and Support
---------------------------

ls -lh Meta-Llama-3-8B-Instruct-FP8-Dynamic/
total 8.5G
-rw-rw-r-- 1 user user 869 Jun 7 14:43 config.json
-rw-rw-r-- 1 user user 194 Jun 7 14:43 generation_config.json
-rw-rw-r-- 1 user user 4.7G Jun 7 14:43 model-00001-of-00002.safetensors
-rw-rw-r-- 1 user user 3.9G Jun 7 14:43 model-00002-of-00002.safetensors
-rw-rw-r-- 1 user user 43K Jun 7 14:43 model.safetensors.index.json
-rw-rw-r-- 1 user user 296 Jun 7 14:43 special_tokens_map.json
-rw-rw-r-- 1 user user 50K Jun 7 14:43 tokenizer_config.json
-rw-rw-r-- 1 user user 8.7M Jun 7 14:43 tokenizer.json
If you encounter any issues or have feature requests, please open an issue on the ``vllm-project/llm-compressor`` GitHub repository.

Finally, you can load the quantized model checkpoint directly in vLLM.

.. code-block:: python
Deprecated Flow
------------------

from vllm import LLM
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8-Dynamic/")
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
result = model.generate("Hello, my name is")
.. note::

The following information is preserved for reference and search purposes.
The quantization method described below is deprecated in favor of the ``llmcompressor`` method described above.

For static per-tensor offline quantization to FP8, please install the `AutoFP8 library <https://github.com/neuralmagic/autofp8>`_.

.. code-block:: bash
git clone https://github.com/neuralmagic/AutoFP8.git
pip install -e AutoFP8
This package introduces the ``AutoFP8ForCausalLM`` and ``BaseQuantizeConfig`` objects for managing how your model will be compressed.

Offline Quantization with Static Activation Scaling Factors
-----------------------------------------------------------

For the best inference performance, you can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.
You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the ``activation_scheme="static"`` argument.

.. code-block:: python
Expand Down Expand Up @@ -169,41 +202,3 @@ Finally, you can load the quantized model checkpoint directly in vLLM.
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
result = model.generate("Hello, my name is")
FP8 checkpoint structure explanation
-----------------------------------------------------------

Here we detail the structure for the FP8 checkpoints.

The following is necessary to be present in the model's ``config.json``:

.. code-block:: text
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": "static" or "dynamic"
}
Each quantized layer in the state_dict will have these tensors:

* If the config has ``"activation_scheme": "static"``:

.. code-block:: text
model.layers.0.mlp.down_proj.weight < F8_E4M3
model.layers.0.mlp.down_proj.input_scale < F32
model.layers.0.mlp.down_proj.weight_scale < F32
* If the config has ``"activation_scheme": "dynamic"``:

.. code-block:: text
model.layers.0.mlp.down_proj.weight < F8_E4M3
model.layers.0.mlp.down_proj.weight_scale < F32
Additionally, there can be `FP8 kv-cache scaling factors <https://github.com/vllm-project/vllm/pull/4893>`_ contained within quantized checkpoints specified through the ``.kv_scale`` parameter present on the Attention Module, such as:

.. code-block:: text
model.layers.0.self_attn.kv_scale < F32
Loading

0 comments on commit b3f4e17

Please sign in to comment.