Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) #3290

Merged
merged 178 commits into from
Apr 3, 2024

Conversation

AdrianAbeyta
Copy link
Contributor

@AdrianAbeyta AdrianAbeyta commented Mar 8, 2024

As part of a series of FP8 development in vLLM, we address an OCP format (nVIDIA compatible) FP8 KV cache in this pull request. We elaborated upon previous #2279, but made following change, enhancement and extensions:

  • Using OCP FP8 data type, E4M3 recommended for inference (as Float8E4M3FN in MLIR, float8_e4m3fn in PyTorch)
  • Using scaled FP8 KV cache, to mitigate quantization loss (scaling factors are aquired from AMD quantizer, AMMO, etc.)
  • Enabled on AMD MI3xx GPUs, MI300x (192GB HBM) in particular (less performant on older silicons without FP8 HW)

Design reference:

  • RFC: FP8 Quantization Schema in vLLM #3218
  • RFC: FP8 in vLLM #2461

Scope:

  • Used in conjunction with Quantizer's output: KV cache scaling factors. For this phase, not include the activation and weights sections from the JSON schema proposed in #2461. Quantizer's output may need to be formatted to that schema based JSON file for vLLM code to consider, an utility script (3rdparty/quantizer/extract_scales.py) is provided for JSON generation from AMMO's output.
  • Quantizer supported: AMD Quantizer, nVIDIA AMMO. an utility script (3rdparty/quantizer/quantize.py) is provided for using AMMO to quantize HF model to FP8 with FP8 KV cache s.t. KV cache scaling factors will be generated (over a calibartion dataset, which you can change to your domain of interests), details in 3rdparty/README.md.
  • Only the common OCP format used for FP8 inference and model forward/eval e4m3fn is enabled, this comes with HW support (so performant) on AMD MI3xx GPUs. Same design is still functional but less performant on earlier AMD GPUs, current design does not cover CUDA device.
  • Model: Llama first, others will be added later after approval.
  • FP8 KV cache only, with scaling, FP8 compute coming next.

Scaling semantics:

  • In concept and this design, we have following definition:
    scaling_factor = AbsMax(input_tensor_fp16_or_bfloat16_or_fp32) / (OCP_E4M3_MAXNORM = 448.0) 
    
  • This semantics is used by AMD quantizer, and AMMO upon observation.
  • scaled_to_fp8_quant: fp8_tensor = fp8_quant(higher_precision_tensor / scaling_factor)
  • scaled_fr_fp8_dequant: higher_precision_tensor = fp8_dequant(fp8_tensor) * scaling factor
  • per tensor scaling as current; per channel, etc. scaling in the future

Usage:

To start, please refer to:

./tests/fp8_kv/README.md
./3rdparty/README.md

Two example JSON files are provided under:

./tests/fp8_kv/

If you run vLLM with kv_cache_dtype="fp8" but not provide JSON file containing scaling factors, then no scaling will be applied towards FP8 (e4m3fn) quantization, which may lead to less accurate results.

manual execution:

from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="/data/models/llama-2-70b-chat-hf", kv_cache_dtype="fp8", scales_path="./tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)

Performance:

We observed 20~30% performance increases from FP16 baseline by just turning KV cache to FP8 (e4m3fn), even on the 70B model served by a single MI300X.

WizardCoder-34b score, dataset: HumanEval-Python-EN on 1-GPU MI300X

KV cache type and config pass@1 pass@10 pass@100
FP16 (T=0.8) 30.63% 79.82% 95.73%
FP8_scaled (T=0.8) 30.55% 78.76% 94.51%
FP16 (BEAM, T=0.0) 40.76% 55.04% 58.54%
FP8_scaled (BEAM, T=0.0) 43.38% 56.63% 58.54%

Contributors:

@HaiShaw, @AdrianAbeyta, @gshtras, @mawong-amd, @Alexei-V-Ivanov-AMD

gshtras and others added 30 commits February 5, 2024 21:07
Add non-MI300 compatible alternative for bulk conversions
Removed bf8 (e5m2) and renamed f8 to fp8 to explicitly specify that it is e4m3
Removed stochastic rounding for simplicity
Put bulk fp8 conversion hip intrinsics behind a define. Disabled by default
Using types from the proper vllm headers. Added namespace
Move amd specific headers under amd_detail
Reduce fp8 range in the conversion test to match e4m3
Add other MI300 architectures to the list
Simplify device guard use in conversion kernel
Rename remaining fp8_e5m2 to general fp8
…ing factor instead of Tensor, should be working out the box
…, remove PT support from scales extraction utility
Copy link
Contributor

@zhaoyang-star zhaoyang-star left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall now the pr is lgtm. @zhuohan123 could you please take time to review it?

Copy link
Contributor

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve according to @zhaoyang-star

@@ -0,0 +1,32 @@
### Quantizer Utilities
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm confused on why do we need to vendor the script? is this the same script from ammo or written by you?

Copy link
Contributor

@HaiShaw HaiShaw Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo , this is a script taken from nVIDIA, as part of their quantizer examples, we kept them under 3rdparty, and their license banner unchanged. It is included only for the reference and convenience.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's no change, let's not include it please. We are should comfortable referring users to the ammo/quantizer repo to perform the quantization.

#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "../quantization/fp8/amd_detail/quant_utils.cuh"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought E4M3 is NV compatible as well? Also is it possible to enable both?

Copy link
Contributor

@HaiShaw HaiShaw Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo , OCP E4M3 is NV compatible. But we don't address NV platform in this PR.


These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
unquantized model through a quantizer tool (e.g. AMD quantizer or nVIDIA AMMO).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unquantized model through a quantizer tool (e.g. AMD quantizer or nVIDIA AMMO).
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a link to these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detail to fetch nVIDIA AMMO is in the README, do you want to have that here also?

unquantized model through a quantizer tool (e.g. AMD quantizer or nVIDIA AMMO).

Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent
silicon offerings e.g. AMD MI300, nVIDIA Hopper or later support native hardware conversion to/from fp32, fp16, bf16, etc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
silicon offerings e.g. AMD MI300, nVIDIA Hopper or later support native hardware conversion to/from fp32, fp16, bf16, etc.
silicon offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to/from fp32, fp16, bf16, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will do a full replacement, btw nVIDIA is official.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo good update me on this 👍 and those were changed.

docs/source/quantization/fp8_e4m3_kvcache.rst Show resolved Hide resolved

from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=1.2, top_p=0.9)
llm = LLM(model="/data/models/llama-2-7b-chat-hf",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use off the shelf one from huggingface so users can use it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/data/models/llama-2-7b-chat-hf is the local path of a converted HF model (converted from Meta released LL2 model in standard manner). Changing the pointer of HF model path will be the same, will update after some verification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to HuggingFace off-the-shelf model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we only doing this for llama? is other models supported as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simon-mo , we plan to enable all other models (that both vLLM and quantizer/ammo support), once this PR (design) is approved. At present, other models will only be using default scaling factor 1.0. We will send follow-up PR for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because so, please note in documentation about this.

# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment make sense, but not the *=2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do *2 only for HIP, to deal with the difference in numeric from our chip. after *2 overall effect is identical as without it on NV.

@@ -275,6 +275,107 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache()


def kv_cache_scales_loader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the best way to distribute the scaling factors? i wonder whether there can be a convention to include it in the model's weights dictionary, similar to how other quantization methods support it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most other methods is weight quantization only, we sure looked.
Here, we first introduced scaling factors to KV cache (not weight sort), soon we will also add scaling factors for activations (tensor Xs). We found adding them as extended parameters this way making sense and also light weight.

# Since the number of layers is small and (for now) we use scalar
# scaling factors (so the size they use is also small), this is
# not a concern at present.
schema = json.load(f, parse_int=int, parse_constant=float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
schema = json.load(f, parse_int=int, parse_constant=float)
schema = json.load(f)

@@ -275,6 +275,107 @@ def hf_model_weights_iterator(
torch.cuda.empty_cache()


def kv_cache_scales_loader(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not simple. We should not do these validation ourselves. Please use pydantic library to define the schema to be read into. https://docs.pydantic.dev/latest/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was looking to refactor the logic with jsonschema/pydantic once a formal schema is established for quantization params so we can generalize across models and also integrate it elsewhere (potential candidates: scales extraction utility, quantization config). But good call on at least doing the basic structure checking with Pydantic for now.

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concerns about the API naming (I would like to hear @zhuohan123 and @WoosukKwon's thought on this):

  1. --quantization-param-path: should it be named more narrowly, something like --scaling-factor-per-layer-json? Even with this I find it difficult for users without understanding the JSON format.
  2. Naming of kv_scale in the paged attention kernel. Should it be called scaling_factor or fp8_scaling?

Overall I think once the remaining comments are settled and these two API design questions resolved, this PR is in good shape to merge.

For the future, it would be a lot easier to review if the renaming part is isolated out to a separate PR.

Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec.
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
#elif defined(ENABLE_FP8_E4M3)
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// Vector conversion from Quant_vec to K_vec. Scaled conversion: FP8 => higher precision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does "Scaled conversion: FP8 => higher precision" mean here. Please use full sentence here to help maintainers understand. Do you mean "We use the scaled_vec_conversion library function for better precision"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved

==================

Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the
cache, improving throughput. OCP specifies two common floating point data formats: E5M2 (5 exponent bits and 2 mantissa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is OCP? explain in this doc for users.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved


These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO [pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo]).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formatting is off. Please use code block properly

https://vllm--3290.org.readthedocs.build/en/3290/quantization/fp8_e4m3_kvcache.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved

unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO [pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo]).

Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent
silicon offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to/from fp32, fp16, bf16, etc.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MI300 or MI300x?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to/from -> to and from

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In codebase, we use MI300 not to differentiate sub-models of MI300, though MI300X is a common SKU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

solved

Comment on lines 27 to 28
# two float8_e4m3fn kv cache scaling factor files are provided at tests/fp8_kv,
# refer to tests/fp8_kv/README.md to generate kv_cache_scales.json of your own.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide github link instead. Something like

https://github.com/vllm-project/vllm/blob/main/tests/fp8_kv/README.md

It will 404 now but works later.

Copy link
Contributor

@HaiShaw HaiShaw Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is Ping on this below, code seems to be different, previously added a comment line to show non-local HF model path to llama-2-7b-chat-hf.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to HF-remote llama-2-7b path.

@@ -402,3 +414,28 @@ def load_weights(self,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# Should not be called unless the KV cache dtype is FP8 on ROCm (AMD GPU)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in that case please use assert to confirm the invariant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the comment, it was meant to be more informative ("FYI, this function is for scaled KV cache, which is currently enabled on ROCm only") rather than bad behavior we should guard against. In particular, the scaling factor is only used when the KV cache dtype is FP8 and on ROCm, so calling this function in other settings has no observable side effects. So there's no need for an assert here.

More broadly, the current design largely decouples the KV cache implementation from the model implementation (which makes sense, as KV caches are not theoretically necessary). IMO, guarding against potential misuse (which is side effect free anyway) isn't a strong enough reason to newly endow the model with the ability to introspect KV cache details.

logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
return ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returning an empty tuple is not common in Python, it took me a while to understand.

Suggested change
return ()
return []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for the list-stans.

# schemas out into a separate file that deals solely with quantization
# params and their related schemas so that they can be generalized and
# shared across various use cases.
class KVCacheQuantSchema(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move out of this function please

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

f"TP rank {tp_rank}.")
return self

class QuantParamSchema(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this out of the function please.

inline class definition has bad performance for evaluated each time.

Copy link
Contributor

@mawong-amd mawong-amd Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance should not be a significant consideration here. This is called at most tp_size times during the loading process and the runtime of this function is far eclipsed by the time it takes to load GB-scale weights. Good object-oriented decomposition considerations are more salient. In any case, they've been moved out of the function.

@computed_field
@property
def rank_keyword(self) -> str:
# Each TP rank key should be prefixed by a common rank_keyword.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't realize this is tp dependent. please add this to documentation!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rank keyword has been removed: the extra processing was more trouble than it's worth.

@HaiShaw
Copy link
Contributor

HaiShaw commented Apr 2, 2024

Concerns about the API naming (I would like to hear @zhuohan123 and @WoosukKwon's thought on this):

  1. --quantization-param-path: should it be named more narrowly, something like --scaling-factor-per-layer-json? Even with this I find it difficult for users without understanding the JSON format.
  2. Naming of kv_scale in the paged attention kernel. Should it be called scaling_factor or fp8_scaling?

Overall I think once the remaining comments are settled and these two API design questions resolved, this PR is in good shape to merge.

For the future, it would be a lot easier to review if the renaming part is isolated out to a separate PR.

--quantization-param-path was called scales-path before, we modified to such in last round review to address comment and also a bit forward looking (that we expect it cover more than kv cache relevant scaling only). --scaling-factor-per-layer-json is accurate to current stand of this PR, or --layered-scaling-factor-path seems good too.
If we all agreed on a name, we can go ahead make a change.

@simon-mo simon-mo merged commit 2ff767b into vllm-project:main Apr 3, 2024
33 checks passed
@youkaichao
Copy link
Member

Seems like this PR fails the main branch now .

Alexei-V-Ivanov-AMD added a commit to Alexei-V-Ivanov-AMD/vllm that referenced this pull request Apr 4, 2024
Adding functionality to ingest scaling factors upon merge of the PR vllm-project#3290
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
)

Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: guofangze <guofangze@kuaishou.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@pcmoritz
Copy link
Collaborator

pcmoritz commented May 3, 2024

Can the files hip_float8.h and hip_float8_impl.h be part of some AMD SDK going forward? They shouldn't be part of vLLM :)

@HaiShaw
Copy link
Contributor

HaiShaw commented May 6, 2024

Can the files hip_float8.h and hip_float8_impl.h be part of some AMD SDK going forward? They shouldn't be part of vLLM :)

That was the plan, once we have common fp8 header released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.