Skip to content

[Performance] Use optimized kernels for MQA/GQA #1880

Closed as not planned
Closed as not planned
@WoosukKwon

Description

@WoosukKwon

In theory, MQA/GQA can reduce memory bandwidth for reading KV cache and enable using TensorCore for the dot products in attention mechanism. However, this benefit can be only realized when using optimized kernels that vLLM does not have at the moment.

  1. For prefill, vLLM explicitly expands the incoming keys and values before running the attention op:
    key = key[:, :,
    None, :].expand(key.shape[0], self.num_kv_heads,
    self.num_queries_per_kv,
    key.shape[-1])
    value = value[:, :, None, :].expand(value.shape[0],
    self.num_kv_heads,
    self.num_queries_per_kv,
    value.shape[-1])
    because xformers (nor PyTorch SDPA) does not support MQA/GQA at the moment. This is bad for performance since 1) it causes extra overhead of expanding the tensor, and 2) the attention kernel cannot leverage the advantage described above. While FlashAttention efficiently supports MQA/GQA, we need to use it carefully since it does not cover all GPUs/data types/head sizes that xformers supports.
  2. For decode, vLLM's current paged attention kernel also does not leverage the benefits of MQA/GQA. To enjoy the benefit, we need to either significantly rewrite the paged attention kernel, or modify the FlashAttention kernel to support paged KV cache.

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is neededperformancePerformance-related issuesstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions