Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Misc ] Refactor MoE to isolate Fp8 From Mixtral #5970

Merged
merged 48 commits into from
Jul 2, 2024

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Jun 28, 2024

MOTIVATION:

  • Prior to this PR, MoE layer + weight loading logic is repeated for all models
  • Prior to this PR, fp8 logic for MoE is coupled with Mixtral and unsupported for other MoE Models as a result
  • Prior to this PR, difficult to add other quantization methods without significant duplicated code in model files, which is very hard to maintain (e.g. made it hard to land GPTQ & AWQ Fused MOE #2761)

THIS PR:

  • Creates concept of FusedMoEMethodBase, UnquantizedFusedMoEMethod, Fp8FusedMoeMethod
  • Creates concept of FusedMoE layer (akin to MergedColumnParallel + RowParallel
  • Confirm correctness with lm-eval harness
  • Confirm able to use FusedMoELinear with Qwen2Moe

FOLLOW UP PR:

  • Support Fp8 checkpoints for models other than Mixtral: (weight loading for scales is currently overfit to Mixtral)
  • Convert other MoE models to use FusedMoELinear: to enable Fp8 for those cases
  • Finally integrate GPTQ / AWQ MoE kernels (GPTQ & AWQ Fused MOE #2761)

ACCURACY:

  • Mixtral
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=4,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot 5 --limit 1000 --batch_size "auto"

main:
|Tasks|Version|     Filter     |n-shot|  Metric   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|----:|---|-----:|
|gsm8k|      3|strict-match    |     5|exact_match|0.643|±  |0.0152|
|     |       |flexible-extract|     5|exact_match|0.648|±  |0.0151|

branch:
|Tasks|Version|     Filter     |n-shot|  Metric   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|----:|---|-----:|
|gsm8k|      3|strict-match    |     5|exact_match|0.643|±  |0.0152|
|     |       |flexible-extract|     5|exact_match|0.648|±  |0.0151|
  • Qwen2 MoE
MODEL="Qwen/Qwen2-57B-A14B-Instruct"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=4,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot 5 --limit 1000 --batch_size "auto"

main:
Tasks|Version|     Filter     |n-shot|  Metric   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|----:|---|-----:|
|gsm8k|      3|strict-match    |     5|exact_match|0.812|±  |0.0124|
|     |       |flexible-extract|     5|exact_match|0.834|±  |0.0118|


- branch:
|Tasks|Version|     Filter     |n-shot|  Metric   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|----:|---|-----:|
|gsm8k|      3|strict-match    |     5|exact_match|0.810|±  |0.0124|
|     |       |flexible-extract|     5|exact_match|0.834|±  |0.0118|

Until we have the large models re-added to the CI, I run the following script offline:

TP_SIZE=4
FEWSHOT=5
LIMIT=250
BATCH_SIZE="auto"

# Mixtral
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Mixtral - Quantized in place
MODEL="mistralai/Mixtral-8x7B-Instruct-v0.1"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",quantization="fp8" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Mixtral - Quantized
MODEL="neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",quantization="fp8" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Qwen
MODEL="Qwen/Qwen2-57B-A14B-Instruct"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Qwen - Quantized in place
MODEL="Qwen/Qwen2-57B-A14B-Instruct"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",quantization="fp8" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Llama3
MODEL="meta-llama/Meta-Llama-3-8B-Instruct"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Llama-3 Quantized in place
MODEL="meta-llama/Meta-Llama-3-8B-Instruct"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",quantization="fp8" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

# Llama-3 - Quantized
MODEL="neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
lm_eval --model vllm \
  --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray" \
  --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT --batch_size $BATCH_SIZE

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [ Misc ] Refactor MoE To Isolate Fp8 From Mixtral Code [ Misc ] Isolate Fp8 From Mixtral Code Jun 28, 2024
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [ Misc ] Isolate Fp8 From Mixtral Code [ Misc ] Isolate Fp8Moe From Mixtral Jun 28, 2024
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic marked this pull request as ready for review June 30, 2024 22:45
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [ Misc ] Isolate Fp8Moe From Mixtral DRAFT [ Misc ] Isolate Fp8Moe From Mixtral Jun 30, 2024
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title DRAFT [ Misc ] Isolate Fp8Moe From Mixtral [ WIP DRAFT ] [ Misc ] Isolate Fp8Moe From Mixtral Jun 30, 2024
rshaw@neuralmagic.com added 2 commits June 30, 2024 23:03
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic changed the title [ WIP DRAFT ] [ Misc ] Isolate Fp8Moe From Mixtral [ WIP DRAFT ] [ Misc ] Refactor MoE to isolate Fp8 From Mixtral Jun 30, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@comaniac ready for re-review

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

This should be good to go.

  • Note that Qwen and Mixtral run in the CI at fp16
  • I ran the accuracy results for fp8 offline

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for the refactoring. Should be good to go when CI is green.

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 7c008c5 into vllm-project:main Jul 2, 2024
72 checks passed
@robertgshaw2-neuralmagic robertgshaw2-neuralmagic deleted the refactor-moe branch July 2, 2024 21:54
prashantgupta24 pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 3, 2024
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
robertgshaw2-neuralmagic added a commit to neuralmagic/nm-vllm that referenced this pull request Jul 7, 2024
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 8, 2024
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This was referenced Jul 13, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants