Skip to content

Commit bbad4f0

Browse files
committed
Fix
1 parent a20adbf commit bbad4f0

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

axlearn/common/flash_attention/tpu_attention.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,29 +278,44 @@ def make_tpu_splash_attention(
278278
interpret=interpret,
279279
residual_checkpoint_name=f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}",
280280
)
281+
# args contains fwd_mask_info, dq_mask_info and dkv_mask_info, corresponding to the first three
282+
# positional arguments to `splash_attention_kernel._splash_attention`.
281283
args, kwargs = kernel.tree_flatten()
282284
specs, _ = kernel.manual_sharding_spec(sharding).tree_flatten()
283285

284286
def shard_map_fn(q_proj, k_proj, v_proj, bias, _, *args):
287+
assert len(args) == 3
285288
if softmax_scale != 1.0:
286289
q_proj *= softmax_scale
287290
_, segment_ids, _ = split(bias, MaskFnAttentionBias, SegmentIdAttentionBias)
291+
# Note: we cannot pass bias to vmap directly since it's possible that not all its tensors
292+
# have the same batch dimension, which is required by vmap. For example, `target_positions`
293+
# and `source_positions` from MaskFnAttentionBias may have batch dim == 1. Therefore, we
294+
# extract the info we need from bias that pass that to vmap instead.
288295
seg_ids = None
289-
if segment_ids.has_value():
296+
if hasattr(segment_ids, "segment_ids"):
297+
seg_ids = segment_ids.segment_ids
298+
return jax.vmap(vmap_fn, in_axes=(0, 0, 0, 0) + (None,) * 3)(
299+
q_proj, k_proj, v_proj, seg_ids, *args
300+
)
301+
302+
def vmap_fn(q_proj, k_proj, v_proj, kv_seg_ids, *args):
303+
if kv_seg_ids is None:
304+
seg_ids = None
305+
else:
290306
# SplashAttention requires q_seg_ids to have the same sequence length q_proj and
291-
# kv_seq_ids to have the same sequence length as k|v_proj. Therefore, we pass in a
307+
# kv_seg_ids to have the same sequence length as k|v_proj. Therefore, we pass in a
292308
# segment id that's not sharded in the sequence dimension, and manually slice the
293309
# sequence dim to populate q_seg_ids.
294-
kv_seq_ids = segment_ids.segment_ids
295310
if q_seq_shards == 1:
296311
q_shard_idx = 0
297312
else:
298313
q_shard_idx = jax.lax.axis_index("seq")
299-
q_shard_size = kv_seq_ids.shape[0] // q_seq_shards
314+
q_shard_size = kv_seg_ids.shape[0] // q_seq_shards
300315
q_seq_ids = jax.lax.dynamic_slice_in_dim(
301-
kv_seq_ids, q_shard_idx * q_shard_size, q_shard_size
316+
kv_seg_ids, q_shard_idx * q_shard_size, q_shard_size
302317
)
303-
seg_ids = splash_attention_kernel.SegmentIds(q_seq_ids, kv_seq_ids)
318+
seg_ids = splash_attention_kernel.SegmentIds(q_seq_ids, kv_seg_ids)
304319

305320
q_proj = jnp.einsum("tnh->nth", q_proj)
306321
k_proj = jnp.einsum("snh->nsh", k_proj)
@@ -312,7 +327,7 @@ def shard_map_fn(q_proj, k_proj, v_proj, bias, _, *args):
312327
return jnp.einsum("nth->tnh", out)
313328

314329
return FlashAttentionShardMapSpecs(
315-
fn=jax.vmap(shard_map_fn, in_axes=(0, 0, 0, 0) + (None,) * 4),
330+
fn=shard_map_fn,
316331
additional_in_specs=specs,
317332
additional_args=args,
318333
)

0 commit comments

Comments
 (0)