Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] TurboMind support W8A8 or FP8 KV Cache #1463

Closed
zhyncs opened this issue Apr 19, 2024 · 12 comments
Closed

[Feature] TurboMind support W8A8 or FP8 KV Cache #1463

zhyncs opened this issue Apr 19, 2024 · 12 comments

Comments

@zhyncs
Copy link
Collaborator

zhyncs commented Apr 19, 2024

Motivation

We plan to add support for W8A8 SmoothQuant or FP8 KV Cache on TurboMind. There is currently no clear decision on which one to prioritize first. We would like to understand how the community judges the priority of these two options and if any advice can be provided. Thanks. @lvhan028 @lzhangzz @grimoire @irexyc cc @ispobock

Related resources

No response

Additional context

No response

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 19, 2024

We can currently do development and testing FP8 on L40.

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 19, 2024

Considering that FP8 has a significant advantage in precision compared to Int8, this means that we are more likely to use it in actual online serving compared to Int8.

Refer to the following blog posts:
https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482
https://blog.fireworks.ai/fireattention-serving-open-source-models-4x-faster-than-vllm-by-quantizing-with-no-tradeoffs-a29a85ad28d0

Currently, we are inclined to conduct research on FP8 internally first, and then decide which feature to work on. Do you have any suggestions?

@lzhangzz
Copy link
Collaborator

lzhangzz commented Apr 19, 2024

FP8 KV cache will be a lot more easier. You will need to add some template specialization for type conversion and some code for dispatching the kernels.

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 19, 2024

FP8 KV cache will be a lot more easier. You will need is to add some template specialization for type conversion and some code for dispatching the kernels.

Exactly we plan to reference the implementation of kv cache online Int8.

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 23, 2024

Hi all. After internal discussions, we plan to start the development work related to FP8 in May. Please stay tuned. Cheers.

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 24, 2024

The following blogs and documents are not directly related to FP8 KV Cache, they are mainly related to FP8 Attention, but also bring us some inspiration. The format used by FriendliAI for implementing FP8 Attention is E4M3. And ColfaxResearch provided an implementation reference for type conversion.

https://friendli.ai/blog/weight-activation-quantization-fp8/
https://docs.friendli.ai/guides/container/efficient_inference_with_fp8/
https://research.colfax-intl.com/adding-fp8-to-flashattention/

@zhyncs
Copy link
Collaborator Author

zhyncs commented Apr 24, 2024

We plan to add support for W8A8 SmoothQuant or FP8 KV Cache on TurboMind.

From https://friendli.ai/blog/quantization-reduce-llm-size/, it can be seen that currently, in terms of speed, SmoothQuant > AWQ > GPTQ and in terms of accuracy, AWQ > GPTQ > SmoothQuant. Among them, AWQ has achieved a good balance and has been efficiently implemented in LMDeploy. GPTQ, as well as AWQ, is W4A16 without any advantages. Our team achieved W8A8 on vLLM in the second half of last year vllm-project/vllm#1508. Due to precision issues, it is difficult for SmoothQuant to be used realistically in online environments. From this perspective, trying FP8 KV Cache on L40 can make it easier to use for online serving.

@ispobock
Copy link
Contributor

I evaluated the KV Cache INT8 in llama2 and llama3 models and get the following results:

dataset metrics llama2-13b-chat llama2-13b-chat-kvint8 llama3-8b llama3-8b-kvint8 llama3-80b llama3-80b-kvint8
ceval naive_average 35.44 35.18 48.29 48.99 67.26 67.08
mmlu naive_average 48.61 48.64 62.68 62.74 79.68 79.53
WiC accuracy 35.58 35.42 0 0 29.78 28.06
WSC accuracy 34.62 33.65 6.73 6.73 34.62 33.65
triviaqa score 60.12 60.27 60.11 60.26 76.89 76.77
gsm8k accuracy 42.61 41.77 56.71 56.63 90.07 90.67
race-middle accuracy 46.38 46.17 33.84 32.87 93.11 93.04
race-high accuracy 34.59 34.33 26.07 25.76 89.25 89.22
  • It seems that KV Cache INT8 can keep most of the accuracy, do we still need FP8 quantization for KV cache?
  • It's abnormal that the accuracy is 0 for llama3 on WiC dataset. Could you help check if I did the evaluation correctly?

Here is my evaluation steps:

# start server
lmdeploy serve api_server /workdir/llm_models/Meta-Llama-3-8B --server-name 0.0.0.0 --server-port 23333 --tp 1 --quant-policy 8

# start opencompass evaluation
python run.py configs/eval_internlm_chat_lmdeploy_apiserver.py -w outputs

@lvhan028 @lzhangzz Do you have any suggestions? cc: @zhyncs

@lvhan028
Copy link
Collaborator

OpenCompass team said WiC and WSC can be neglected

@ispobock
Copy link
Contributor

OpenCompass team said WiC and WSC can be neglected

OK, got it.

@zhyncs
Copy link
Collaborator Author

zhyncs commented May 10, 2024

  • It seems that KV Cache INT8 can keep most of the accuracy, do we still need FP8 quantization for KV cache?

Due to the excellent performance improvement and negligible accuracy loss of the Online KV Cache Int8 currently implemented in LMDeploy, we are inclined not to proceed with FP8 KV Cache for now. The ROI is not very high for us. Is there any plan in the community to work on this? Looking forward to your reply. Thanks. @lvhan028 @lzhangzz

@lzhangzz
Copy link
Collaborator

We don't have plan to support FP8 KV cache, as the current INT8 implementation works just fine and it also works on pre sm_89 devices. (well the fact is that I don't even have a sm_89+ device to start with)

We seek to improve the accuracy of current INT8/INT4 implementations by more advanced quantization methods.

@zhyncs zhyncs closed this as completed May 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants