@@ -474,18 +474,17 @@ def pallas_tpu_flash_attention(
474
474
batch_size , num_heads , q_seq_len , kv_seq_len , d_model
475
475
)
476
476
return _flash_attention (
477
- q , k , v , ab , segment_ids , False , causal , softmax_scale , block_sizes , debug , interpret
477
+ q , k , v , ab , segment_ids , causal , softmax_scale , block_sizes , debug , interpret
478
478
)
479
479
480
480
481
- @functools .partial (jax .custom_vjp , nondiff_argnums = range (5 , 11 ))
481
+ @functools .partial (jax .custom_vjp , nondiff_argnums = range (5 , 10 ))
482
482
def _flash_attention (
483
483
q ,
484
484
k ,
485
485
v ,
486
486
ab ,
487
487
segment_ids ,
488
- save_residuals ,
489
488
causal ,
490
489
softmax_scale ,
491
490
block_sizes ,
@@ -498,7 +497,7 @@ def _flash_attention(
498
497
v ,
499
498
ab ,
500
499
segment_ids ,
501
- save_residuals ,
500
+ False ,
502
501
causal ,
503
502
softmax_scale ,
504
503
block_sizes .block_b ,
@@ -516,23 +515,32 @@ def _flash_attention_fwd(
516
515
v ,
517
516
ab ,
518
517
segment_ids ,
519
- save_residuals ,
520
518
causal ,
521
519
softmax_scale ,
522
520
block_sizes ,
523
521
debug ,
524
522
interpret ,
525
523
):
526
- if save_residuals :
527
- raise NotImplementedError ("Higher-order AD not supported" )
528
- o , l , m = _flash_attention (
529
- q , k , v , ab , segment_ids , True , causal , softmax_scale , block_sizes , debug , interpret
524
+ o , l , m = _flash_attention_impl (
525
+ q ,
526
+ k ,
527
+ v ,
528
+ ab ,
529
+ segment_ids ,
530
+ True ,
531
+ causal ,
532
+ softmax_scale ,
533
+ block_sizes .block_b ,
534
+ block_sizes .block_q ,
535
+ block_sizes .block_k_major ,
536
+ block_sizes .block_k ,
537
+ debug ,
538
+ interpret ,
530
539
)
531
540
return o , (q , k , v , ab , segment_ids , o , l , m )
532
541
533
542
534
543
def _flash_attention_bwd (
535
- save_residuals : bool ,
536
544
causal : bool ,
537
545
softmax_scale : float ,
538
546
block_sizes : LegacyBlockSizes ,
@@ -542,8 +550,6 @@ def _flash_attention_bwd(
542
550
do ,
543
551
):
544
552
"""VJP rule for FlashAttention."""
545
- if save_residuals :
546
- raise NotImplementedError ("Higher-order AD not supported" )
547
553
(q , k , v , ab , segment_ids , o , l , m ) = residuals
548
554
if not block_sizes .has_backward_blocks :
549
555
raise ValueError (
@@ -789,11 +795,11 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index)
789
795
)
790
796
),
791
797
)(q , k , v , ab , q_segment_ids , kv_segment_ids )
792
- o = jax .ad_checkpoint .checkpoint_name (o , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
793
- l = jax .ad_checkpoint .checkpoint_name (l , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
794
- m = jax .ad_checkpoint .checkpoint_name (m , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
795
798
if save_residuals :
796
799
l , m = (v [..., 0 ] for v in aux [- 2 :])
800
+ o = jax .ad_checkpoint .checkpoint_name (o , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
801
+ l = jax .ad_checkpoint .checkpoint_name (l , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
802
+ m = jax .ad_checkpoint .checkpoint_name (m , f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " )
797
803
return (o , l , m )
798
804
else :
799
805
return o
0 commit comments