Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales #4343

Merged

Conversation

pcmoritz
Copy link
Collaborator

@pcmoritz pcmoritz commented Apr 24, 2024

This PR contains the performance improvements mentioned in #4244

The performance numbers are as follows (TP2, H100) -- with mistralai/Mixtral-8x7B-Instruct-v0.1:

Screenshot 2024-04-24 at 3 18 31 PM

Note that to get the performance improvements from static scaling, you will need a model checkpoint with the scales.

Instructions on how to create the checkpoint:

git clone git@hf.co:mistralai/Mixtral-8x7B-Instruct-v0.1

and copy the directory over to Mixtral-8x7B-Instruct-v0.1-fp8.

Inside the Mixtral-8x7B-Instruct-v0.1-fp8 directory, run the following script (with mixtral_scales.pth generated as described in #3208 (comment)):

from safetensors import safe_open
from safetensors.torch import save_file
import torch

activation_scales = torch.load("/tmp/mixtral_scales.pth")

def rewrite_safetensors(name):
    tensors = {}
    with safe_open(name, framework="pt") as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
            if "w1" in k or "w2" in k or "w3" in k:
                activation_scale_prefix = k.removesuffix(".weight")
                activation_scale_name = activation_scale_prefix + ".act_scale"
                print(f"activation scale for {k} is {activation_scale_name}")
                tensors[activation_scale_name] = activation_scales[activation_scale_prefix].max()
    save_file(tensors, name)

for i in range(1, 20):
    filename = f"model-{i:05}-of-00019.safetensors"
    print(f"rewriting {filename}")
    rewrite_safetensors(filename)

Then create a file quantize_config.json inside of Mixtral-8x7B-Instruct-v0.1-fp8 with

{
    "activation_scheme": "static"
}

You can then run the model with

   from vllm import LLM, SamplingParams
   prompts = [
       "Hello, my name is",
       "The president of the United States is",
       "The capital of France is",
       "The future of AI is",
   ]
   sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
   
   llm = LLM(model="/mnt/local_storage/Mixtral-8x7B-Instruct-v0.1-fp8", tensor_parallel_size=2, quantization="fp8")
   
   outputs = llm.generate(prompts, sampling_params)
   
   for output in outputs:
      prompt = output.prompt
      generated_text = output.outputs[0].text
      print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Note that this is probably not the best way to generate these scales, I'm looking forward to suggestions to generate better ones :)


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM!
The way to propagate activation scaling method is a bit tricky, but it should be a lot easier with #4342. We can merge this PR first given it's more straightforward and no design changes.

vllm/_custom_ops.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/fp8.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon self-assigned this Apr 25, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcmoritz Thanks for the PR! The code looks very clean. I like it! Left some minor comments on the code style. Please check them out.

vllm/_custom_ops.py Outdated Show resolved Hide resolved
vllm/_custom_ops.py Show resolved Hide resolved
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
return ["quantize_config.json"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: For static quantization, can we just have this scaling factor in config.json? I'm not sure if this is a better decision than having a separate quantization config file, but it seems feasible. WDYT?

Copy link
Collaborator Author

@pcmoritz pcmoritz Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is feasible. The reason why I put it into the safetensor files is because that's a better place to store tensors than the config.json and for some schemes, these tensors can be larger (e.g. for per-channel quantization). Since the checkpoint already needs to be rewritten (to add the quant_config.json and also possibly to convert weights to FP8), I don't think it is a big problem to rewrite the safetensors to include activation scales. This is also what https://huggingface.co/FriendliAI/Mistral-7B-Instruct-v0.2-fp8 does (note that their model format is otherwise not very useful since it stores the weights as int8).

I assume that a standard for FP8 will emerge, and I would expect it to store the scales in the safetensor files -- this is a no brainer for weight scales to keep them close to the weights and make sure they are consistent with the quantized weights but also makes sense for activation scales. Once we have a standard, we should use that. Right now, I don't think trying to invent our own standard in quantize_config.json is a good idea (since it involves a schema), whereas storing scales in the safetensor scales is pretty canonical and doesn't require us to invent a lot of convention.

These are my reasons -- let me know if you prefer otherwise :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also cc @robertgshaw2-neuralmagic who also thought about this a bunch I think :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @pcmoritz , the scales should be in a safetensors file

The mental model for the quantize_config.json is that it would hold metadata about what is in the safetensors file.

So examples could be:

  • Datatype of the weights
  • Whether the activations are static or dynamic (so we dont have to peek into safetensors)
  • Channelwise vs not, etc

For this first implementation, we don't this, but if we start supporting various different schemes, then we will (since we need to know this when create_weights is called - which happens before we look at the safetensors file)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, mostly quantize_config is not used anymore, even for AutoGPTQ

The config.json typically has a quantization_config.

https://huggingface.co/astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit/blob/main/config.json

We only fall back to quantize_config.json if quantization_config is not found in the config.json

vllm/model_executor/layers/fused_moe/fused_moe.py Outdated Show resolved Hide resolved
vllm/model_executor/models/mixtral.py Outdated Show resolved Hide resolved
need_act_scales = (self.use_fp8 and
linear_method.quant_config.act_scaling == "static")
self.as_scale = nn.Parameter(
torch.zeros(1, device="cuda", dtype=torch.float32),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Theoretically, we shouldn't use "cuda" in the model code. Since the GPU worker sets "cuda" as the default device in torch, device="cuda" is not necessary. Also, it's not good for the compatibility with non-CUDA devices.

This rule is violated for Mixtral and other MoE models unfortunately. 😢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll make a follow up PR to remove the device="cuda" -- since we also specify it explicitly for the other parameters, I don't want to be inconsistent for this PR :)

@WoosukKwon WoosukKwon removed their assignment Apr 25, 2024
pcmoritz and others added 4 commits April 24, 2024 20:48
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @pcmoritz generally looks good

Question - what is the UX for this?

I see three cases.

  • We have an fp8 model with dynamic activation quantization
  • We have an fp8 model with static activation quantization
  • We have an fp16 model which we want to quantize automatically and run with dynamic quantization

For 1-2, I don't think the user should have to specify that they want fp8. For these cases, I think we will need a simple quantization_config inside the config.json that we can use to parse which of the cases we have

For 3, there is no quantization config needed. The user should have to specify they want to use fp8 in this case

@classmethod
def get_config_filenames(cls) -> List[str]:
return []
return ["quantize_config.json"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @pcmoritz , the scales should be in a safetensors file

The mental model for the quantize_config.json is that it would hold metadata about what is in the safetensors file.

So examples could be:

  • Datatype of the weights
  • Whether the activations are static or dynamic (so we dont have to peek into safetensors)
  • Channelwise vs not, etc

For this first implementation, we don't this, but if we start supporting various different schemes, then we will (since we need to know this when create_weights is called - which happens before we look at the safetensors file)

})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to this PR, I think we should have an MoELayer that is shared across models

All of these changes are currently only impacting Mixtral, but could also be applied to other models. Since we have all this generic logic in the model definitions, we are losing out at running others with these features

@classmethod
def get_config_filenames(cls) -> List[str]:
return []
return ["quantize_config.json"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, mostly quantize_config is not used anymore, even for AutoGPTQ

The config.json typically has a quantization_config.

https://huggingface.co/astronomer/Llama-3-8B-Instruct-GPTQ-8-Bit/blob/main/config.json

We only fall back to quantize_config.json if quantization_config is not found in the config.json

pcmoritz and others added 2 commits April 25, 2024 13:45
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@pcmoritz
Copy link
Collaborator Author

Hey @pcmoritz generally looks good

Question - what is the UX for this?

I see three cases.

  • We have an fp8 model with dynamic activation quantization
  • We have an fp8 model with static activation quantization
  • We have an fp16 model which we want to quantize automatically and run with dynamic quantization

For 1-2, I don't think the user should have to specify that they want fp8. For these cases, I think we will need a simple quantization_config inside the config.json that we can use to parse which of the cases we have

For 3, there is no quantization config needed. The user should have to specify they want to use fp8 in this case

Yeah exactly -- for 1 and 2, the config.json in the model checkpoint will have the right quantization_config and we will respect these configurations. For 3, the user will have to specify quantization="fp8" :)


quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in possible_config_filenames)
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for GPTQ/AWQ, having a quantize_config.json is not necessarily required. So I think this check could break models that have:

  • quantization_config in config.json
  • no quantize_config.json

For example: https://huggingface.co/casperhansen/llama-3-70b-instruct-awq/tree/main

Our CI does not seem to have any models with this setup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I fixed this now by removing the quantize_config.json support -- we can just use config.json for specifying the quantization for FP8 checkpoints for the time being :)

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic merged commit 12628d3 into vllm-project:main Apr 27, 2024
48 checks passed
pcmoritz added a commit that referenced this pull request May 1, 2024
Remove the device="cuda" declarations in mixtral as promised in #4343
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
…les (vllm-project#4343)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 6, 2024
Remove the device="cuda" declarations in mixtral as promised in vllm-project#4343
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
…les (vllm-project#4343)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
Remove the device="cuda" declarations in mixtral as promised in vllm-project#4343
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request May 7, 2024
Remove the device="cuda" declarations in mixtral as promised in vllm-project#4343
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…les (vllm-project#4343)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Remove the device="cuda" declarations in mixtral as promised in vllm-project#4343
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants