-
-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Description
This issue is not in response to a performance regression.
The method of performing cross-attention QKV computations introduced in #4942 could be improved. Because this issue relates to cross-attention, it only impacts encoder/decoder models, not decoder-only models.
For context, QKVParallelLinear
computes QKV from the previous decoder layer's hidden state output, i.e. only a single input. The problem is that cross attention requires QKV to be computed from two inputs: Q must be computed from the previous decoder layer's hidden state output, and KV must be computed from the encoder's output hidden states. Additionally,
- During prefill phase, both Q and KV must be computed
- During decode phase, only Q is computed because the encoder sequence is static so there are no new encoder KVs
The current, inefficient workaround for cross-attention is to construct a QKVParallelLinear
layer & apply it at most 2 times in a given run of the cross-attention forward()
method: once to decoder_hidden_states
to obtain Q, and (only during prefill) a second time to encoder_hidden_states
to obtain KV:
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
Cost breakdown of the current method
During prefill,
-
$(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2 out of 3 GEMMs are unnecessary -
$(encoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain KV, so 1 out of 3 GEMMs are unnecessary - In total, half of GEMMs are unnecessary (50% efficiency)
During decode
-
$(decoder\ hidden\ states) [W_Q W_K W_V]$ is computed in order to obtain Q, so 2/3 of GEMMs are unnecessary (33% efficiency)
Proposed solution
What is needed is a modification or subclass to QKVParallelLinear
with the following properties
- Exploits parallelism over multiple GPUs
-
forward()
takes a decoder hidden states argument, and an optional encoder hidden states argument -
forward()
always computes$(decoder\ hidden\ states) W_Q$ -
forward()
computes$(encoder\ hidden\ states) [W_K W_V]$ conditionally: only if the encoder hidden states are notNone
- 100% of GEMMs are necessary