Skip to content

Commit 185b1b5

Browse files
authored
Repeat KV heads in Flash Attention (apple#938)
* Roll back '_repeat_kv_heads' change in Flash Attention Recent PR removed _repeat_kv_heads from Flash Attention for GQA optimization, in the hope to reduce HBM usage. However the actual HBM saving would be limited in the model-parallel setting, as the heads are already sharded across devices. It also introduces some limitation which breaks some of the existing sharding configurations. For example, let's say num_heads = 8 and num_kv_heads = 4. When we repeat KV heads, we can set the model axis as 8 so that each device will have only one Q, K, V head; Without repeat_kv_heads, the max value of model axis is 4, and each device will have 2 Q heads as a result, increasing the actual HBM usage. * Repeat kv as necessary for sharding * Unit tests * Address comments.
1 parent 4678740 commit 185b1b5

File tree

3 files changed

+158
-17
lines changed

3 files changed

+158
-17
lines changed

axlearn/common/flash_attention/layer.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import jax
99
import jax.numpy as jnp
10+
import numpy as np
1011
from jax.experimental.shard_map import shard_map
1112
from jax.interpreters.pxla import thread_resources
1213
from jax.sharding import PartitionSpec
@@ -110,6 +111,42 @@ def _logit_biases_spec(self, attention_logit_biases: BaseAttentionBias) -> BaseA
110111
cfg = self.config
111112
return attention_logit_biases.partition_spec(cfg.mha_dim_to_partition_spec)
112113

114+
def _maybe_repeat_kv_heads(self, key_or_value: Tensor) -> Tensor:
115+
"""Repeats key or value heads dim to be shardable."""
116+
cfg = self.config
117+
partition_spec = cfg.mha_dim_to_partition_spec["bsnh"]
118+
global_mesh = thread_resources.env.physical_mesh
119+
if (
120+
partition_spec == PartitionSpec(None)
121+
or len(partition_spec) != 4
122+
or partition_spec[-2] is None
123+
):
124+
return key_or_value
125+
126+
axis = partition_spec[-2]
127+
if isinstance(axis, tuple):
128+
axis_size = np.prod([global_mesh.shape[x] for x in axis])
129+
else:
130+
axis_size = global_mesh.shape[axis]
131+
# There will be sharding error if axis_size > num_heads.
132+
if cfg.num_heads < axis_size:
133+
raise ValueError(
134+
f"num_heads ({cfg.num_heads}) must be greater than or equal to "
135+
f"the number of devices {axis_size} in the mesh axis {axis}."
136+
)
137+
num_head_repeats = axis_size // key_or_value.shape[-2]
138+
# Repeat along the num_heads dim: [batch, source_length, repeated_num_heads, per_head_dim].
139+
if num_head_repeats > 1:
140+
key_or_value = jnp.repeat(key_or_value, num_head_repeats, axis=-2)
141+
142+
if key_or_value.shape[-2] % axis_size != 0:
143+
raise ValueError(
144+
f"repeated_num_heads dim size {key_or_value.shape[-2]} must be "
145+
f"fully divisible by mesh axis {axis} size {axis_size}."
146+
)
147+
148+
return key_or_value
149+
113150
def _compute_attention(
114151
self,
115152
*,
@@ -121,6 +158,10 @@ def _compute_attention(
121158
cfg: FlashAttention.Config = self.config
122159
backend = self._backend()
123160

161+
# Repeats key/value heads dim if necessary.
162+
k_proj = self._maybe_repeat_kv_heads(k_proj)
163+
v_proj = self._maybe_repeat_kv_heads(v_proj)
164+
124165
batch, target_len, num_heads, _ = q_proj.shape
125166
_, source_len, _, _ = k_proj.shape
126167

@@ -155,9 +196,8 @@ def _compute_attention(
155196
in_specs=(
156197
# Q [batch_size, seq_len, num_heads, per_head_dim].
157198
cfg.mha_dim_to_partition_spec["btnh"],
158-
# KV [batch_size, seq_len, num_kv_heads, per_head_dim].
159-
# Note: while num_kv_heads can be different from num_heads, their partition spec
160-
# should be the same.
199+
# KV [batch_size, seq_len, repeated_num_heads, per_head_dim].
200+
# repeated_num_heads should be divided evenly by the n axis.
161201
cfg.mha_dim_to_partition_spec["bsnh"],
162202
cfg.mha_dim_to_partition_spec["bsnh"],
163203
# Bias that can broadcast to [batch_size, num_heads, seq_len, seq_len].

0 commit comments

Comments
 (0)