@@ -278,29 +278,44 @@ def make_tpu_splash_attention(
278278 interpret = interpret ,
279279 residual_checkpoint_name = f"tpu_attention.{ FLASH_ATTN_RESIDUAL_NAME } " ,
280280 )
281+ # args contains fwd_mask_info, dq_mask_info and dkv_mask_info, corresponding to the first three
282+ # positional arguments to `splash_attention_kernel._splash_attention`.
281283 args , kwargs = kernel .tree_flatten ()
282284 specs , _ = kernel .manual_sharding_spec (sharding ).tree_flatten ()
283285
284286 def shard_map_fn (q_proj , k_proj , v_proj , bias , _ , * args ):
287+ assert len (args ) == 3
285288 if softmax_scale != 1.0 :
286289 q_proj *= softmax_scale
287290 _ , segment_ids , _ = split (bias , MaskFnAttentionBias , SegmentIdAttentionBias )
291+ # Note: we cannot pass bias to vmap directly since it's possible that not all its tensors
292+ # have the same batch dimension, which is required by vmap. For example, `target_positions`
293+ # and `source_positions` from MaskFnAttentionBias may have batch dim == 1. Therefore, we
294+ # extract the info we need from bias that pass that to vmap instead.
288295 seg_ids = None
289- if segment_ids .has_value ():
296+ if hasattr (segment_ids , "segment_ids" ):
297+ seg_ids = segment_ids .segment_ids
298+ return jax .vmap (vmap_fn , in_axes = (0 , 0 , 0 , 0 ) + (None ,) * 3 )(
299+ q_proj , k_proj , v_proj , seg_ids , * args
300+ )
301+
302+ def vmap_fn (q_proj , k_proj , v_proj , kv_seg_ids , * args ):
303+ if kv_seg_ids is None :
304+ seg_ids = None
305+ else :
290306 # SplashAttention requires q_seg_ids to have the same sequence length q_proj and
291- # kv_seq_ids to have the same sequence length as k|v_proj. Therefore, we pass in a
307+ # kv_seg_ids to have the same sequence length as k|v_proj. Therefore, we pass in a
292308 # segment id that's not sharded in the sequence dimension, and manually slice the
293309 # sequence dim to populate q_seg_ids.
294- kv_seq_ids = segment_ids .segment_ids
295310 if q_seq_shards == 1 :
296311 q_shard_idx = 0
297312 else :
298313 q_shard_idx = jax .lax .axis_index ("seq" )
299- q_shard_size = kv_seq_ids .shape [0 ] // q_seq_shards
314+ q_shard_size = kv_seg_ids .shape [0 ] // q_seq_shards
300315 q_seq_ids = jax .lax .dynamic_slice_in_dim (
301- kv_seq_ids , q_shard_idx * q_shard_size , q_shard_size
316+ kv_seg_ids , q_shard_idx * q_shard_size , q_shard_size
302317 )
303- seg_ids = splash_attention_kernel .SegmentIds (q_seq_ids , kv_seq_ids )
318+ seg_ids = splash_attention_kernel .SegmentIds (q_seq_ids , kv_seg_ids )
304319
305320 q_proj = jnp .einsum ("tnh->nth" , q_proj )
306321 k_proj = jnp .einsum ("snh->nsh" , k_proj )
@@ -312,7 +327,7 @@ def shard_map_fn(q_proj, k_proj, v_proj, bias, _, *args):
312327 return jnp .einsum ("nth->tnh" , out )
313328
314329 return FlashAttentionShardMapSpecs (
315- fn = jax . vmap ( shard_map_fn , in_axes = ( 0 , 0 , 0 , 0 ) + ( None ,) * 4 ) ,
330+ fn = shard_map_fn ,
316331 additional_in_specs = specs ,
317332 additional_args = args ,
318333 )
0 commit comments