Skip to content

[Misc]: Cross-attention QKV computation is inefficient #7397

@afeldman-nm

Description

@afeldman-nm

This issue is not in response to a performance regression.

The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.

For context, QKVParallelLinear computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,

  • During prefill phase, both Q and KV must be computed
  • During decode phase, only Q is computed because the encoder sequence is static so there are no new encoder KVs

The current, inefficient workaround for cross-attention is to construct a QKVParallelLinear layer & apply it at most 2 times in a given run of the cross-attention forward() method: once to decoder_hidden_states to obtain Q, and (only during prefill) a second time to encoder_hidden_states to obtain KV:

# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
                        dim=-1)
if encoder_hidden_states is None:
    k = None
    v = None
else:
    qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
    _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
                            dim=-1)

Cost breakdown of the current method

During prefill,

  • $(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2 out of 3 GEMMs are unnecessary
  • $(encoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain KV, so 1 out of 3 GEMMs are unnecessary
  • In total, half of GEMMs are unnecessary (50% efficiency)

During decode

  • $(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2/3 of GEMMs are unnecessary (33% efficiency)

Proposed solution

What is needed is a modification or subclass to QKVParallelLinear with the following properties

  • Exploits parallelism over multiple GPUs
  • forward() takes a decoder hidden states argument, and an optional encoder hidden states argument
  • forward() always computes $(decoder\ hidden\ states) W_Q$
  • forward() computes $(encoder\ hidden\ states) [W_K W_V]$ conditionally: only if the encoder hidden states are not None
  • 100% of GEMMs are necessary

Metadata

Metadata

Assignees

No one assigned

    Labels

    miscstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions