Skip to content

[RFC]: hybrid dtype: float32 for weights and activation, float16 or bfloat16 for attention. #18342

Closed
@noooop

Description

@noooop

Motivation.

vllm defaults to using float16 inference for float32 models.

https://github.com/vllm-project/vllm/blob/275c5daeb0048c3b3f359bb5d9478b1e75e02857/vllm/config.py#L3053C1-L3055C44

# Set default dtype from model config
if config_dtype == torch.float32:
    # Following common practice, we use float16 for float32 models
    torch_dtype = torch.float16

Most models can maintain their original precision, but a few models require float32.
But flash attn does not support float32, making it a very ineffective choice.

More reports of embedding model precision significantly decreasing at float16:
#17175 #17986 #17785 #15393 ....

Wider numerical issues PTAL #17123

Proposed Change.

This RFC manage to use float32 for weights and activation, float16 or bfloat16 for attention.

For models where precision drops significantly at float16, this might be a better choice, especially models that support long context.

But previously, kv_cache_dtype defaulting to dtype. If the generation model is to support hybrid dtype, kv_cache_dtype should default to be consistent with attn_dtype., which requires changing many places and can introduce a lot of bugs.

So it's better to implement this RPC in two phases, first supporting embedding models (Pooling Models), and then supporting generative models.

Embedding model

PTAL #18940

Generative models

I carefully constructed a test with a small model, which only passed with float32 and hybrid dtype, but failed with float16. This demonstrates the effectiveness of hybrid dtype.

Supporting generation models requires modifying a lot of code, but the benefits are not significant.

Future

  • Figure out what training configurations would lead to a decrease in fp16 inference precision, and what configurations can be avoided.
  • Online detection of models that may decrease in fp16 inference precision and provide warnings. Now it needs to be compared offline with the results of SentenceTransformer on a dataset. e.g. can the distribution of model parameters be used to detect a potential decrease in fp16 inference precision?
  • Online repair of float32 models that decrease in fp16 inference precision. Similar to an in-fight float16 quantization method.

Feedback Period.

No response

CC List.

core vLLM team

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions