@@ -1126,16 +1126,17 @@ def wrap_flash_attention(
11261126 decoder_segment_ids ,
11271127 custom_mask = None ,
11281128 attn_logits_soft_cap = None ,
1129+ head_shards = 1 ,
1130+ q_seq_shards = 1 ,
11291131):
11301132 if decoder_segment_ids is not None :
11311133 assert query .shape [2 ] == decoder_segment_ids .q .shape [1 ], (
1132- "Sharding along sequence dimension not allowed in tpu kernel "
1133- "attention"
1134+ "Sharding along sequence dimension not allowed"
1135+ " in TPU kernel attention"
11341136 )
11351137
11361138 if custom_mask is not None :
11371139 mask = splash_attention_mask .NumpyMask (array = custom_mask )
1138-
11391140 else :
11401141 mask = splash_attention_mask .CausalMask (
11411142 shape = (query .shape [2 ], query .shape [2 ])
@@ -1147,8 +1148,8 @@ def wrap_flash_attention(
11471148 )
11481149 splash_kernel = splash_attention_kernel .make_splash_mha (
11491150 mask = multi_head_mask ,
1150- head_shards = 1 ,
1151- q_seq_shards = 1 ,
1151+ head_shards = head_shards ,
1152+ q_seq_shards = q_seq_shards ,
11521153 attn_logits_soft_cap = attn_logits_soft_cap ,
11531154 )
11541155
@@ -1168,6 +1169,38 @@ def dot_product_attention(
11681169 flash_attention = None ,
11691170 attn_logits_soft_cap = None ,
11701171):
1172+ """Computes dot-product attention given query, key, and value.
1173+
1174+ This is the core computation of attention that is used in transformers.
1175+ For TPU platforms, flash attention optimizations are automatically applied
1176+ when possible, and sharding parameters are inferred from the layout map
1177+ in the current distribution context.
1178+
1179+ Args:
1180+ query: Queries with shape `[batch, time, heads,
1181+ depth_k]`.
1182+ key: Keys with shape `[batch, time, heads,
1183+ depth_k]`.
1184+ value: Values with shape `[batch, time, heads,
1185+ depth_v]`.
1186+ bias: Optional bias with shape broadcastable to
1187+ `[batch, heads, dest_time, source_time]`.
1188+ mask: Optional mask with shape broadcastable to
1189+ `[batch, heads, dest_time, source_time]`.
1190+ scale: Float. Optional scale that is applied to the attention
1191+ computation.
1192+ is_causal: Boolean. Specifying whether causal masking is applied.
1193+ flash_attention: Boolean. Whether to use flash attention optimization
1194+ for increased performance. Default to None, which means it will
1195+ be auto-determined based on the platform, input shapes and
1196+ compatibility.
1197+ attn_logits_soft_cap: Float. Optional float to softly cap attention
1198+ logits to avoid numerical stability issues. Applied as:
1199+ `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.
1200+
1201+ Returns:
1202+ JAX Array of shape `[batch, time, heads, depth_v]`.
1203+ """
11711204 query = convert_to_tensor (query )
11721205 key = convert_to_tensor (key )
11731206 value = convert_to_tensor (value )
@@ -1177,47 +1210,155 @@ def dot_product_attention(
11771210 f"Received: query.shape={ query .shape } , key.shape={ key .shape } , "
11781211 f"value.shape={ value .shape } ."
11791212 )
1180- if flash_attention is None :
1181- flash_attention = _can_use_flash_attention (query , key , value , bias )
1182- elif flash_attention is True :
1183- # Use `raise_error=True` to provide more details if the inputs failed to
1184- # use flash attention
1185- _can_use_flash_attention (query , key , value , bias , raise_error = True )
11861213
1187- if jax .devices ()[0 ].platform == "tpu" :
1188- # Transpose to ('batch', 'heads', 'length', 'kv')
1189- query = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1190- key = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
1191- value = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
1192- B , H , S , KV = query .shape
1193-
1194- segment_ids = jnp .ones ([B , S ])
1195- # {token_ids, padding_mask, segment_ids} enable packing
1196- out = wrap_flash_attention (
1197- query ,
1198- key ,
1199- value ,
1200- decoder_segment_ids = splash_attention_kernel .SegmentIds (
1201- segment_ids , segment_ids
1202- ),
1203- custom_mask = mask ,
1204- attn_logits_soft_cap = attn_logits_soft_cap ,
1214+ # Check platform
1215+ platform = jax .devices ()[0 ].platform
1216+ is_tpu = platform == "tpu"
1217+
1218+ # Get sharding parameters from distribution context
1219+ head_shards = 1
1220+ q_seq_shards = 1
1221+
1222+ if is_tpu :
1223+ try :
1224+ from keras .src .distribution .distribution_lib import ModelParallel
1225+ from keras .src .distribution .distribution_lib import (
1226+ distribution as get_dist ,
1227+ )
1228+
1229+ # Get current distribution if available
1230+ dist = get_dist ()
1231+ if dist and isinstance (dist , ModelParallel ):
1232+ mesh = dist .device_mesh
1233+ if "model" in mesh .axis_names :
1234+ model_dim_index = mesh .axis_names .index ("model" )
1235+ # Set head_shards based on the model dimension of the mesh
1236+ head_shards = mesh .shape [model_dim_index ]
1237+ # Typically keep q_seq_shards=1 for best performance
1238+ q_seq_shards = 1
1239+ except (ImportError , ValueError , AttributeError ):
1240+ # Use default values if detection fails
1241+ head_shards = 1
1242+ q_seq_shards = 1
1243+
1244+ # Check if inputs use partial sharding (not fully replicated)
1245+ # Flash attention works well with fully replicated tensors on all platforms
1246+ # but may have issues with certain partial sharding patterns on non-TPU
1247+ # platforms
1248+ partially_sharded_inputs = any (
1249+ hasattr (t , "sharding" ) and not t .sharding .is_fully_replicated
1250+ for t in (query , key , value )
1251+ )
1252+
1253+ # Determine flash attention compatibility
1254+ if flash_attention is None :
1255+ # Auto-detect flash attention availability
1256+ if is_tpu :
1257+ # TPUs have specialized hardware for attention that works with any
1258+ # sharding pattern
1259+ flash_attention = True
1260+ else :
1261+ # For GPU/CPU with partially sharded inputs, we need
1262+ # multiple devices to efficiently handle the sharding
1263+ if partially_sharded_inputs and len (jax .devices ()) <= 1 :
1264+ flash_attention = False
1265+ else :
1266+ flash_attention = _can_use_flash_attention (
1267+ query , key , value , bias
1268+ )
1269+ elif flash_attention is True and not is_tpu :
1270+ # If flash attention is explicitly requested, validate compatibility
1271+ # Skip validation for TPU as it has specialized hardware support
1272+ try :
1273+ _can_use_flash_attention (query , key , value , bias , raise_error = True )
1274+ except Exception :
1275+ # Only disable flash attention on non-TPU platforms
1276+ # if validation fails
1277+ flash_attention = False
1278+
1279+ # TPU-specific flash attention path
1280+ if is_tpu and flash_attention :
1281+ # Transpose to ('batch', 'heads', 'length', 'head_dim')
1282+ query_tpu_layout = jnp .transpose (query , axes = (0 , 2 , 1 , 3 ))
1283+ key_tpu_layout = jnp .transpose (key , axes = (0 , 2 , 1 , 3 ))
1284+ value_tpu_layout = jnp .transpose (value , axes = (0 , 2 , 1 , 3 ))
1285+
1286+ bs , num_heads , q_len , head_dim = query_tpu_layout .shape
1287+
1288+ # Apply scale to query if provided
1289+ if scale is not None :
1290+ # TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1291+ # overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1292+ query_tpu_layout = query_tpu_layout * (scale * math .sqrt (head_dim ))
1293+
1294+ # Create segment IDs for Splash Attention (for packing/batching)
1295+ segment_ids = jnp .zeros ([bs , q_len ], dtype = jnp .int32 )
1296+ decoder_segment_ids = splash_attention_kernel .SegmentIds (
1297+ q = segment_ids , kv = segment_ids
12051298 )
1206- out = jnp .transpose (out , axes = (0 , 2 , 1 , 3 ))
1207- return out
12081299
1209- # `dot_product_attention` is only available in jax>=0.4.31
1300+ # Process mask for Splash Attention
1301+ custom_mask = None
1302+ if mask is not None :
1303+ mask_bool = mask .astype ("bool" ) if mask .dtype != jnp .bool_ else mask
1304+
1305+ if mask_bool .ndim == 3 and mask_bool .shape [0 ] == bs :
1306+ custom_mask = mask_bool [0 ]
1307+ elif mask_bool .ndim == 4 and mask_bool .shape [0 ] == bs :
1308+ custom_mask = mask_bool [0 , 0 ]
1309+
1310+ if is_causal and custom_mask is not None :
1311+ causal_mask = jnp .tril (
1312+ jnp .ones ((q_len , q_len ), dtype = jnp .bool_ )
1313+ )
1314+ custom_mask = jnp .logical_and (custom_mask , causal_mask )
1315+
1316+ if custom_mask is None and is_causal :
1317+ custom_mask = jnp .tril (jnp .ones ((q_len , q_len ), dtype = jnp .bool_ ))
1318+
1319+ try :
1320+ output = wrap_flash_attention (
1321+ query_tpu_layout ,
1322+ key_tpu_layout ,
1323+ value_tpu_layout ,
1324+ decoder_segment_ids = decoder_segment_ids ,
1325+ custom_mask = custom_mask ,
1326+ attn_logits_soft_cap = attn_logits_soft_cap ,
1327+ head_shards = head_shards ,
1328+ q_seq_shards = q_seq_shards ,
1329+ )
1330+ # Transpose output back to Keras layout
1331+ return jnp .transpose (output , axes = (0 , 2 , 1 , 3 ))
1332+ except Exception :
1333+ flash_attention = False
1334+
1335+ # JAX native dot_product_attention for GPU or fallback for TPU
12101336 if hasattr (jax .nn , "dot_product_attention" ):
1211- return jax .nn .dot_product_attention (
1212- query ,
1213- key ,
1214- value ,
1215- bias = bias ,
1216- mask = mask ,
1217- scale = scale ,
1218- is_causal = is_causal ,
1219- implementation = "cudnn" if flash_attention else "xla" ,
1220- )
1337+ try :
1338+ return jax .nn .dot_product_attention (
1339+ query ,
1340+ key ,
1341+ value ,
1342+ bias = bias ,
1343+ mask = mask ,
1344+ scale = scale ,
1345+ is_causal = is_causal ,
1346+ implementation = "cudnn" if flash_attention else "xla" ,
1347+ )
1348+ except Exception :
1349+ # If flash attention fails, fall back to XLA implementation
1350+ if flash_attention :
1351+ return jax .nn .dot_product_attention (
1352+ query ,
1353+ key ,
1354+ value ,
1355+ bias = bias ,
1356+ mask = mask ,
1357+ scale = scale ,
1358+ is_causal = is_causal ,
1359+ implementation = "xla" ,
1360+ )
1361+ raise
12211362
12221363 if flash_attention :
12231364 raise RuntimeError (
@@ -1228,6 +1369,9 @@ def dot_product_attention(
12281369 # Ref: jax.nn.dot_product_attention
12291370 # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
12301371 # Not support `query_seq_lengths` and `key_value_seq_lengths` args
1372+
1373+ # Fallback to custom XLA implementation
1374+ # This is the reference implementation from jax.nn.dot_product_attention
12311375 output_shape = query .shape
12321376 _ , _ , K , H = key .shape
12331377 scale = (1.0 / jnp .sqrt (H )) if scale is None else scale
0 commit comments