Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ & AWQ Fused MOE #2761

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open

Conversation

chu-tianxiang
Copy link
Contributor

@chu-tianxiang chu-tianxiang commented Feb 5, 2024

Thanks to the very smart MoE align strategy introduced in #2453, each block only uses a single expert, making it much easier to be adapted to quantized methods. This PR refactors the code to support quantized fused-MoE and adds GPTQ group gemm kernels based on exllamav2.

tokens/s of Mixtral measured at A100 using benchmark_latency.py with input_len=256 and output_len=1024.

  • GPTQ:
Batch size 1 4 16 64 256
main 38 99 176 341 846
pr 100 207 395 556 1092
  • AWQ:
Batch size 1 4 16 64 256
main 20 77 255 452 533
pr 71 183 474 1003 1207

Todo:

  • Support Deepseek MoE
  • Support AWQ via repacking
  • Add tests

@casper-hansen
Copy link
Contributor

@chu-tianxiang Great job on optimizing GPTQ! Is there another option than repacking for AWQ?

@chu-tianxiang
Copy link
Contributor Author

@chu-tianxiang Great job on optimizing GPTQ! Is there another option than repacking for AWQ?

I can implement the AWQ kernel based on current AWQ gemm implementation too. Which do you think is better?

@casper-hansen
Copy link
Contributor

@chu-tianxiang Great job on optimizing GPTQ! Is there another option than repacking for AWQ?

I can implement the AWQ kernel based on current AWQ gemm implementation too. Which do you think is better?

I would prefer it if you can base it on the current AWQ GEMM kernel

@chu-tianxiang chu-tianxiang changed the title GPTQ Fused MOE GPTQ & AWQ Fused MOE Feb 7, 2024
@chu-tianxiang chu-tianxiang marked this pull request as ready for review February 7, 2024 05:55
@chu-tianxiang
Copy link
Contributor Author

@chu-tianxiang Great job on optimizing GPTQ! Is there another option than repacking for AWQ?

I can implement the AWQ kernel based on current AWQ gemm implementation too. Which do you think is better?

I would prefer it if you can base it on the current AWQ GEMM kernel

I have updated the AWQ kernels. AWQ GEMM uses tensor cores and has better performance at large batch size, which turns out to be better suited in the MoE case.

@casper-hansen
Copy link
Contributor

This is excellent work! Looking forward to seeing this merged for a big speedup.

@casper-hansen
Copy link
Contributor

@chu-tianxiang On a side note, I tried importing the kernels from here to AutoAWQ and I am getting CUDA illegal memory access on multi-GPU while it works fine on a single GPU. It triggers at awq_group_gemm, which usually means the operation before (moe_align_block_size) had some illegal memory access operation.

However, I do not get the same issue in vLLM. Do you have any way or idea to address this issue for AutoAWQ?

@chu-tianxiang
Copy link
Contributor Author

@chu-tianxiang On a side note, I tried importing the kernels from here to AutoAWQ and I am getting CUDA illegal memory access on multi-GPU while it works fine on a single GPU. It triggers at awq_group_gemm, which usually means the operation before (moe_align_block_size) had some illegal memory access operation.

However, I do not get the same issue in vLLM. Do you have any way or idea to address this issue for AutoAWQ?

Could you please provide the branch / code to reproduce please? vLLM use separate process for tensor parallel while AutoAWQ and transformers use torch hooks for pipeline parallel. An initial guess is that moe_align_block_size not using device guard might be a problem.

@casper-hansen
Copy link
Contributor

Hi @chu-tianxiang, I added an issue to track it. I attempted to put a device guard in place and it fixes the illegal memory access error, but then results in the generated output being garbage. See details in the issue below.

casper-hansen/AutoAWQ#341

@lroberts7
Copy link

lroberts7 commented Feb 22, 2024

I built this branch and ran the all tests under python3.10 -m pytest tests/kernels and only ~20% pass, the failing ones all seem to encounter the runtime cuda error on illegal memory access that was mentioned in the thread (see screenshot)

Screen Shot 2024-02-22 at 5 20 38 PM

the tests that were added in this PR do all seem to pass though:

lroberts@GPU77B9:~/update-vllm-env/vllm-source/vllm$ python3.10 -m pytest tests/kernels/test_moe.py -k "test_fused_moe_gptq or test_fused_moe_awq" 
======================================================================================= test session starts =======================================================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.3.0
rootdir: /home/lroberts/update-vllm-env/vllm-source/vllm
plugins: asyncio-0.23.3, forked-1.6.0, anyio-3.7.1
asyncio: mode=strict
collected 1299 items / 291 deselected / 1008 selected                                                                                                                                             

tests/kernels/test_moe.py ................................................................................................................................................................. [ 15%]
........................................................................................................................................................................................... [ 34%]
........................................................................................................................................................................................... [ 53%]
........................................................................................................................................................................................... [ 71%]
........................................................................................................................................................................................... [ 90%]
...................................................................................................                                                                                         [100%]

======================================================================================== warnings summary =========================================================================================
../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87
  /usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.1.0) or chardet (5.2.0) doesn't match a supported version!
    warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:121
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:121: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("best_of")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:140
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:140: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("repetition_penalty")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:146
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:146: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("seed")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:152
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:152: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("temperature")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:158
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:158: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("top_k")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:164
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:164: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("top_p")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:170
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:170: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("truncate")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:176
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:176: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("typical_p")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:204
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:204: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("inputs")

../../../.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:210
  /home/lroberts/.local/lib/python3.10/site-packages/huggingface_hub/inference/_text_generation.py:210: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
    @validator("stream")

../../../.local/lib/python3.10/site-packages/cupy/_environment.py:404
  /home/lroberts/.local/lib/python3.10/site-packages/cupy/_environment.py:404: UserWarning: 
  nccl library could not be loaded.
  
  Reason: ImportError (libnccl.so.2: cannot open shared object file: No such file or directory)
  
  You can install the library by:
  
    $ python -m cupyx.tools.install_library --library nccl --cuda 12.x
  
    warnings.warn(msg)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================= 1008 passed, 291 deselected, 12 warnings in 19.40s ========================================================================

EDIT: some details on environment
cuda driver version: 530.30.02
lroberts@GPU77B9:/update-vllm-env/vllm-source/vllm$ python3 -c "import torch; print(torch.version)"
2.1.2+cu121
roberts@GPU77B9:/update-vllm-env/vllm-source/vllm$ python3 -c "import transformers; print(transformers.version)"
/usr/lib/python3/dist-packages/requests/init.py:87: RequestsDependencyWarning: urllib3 (2.1.0) or chardet (5.2.0) doesn't match a supported version!
warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "
4.37.1

@casper-hansen
Copy link
Contributor

@lroberts7 it seems your tests are failing for reasons unrelated to this PR. I think you may have an environment issue or some problem with the GPUs.

@chu-tianxiang
Copy link
Contributor Author

The PR breaks the mixtral unit test previously and I pushed a fix for it, but I'm still seeing illegal memory access in the CI test after the commit. I'm not sure what the problem is yet, I pulled the docker image built in CI and cannot reproduce the problem running locally.

@joennlae
Copy link
Contributor

What is the merge plan here?

@omarsou
Copy link

omarsou commented Mar 26, 2024

What is the merge plan here?

+1
We will be more than happy to see this being merged :)

@simon-mo
Copy link
Collaborator

Hi everyone, thank you for the active development on this PR. We would really like to include this in the next release. However, we identified few issues: (1) the code made some significant change the existing moe implementation that needs to be carefully reviewed (2) there are some merge conflict (3) the main "code owners" who are familiar with code path for recent moe changes @pcmoritz and @WoosukKwon is lacking in bandwidth.

Therefore, we would like to push this to the next release v0.4.1 which is targeted around mid April.

@chu-tianxiang
Copy link
Contributor Author

Thanks for all the attention. I fixed the conflicts and added quantization support for Qwen2Moe model. Tested with Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4 without problem. dbrx is more tricky as discussed in databricks/dbrx-instruct and is not supported by AutoGPTQ / AutoAWQ yet, so I'll leave it till quantized models are available.

Btw, yapf and isort seem to have conflicting format rules, I'm not sure how that could be handled.

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Mar 31, 2024

@chu-tianxiang Thanks for the great PR

I have one major piece of feedback. This PR effectively supports two cases:

  1. There is a fused MoE kernels for the quantization type, in which case we use the fused kernels (matching the logic in current main Mixtral.py)
  2. There is not a fused MoE kernel for the quantization type, in which case we use the naive looping over the experts with the gemm kernels (matching the logic in current main MixtralQuant.py

Supporting both of these cases adds significant complexity to the implementation, since we now have a big if statement in each of the core methods in the model definition:

if not isinstance(self.linear_method, UnquantizedLinearMethod) and not self.linear_method.quant_config.support_fused_moe():
        # case 2 --> there is not a fused kernel
else:
        # case 1 --> there is a fused kernel

This impacts each the core methods in the model definitions:

  • __init__ --> now, we need to maintain two weight definitions for MLP
  • forward --> now, we need to maintain two forward methods for MLP
  • load_weights --> now, we have to have two cases for loading the weights for MLP

Since we now have kernels for GPTQ and AWQ, which are by far the most popular quantization methods, I think it makes sense to remove support for case 2 and simply fail if the user tries to run a quantization method that does not support fused_moe execution. This will dramatically simplify the code and make it much easier to (a) maintain and (b) add new MoE models in the future.

Neural Magic is already working on a fused MoE version of Marlin as well. So it will really just be SqueezeLLM that lacks a fused kernel. I think this is a completely worthwhile tradeoff

@chu-tianxiang
Copy link
Contributor Author

@robertgshaw2-neuralmagic Thanks for the suggestion, the current logic does increase the code complexity of MoE models quite a bit. Inspired by your analysis, I'm thinking that the root cause of complexity is that fused MoE uses tensor parallel while the unfused uses expert parallel, maybe we change the unfused MoE implementation from expert parallel to the very initial tensor parallel. If it works out we can have simple code and full quantization support at the same time.

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Apr 1, 2024

@chu-tianxiang Are you okay if I make a proposal for a refactor to the logic?

@chu-tianxiang
Copy link
Contributor Author

@chu-tianxiang Are you okay if I make a proposal for a refactor to the logic?

Sure, please feel free to do so.

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Apr 8, 2024

@chu-tianxiang LMK when you're ready for a re-review

The refactor I have been working on basically makes a shared FusedMoELinear layer that can be used across Deepseek, Mixtral, and QwenMoE. Should be able to sync up with the changes you have made such far

Since all that logic is duplicated across each model, thought it made sense to abstract it into a new Linear type

@chu-tianxiang
Copy link
Contributor Author

@robertgshaw2-neuralmagic Thanks, it is ready now. Following your suggestion, I removed the expert parallel part and it's much cleaner. Currently tested on Mixtral (fp16, awq, gptq-4bit, gptq-3bit), Deepseek (fp16, gptq-4bit) and Qwen-moe (fp16, gptq-4bit).

@robertgshaw2-neuralmagic
Copy link
Collaborator

Sweet - thanks @chu-tianxiang

Will take a look later this week.

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Apr 18, 2024

Hey @chu-tianxiang - these changes are looking much better, the logic is much simpler and easier to parse

The final architectural change I would want to see to feel comfortable merging is to abstract the fused MoE layer logic from the model definitions into a new class in linear.py called FusedMoELinear. This would resolve the final issues I see with the PR:

  • A) We have the same logic for making an MoE layer in Mixtral DeepSeek and Qwen. Having a single implementation will be easier to maintain.

  • B) We currently overload MergedColumnParallelLinear and RowParallelLinear in each model to handle the weight loading. The FusedMoELinear class can handle this logic (likely reusing the code from MergedColumnParallelLinear and RowParallelLinear

  • C) We currently have each Linear layer creating weights with create_moe_weights and each linear constructor requires expert_id to be passed. This is confusing. By having FusedMoELinear, we can call the create_moe_weights just in that class.

The key problem with the proposal I laid out is that it makes the mapping of the vllm state dict to the hf state dict more difficult. So we will need to handle this in each model's load_weights function

FusedMoELinear Proposal

class FusedMoELinear(torch.nn.Module):
    def __init__(shapes, linear_method):
        # gate_up_proj
        self.ws = linear_method.create_moe_weights(shapes)
        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader_merged_column,
        })
       
        # down_proj 
        self.ws = linear_method.create_moe_weights(shapes)
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader_row_parallel,
        })
       # ...

    # weight loader for gate_up_proj    
    def weight_loader_merged_column(param, loaded_weight, expert_id):
       # refactor to share with MergedColumnParallel?  << make method static in MergedColumnParallel?
       pass

    # weight loader for down_proj
    def weight_loader_row_parallel(param, loaded_weight, expert_id):
       # refactor to share with MergedColumnParallel? << make method static in RowColumnLinear?
        pass

    def forward(hidden_states, router_logic):
        linear_method.apply_moe_weights(**)

Then, this layer would be part of Mixtral:

class MixtralMoE(torch.nn.Module):
     def __init__():
          self.gate = ReplicatedLinear()
          # note: this breaks the disk state dict (model.layers.0.mlp.w1 --> model.layers.0.mlp.fused_moe.ws)
          self.fused_moe = FusedMoE()
    def forward(hidden):
          router_logits = gate(hidden)
          return self.fused_moe(hidden, gates)
    
    # handle complexity of state dict remapping here
    def load_weights():
         # model.layers.0.mlp.w1 --> model.layers.0.mlp.fused_moe.ws 
         # model.layers.0.mlp.w2--> model.layers.0.mlp.fused_moe.w2
         # model.layers.0.mlp.w3--> model.layers.0.mlp.fused_moe.w3

WDYT?

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments for final requested architectural changes. Also we need tests for this.

It would be nice to have some big model tests and small model tests (that can run on a single GPU).

The small model tests should be possible for deepseek and qwen as they have sizes that fit on a single GPU

@jinzhen-lin
Copy link
Contributor

Hi, is this pr still active? Looking forward to this pr being merged.

@chu-tianxiang
Copy link
Contributor Author

Sorry I've been quite busy with personal life over the past month, left me with little time to update. Additionally, when I attempted to update last month, I encountered some conflicts that were hard to resolve.

Originally I created the MoE weights by adding an axis to every weights in linear_weights, however #3977 removed linear_weights and register as parameters directly. #4342 and #4373 suggested using fully separate logics for linear layer and MoE layer. Now I don't know how I can create MoE weights without copy-and-paste the create_weight method for every quantization method, which is way too inelegant. If anyone can come up a better design, please let me know or feel free to create another PR.

@fengyang95
Copy link

Does this PR support deepseek-v2 awq?

@casper-hansen
Copy link
Contributor

@fengyang95 This was intended for v1 but should be extendable to v2. I hope @robertgshaw2-neuralmagic is able to pick this up at some point to get it through :)

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 30, 2024
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @chu-tianxiang please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.