-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC/WIP] First steps towards FP8 for Mixtral #3208
base: main
Are you sure you want to change the base?
Conversation
Instructions on how to generate activation scales with https://github.com/AniZpZ/AutoSmoothQuant:
And then run
(you need to download the pile validation data file |
I uploaded the model with quantized MoE layers to https://huggingface.co/pcmoritz/Mixtral-8x7B-Instruct-v0.1-FP8MOE now, it can be used with this PR. It needs to be loaded with
I'd love to hear about people's experience with this if anybody tries it out :) |
This PR is the first step towards fixing #3208 It implements dynamic per-tensor scaling (see #4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in #3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project/vllm#3208 It implements dynamic per-tensor scaling (see vllm-project/vllm#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project/vllm#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project/vllm#3208 It implements dynamic per-tensor scaling (see vllm-project/vllm#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project/vllm#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This PR is the first step towards fixing vllm-project#3208 It implements dynamic per-tensor scaling (see vllm-project#4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this: ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8") outputs = llm.generate(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` **Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in vllm-project#3954). With this PR, the results are as follows: <img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03"> **Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows: ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7018|± |0.0036| | - humanities |N/A |none | 5|acc |0.6472|± |0.0065| | - other |N/A |none | 5|acc |0.7673|± |0.0072| | - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070| | - stem |N/A |none | 5|acc |0.6131|± |0.0083| ``` this compares favorably with the fp16 results which are ``` | Groups |Version|Filter|n-shot|Metric|Value | |Stderr| |------------------|-------|------|-----:|------|-----:|---|-----:| |mmlu |N/A |none | 0|acc |0.7020|± |0.1313| | - humanities |N/A |none | 5|acc |0.6425|± |0.1349| | - other |N/A |none | 5|acc |0.7744|± |0.1038| | - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695| | - stem |N/A |none | 5|acc |0.6108|± |0.1383| ``` Happy hacking!
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
This pull request has merge conflicts that must be resolved before it can be |
Even after the optimizations #2979 and #2542, most of the time in the Mixtral forward pass is still spent in the fully connected layer / fused_moe kernel, see e.g. the following profiling run:
The next step to optimize this will be to use fp8 for these computations, which is known to perform well (https://arxiv.org/pdf/2208.09225.pdf, https://blog.fireworks.ai/fireattention-serving-open-source-models-4x-faster-than-vllm-by-quantizing-with-no-tradeoffs-a29a85ad28d0, https://arxiv.org/pdf/2303.17951.pdf).
NOTE: This is not yet working end-to-end, i.e. the numerics are not correct yet. We will most likely need to introduce scaling factors like in https://arxiv.org/abs/2211.10438
Can be tried out with
Current performance on master: 4.40s
Current performance on this PR: 3.40s
This PR has a long way to go before it can be ready. I'd love to get the communities input on what the best open source quantization codebases are that we could leverage to compute the scaling factors (e.g. I'm aware of https://github.com/AniZpZ/AutoSmoothQuant and https://github.com/mit-han-lab/smoothquant).