Skip to content

Commit

Permalink
Legacy flash remat fix (#943)
Browse files Browse the repository at this point in the history
* Fix the same problem for legacy tpu attn

* Fix
  • Loading branch information
hanzhi713 authored Jan 23, 2025
1 parent 6a9f980 commit 30284c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
8 changes: 4 additions & 4 deletions axlearn/common/flash_attention/remat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,13 @@ def test_remat_combine_policy(self):
)
)

remat_hlo = str(jax.jit(remat).lower(params, inputs).as_text("hlo"))
self.assertEqual(
str(jax.make_jaxpr(remat)(params, inputs)).count("_mha_forward_kernel"),
1,
remat_hlo.count('custom_call_target="__gpu$xla.gpu.triton"'),
3,
)
self.assertEqual(
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
remat_hlo.count(" dot("),
no_remat_dots_count,
)

Expand Down Expand Up @@ -229,4 +230,3 @@ def test_remat_combine_policy(self):
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
no_remat_dots_count,
)
jax.jit(remat).lower(params, inputs).as_text("hlo")
36 changes: 21 additions & 15 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,17 @@ def pallas_tpu_flash_attention(
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
)
return _flash_attention(
q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug, interpret
q, k, v, ab, segment_ids, causal, softmax_scale, block_sizes, debug, interpret
)


@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11))
@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
def _flash_attention(
q,
k,
v,
ab,
segment_ids,
save_residuals,
causal,
softmax_scale,
block_sizes,
Expand All @@ -498,7 +497,7 @@ def _flash_attention(
v,
ab,
segment_ids,
save_residuals,
False,
causal,
softmax_scale,
block_sizes.block_b,
Expand All @@ -516,23 +515,32 @@ def _flash_attention_fwd(
v,
ab,
segment_ids,
save_residuals,
causal,
softmax_scale,
block_sizes,
debug,
interpret,
):
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
o, l, m = _flash_attention(
q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug, interpret
o, l, m = _flash_attention_impl(
q,
k,
v,
ab,
segment_ids,
True,
causal,
softmax_scale,
block_sizes.block_b,
block_sizes.block_q,
block_sizes.block_k_major,
block_sizes.block_k,
debug,
interpret,
)
return o, (q, k, v, ab, segment_ids, o, l, m)


def _flash_attention_bwd(
save_residuals: bool,
causal: bool,
softmax_scale: float,
block_sizes: LegacyBlockSizes,
Expand All @@ -542,8 +550,6 @@ def _flash_attention_bwd(
do,
):
"""VJP rule for FlashAttention."""
if save_residuals:
raise NotImplementedError("Higher-order AD not supported")
(q, k, v, ab, segment_ids, o, l, m) = residuals
if not block_sizes.has_backward_blocks:
raise ValueError(
Expand Down Expand Up @@ -789,11 +795,11 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
)
),
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
if save_residuals:
l, m = (v[..., 0] for v in aux[-2:])
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
return (o, l, m)
else:
return o
Expand Down

0 comments on commit 30284c8

Please sign in to comment.