Skip to content

Commit 30284c8

Browse files
authored
Legacy flash remat fix (apple#943)
* Fix the same problem for legacy tpu attn * Fix
1 parent 6a9f980 commit 30284c8

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

axlearn/common/flash_attention/remat_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,13 @@ def test_remat_combine_policy(self):
187187
)
188188
)
189189

190+
remat_hlo = str(jax.jit(remat).lower(params, inputs).as_text("hlo"))
190191
self.assertEqual(
191-
str(jax.make_jaxpr(remat)(params, inputs)).count("_mha_forward_kernel"),
192-
1,
192+
remat_hlo.count('custom_call_target="__gpu$xla.gpu.triton"'),
193+
3,
193194
)
194195
self.assertEqual(
195-
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
196+
remat_hlo.count(" dot("),
196197
no_remat_dots_count,
197198
)
198199

@@ -229,4 +230,3 @@ def test_remat_combine_policy(self):
229230
str(jax.jit(remat).lower(params, inputs).as_text("hlo")).count(" dot("),
230231
no_remat_dots_count,
231232
)
232-
jax.jit(remat).lower(params, inputs).as_text("hlo")

axlearn/common/flash_attention/tpu_attention.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -474,18 +474,17 @@ def pallas_tpu_flash_attention(
474474
batch_size, num_heads, q_seq_len, kv_seq_len, d_model
475475
)
476476
return _flash_attention(
477-
q, k, v, ab, segment_ids, False, causal, softmax_scale, block_sizes, debug, interpret
477+
q, k, v, ab, segment_ids, causal, softmax_scale, block_sizes, debug, interpret
478478
)
479479

480480

481-
@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 11))
481+
@functools.partial(jax.custom_vjp, nondiff_argnums=range(5, 10))
482482
def _flash_attention(
483483
q,
484484
k,
485485
v,
486486
ab,
487487
segment_ids,
488-
save_residuals,
489488
causal,
490489
softmax_scale,
491490
block_sizes,
@@ -498,7 +497,7 @@ def _flash_attention(
498497
v,
499498
ab,
500499
segment_ids,
501-
save_residuals,
500+
False,
502501
causal,
503502
softmax_scale,
504503
block_sizes.block_b,
@@ -516,23 +515,32 @@ def _flash_attention_fwd(
516515
v,
517516
ab,
518517
segment_ids,
519-
save_residuals,
520518
causal,
521519
softmax_scale,
522520
block_sizes,
523521
debug,
524522
interpret,
525523
):
526-
if save_residuals:
527-
raise NotImplementedError("Higher-order AD not supported")
528-
o, l, m = _flash_attention(
529-
q, k, v, ab, segment_ids, True, causal, softmax_scale, block_sizes, debug, interpret
524+
o, l, m = _flash_attention_impl(
525+
q,
526+
k,
527+
v,
528+
ab,
529+
segment_ids,
530+
True,
531+
causal,
532+
softmax_scale,
533+
block_sizes.block_b,
534+
block_sizes.block_q,
535+
block_sizes.block_k_major,
536+
block_sizes.block_k,
537+
debug,
538+
interpret,
530539
)
531540
return o, (q, k, v, ab, segment_ids, o, l, m)
532541

533542

534543
def _flash_attention_bwd(
535-
save_residuals: bool,
536544
causal: bool,
537545
softmax_scale: float,
538546
block_sizes: LegacyBlockSizes,
@@ -542,8 +550,6 @@ def _flash_attention_bwd(
542550
do,
543551
):
544552
"""VJP rule for FlashAttention."""
545-
if save_residuals:
546-
raise NotImplementedError("Higher-order AD not supported")
547553
(q, k, v, ab, segment_ids, o, l, m) = residuals
548554
if not block_sizes.has_backward_blocks:
549555
raise ValueError(
@@ -789,11 +795,11 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
789795
)
790796
),
791797
)(q, k, v, ab, q_segment_ids, kv_segment_ids)
792-
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
793-
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
794-
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
795798
if save_residuals:
796799
l, m = (v[..., 0] for v in aux[-2:])
800+
o = jax.ad_checkpoint.checkpoint_name(o, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
801+
l = jax.ad_checkpoint.checkpoint_name(l, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
802+
m = jax.ad_checkpoint.checkpoint_name(m, f"tpu_attention.{FLASH_ATTN_RESIDUAL_NAME}")
797803
return (o, l, m)
798804
else:
799805
return o

0 commit comments

Comments
 (0)