Skip to content

Commit d8f3f70

Browse files
authored
Fixed issue with dot_product_attention when using TPU. (#21254)
* Update nn.py * Update nn.py * Update nn.py * Update nn.py * Update nn.py Corrected indentation in doc string * Update nn.py
1 parent 3318d8f commit d8f3f70

File tree

1 file changed

+186
-42
lines changed
  • keras/src/backend/jax

1 file changed

+186
-42
lines changed

keras/src/backend/jax/nn.py

Lines changed: 186 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,16 +1126,17 @@ def wrap_flash_attention(
11261126
decoder_segment_ids,
11271127
custom_mask=None,
11281128
attn_logits_soft_cap=None,
1129+
head_shards=1,
1130+
q_seq_shards=1,
11291131
):
11301132
if decoder_segment_ids is not None:
11311133
assert query.shape[2] == decoder_segment_ids.q.shape[1], (
1132-
"Sharding along sequence dimension not allowed in tpu kernel "
1133-
"attention"
1134+
"Sharding along sequence dimension not allowed"
1135+
" in TPU kernel attention"
11341136
)
11351137

11361138
if custom_mask is not None:
11371139
mask = splash_attention_mask.NumpyMask(array=custom_mask)
1138-
11391140
else:
11401141
mask = splash_attention_mask.CausalMask(
11411142
shape=(query.shape[2], query.shape[2])
@@ -1147,8 +1148,8 @@ def wrap_flash_attention(
11471148
)
11481149
splash_kernel = splash_attention_kernel.make_splash_mha(
11491150
mask=multi_head_mask,
1150-
head_shards=1,
1151-
q_seq_shards=1,
1151+
head_shards=head_shards,
1152+
q_seq_shards=q_seq_shards,
11521153
attn_logits_soft_cap=attn_logits_soft_cap,
11531154
)
11541155

@@ -1168,6 +1169,38 @@ def dot_product_attention(
11681169
flash_attention=None,
11691170
attn_logits_soft_cap=None,
11701171
):
1172+
"""Computes dot-product attention given query, key, and value.
1173+
1174+
This is the core computation of attention that is used in transformers.
1175+
For TPU platforms, flash attention optimizations are automatically applied
1176+
when possible, and sharding parameters are inferred from the layout map
1177+
in the current distribution context.
1178+
1179+
Args:
1180+
query: Queries with shape `[batch, time, heads,
1181+
depth_k]`.
1182+
key: Keys with shape `[batch, time, heads,
1183+
depth_k]`.
1184+
value: Values with shape `[batch, time, heads,
1185+
depth_v]`.
1186+
bias: Optional bias with shape broadcastable to
1187+
`[batch, heads, dest_time, source_time]`.
1188+
mask: Optional mask with shape broadcastable to
1189+
`[batch, heads, dest_time, source_time]`.
1190+
scale: Float. Optional scale that is applied to the attention
1191+
computation.
1192+
is_causal: Boolean. Specifying whether causal masking is applied.
1193+
flash_attention: Boolean. Whether to use flash attention optimization
1194+
for increased performance. Default to None, which means it will
1195+
be auto-determined based on the platform, input shapes and
1196+
compatibility.
1197+
attn_logits_soft_cap: Float. Optional float to softly cap attention
1198+
logits to avoid numerical stability issues. Applied as:
1199+
`logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.
1200+
1201+
Returns:
1202+
JAX Array of shape `[batch, time, heads, depth_v]`.
1203+
"""
11711204
query = convert_to_tensor(query)
11721205
key = convert_to_tensor(key)
11731206
value = convert_to_tensor(value)
@@ -1177,47 +1210,155 @@ def dot_product_attention(
11771210
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
11781211
f"value.shape={value.shape}."
11791212
)
1180-
if flash_attention is None:
1181-
flash_attention = _can_use_flash_attention(query, key, value, bias)
1182-
elif flash_attention is True:
1183-
# Use `raise_error=True` to provide more details if the inputs failed to
1184-
# use flash attention
1185-
_can_use_flash_attention(query, key, value, bias, raise_error=True)
11861213

1187-
if jax.devices()[0].platform == "tpu":
1188-
# Transpose to ('batch', 'heads', 'length', 'kv')
1189-
query = jnp.transpose(query, axes=(0, 2, 1, 3))
1190-
key = jnp.transpose(key, axes=(0, 2, 1, 3))
1191-
value = jnp.transpose(value, axes=(0, 2, 1, 3))
1192-
B, H, S, KV = query.shape
1193-
1194-
segment_ids = jnp.ones([B, S])
1195-
# {token_ids, padding_mask, segment_ids} enable packing
1196-
out = wrap_flash_attention(
1197-
query,
1198-
key,
1199-
value,
1200-
decoder_segment_ids=splash_attention_kernel.SegmentIds(
1201-
segment_ids, segment_ids
1202-
),
1203-
custom_mask=mask,
1204-
attn_logits_soft_cap=attn_logits_soft_cap,
1214+
# Check platform
1215+
platform = jax.devices()[0].platform
1216+
is_tpu = platform == "tpu"
1217+
1218+
# Get sharding parameters from distribution context
1219+
head_shards = 1
1220+
q_seq_shards = 1
1221+
1222+
if is_tpu:
1223+
try:
1224+
from keras.src.distribution.distribution_lib import ModelParallel
1225+
from keras.src.distribution.distribution_lib import (
1226+
distribution as get_dist,
1227+
)
1228+
1229+
# Get current distribution if available
1230+
dist = get_dist()
1231+
if dist and isinstance(dist, ModelParallel):
1232+
mesh = dist.device_mesh
1233+
if "model" in mesh.axis_names:
1234+
model_dim_index = mesh.axis_names.index("model")
1235+
# Set head_shards based on the model dimension of the mesh
1236+
head_shards = mesh.shape[model_dim_index]
1237+
# Typically keep q_seq_shards=1 for best performance
1238+
q_seq_shards = 1
1239+
except (ImportError, ValueError, AttributeError):
1240+
# Use default values if detection fails
1241+
head_shards = 1
1242+
q_seq_shards = 1
1243+
1244+
# Check if inputs use partial sharding (not fully replicated)
1245+
# Flash attention works well with fully replicated tensors on all platforms
1246+
# but may have issues with certain partial sharding patterns on non-TPU
1247+
# platforms
1248+
partially_sharded_inputs = any(
1249+
hasattr(t, "sharding") and not t.sharding.is_fully_replicated
1250+
for t in (query, key, value)
1251+
)
1252+
1253+
# Determine flash attention compatibility
1254+
if flash_attention is None:
1255+
# Auto-detect flash attention availability
1256+
if is_tpu:
1257+
# TPUs have specialized hardware for attention that works with any
1258+
# sharding pattern
1259+
flash_attention = True
1260+
else:
1261+
# For GPU/CPU with partially sharded inputs, we need
1262+
# multiple devices to efficiently handle the sharding
1263+
if partially_sharded_inputs and len(jax.devices()) <= 1:
1264+
flash_attention = False
1265+
else:
1266+
flash_attention = _can_use_flash_attention(
1267+
query, key, value, bias
1268+
)
1269+
elif flash_attention is True and not is_tpu:
1270+
# If flash attention is explicitly requested, validate compatibility
1271+
# Skip validation for TPU as it has specialized hardware support
1272+
try:
1273+
_can_use_flash_attention(query, key, value, bias, raise_error=True)
1274+
except Exception:
1275+
# Only disable flash attention on non-TPU platforms
1276+
# if validation fails
1277+
flash_attention = False
1278+
1279+
# TPU-specific flash attention path
1280+
if is_tpu and flash_attention:
1281+
# Transpose to ('batch', 'heads', 'length', 'head_dim')
1282+
query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))
1283+
key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))
1284+
value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))
1285+
1286+
bs, num_heads, q_len, head_dim = query_tpu_layout.shape
1287+
1288+
# Apply scale to query if provided
1289+
if scale is not None:
1290+
# TPU kernel applies 1/sqrt(head_dim) internally, to achieve
1291+
# overall QK^T * scale, scale query by (scale * sqrt(head_dim))
1292+
query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))
1293+
1294+
# Create segment IDs for Splash Attention (for packing/batching)
1295+
segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)
1296+
decoder_segment_ids = splash_attention_kernel.SegmentIds(
1297+
q=segment_ids, kv=segment_ids
12051298
)
1206-
out = jnp.transpose(out, axes=(0, 2, 1, 3))
1207-
return out
12081299

1209-
# `dot_product_attention` is only available in jax>=0.4.31
1300+
# Process mask for Splash Attention
1301+
custom_mask = None
1302+
if mask is not None:
1303+
mask_bool = mask.astype("bool") if mask.dtype != jnp.bool_ else mask
1304+
1305+
if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:
1306+
custom_mask = mask_bool[0]
1307+
elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:
1308+
custom_mask = mask_bool[0, 0]
1309+
1310+
if is_causal and custom_mask is not None:
1311+
causal_mask = jnp.tril(
1312+
jnp.ones((q_len, q_len), dtype=jnp.bool_)
1313+
)
1314+
custom_mask = jnp.logical_and(custom_mask, causal_mask)
1315+
1316+
if custom_mask is None and is_causal:
1317+
custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))
1318+
1319+
try:
1320+
output = wrap_flash_attention(
1321+
query_tpu_layout,
1322+
key_tpu_layout,
1323+
value_tpu_layout,
1324+
decoder_segment_ids=decoder_segment_ids,
1325+
custom_mask=custom_mask,
1326+
attn_logits_soft_cap=attn_logits_soft_cap,
1327+
head_shards=head_shards,
1328+
q_seq_shards=q_seq_shards,
1329+
)
1330+
# Transpose output back to Keras layout
1331+
return jnp.transpose(output, axes=(0, 2, 1, 3))
1332+
except Exception:
1333+
flash_attention = False
1334+
1335+
# JAX native dot_product_attention for GPU or fallback for TPU
12101336
if hasattr(jax.nn, "dot_product_attention"):
1211-
return jax.nn.dot_product_attention(
1212-
query,
1213-
key,
1214-
value,
1215-
bias=bias,
1216-
mask=mask,
1217-
scale=scale,
1218-
is_causal=is_causal,
1219-
implementation="cudnn" if flash_attention else "xla",
1220-
)
1337+
try:
1338+
return jax.nn.dot_product_attention(
1339+
query,
1340+
key,
1341+
value,
1342+
bias=bias,
1343+
mask=mask,
1344+
scale=scale,
1345+
is_causal=is_causal,
1346+
implementation="cudnn" if flash_attention else "xla",
1347+
)
1348+
except Exception:
1349+
# If flash attention fails, fall back to XLA implementation
1350+
if flash_attention:
1351+
return jax.nn.dot_product_attention(
1352+
query,
1353+
key,
1354+
value,
1355+
bias=bias,
1356+
mask=mask,
1357+
scale=scale,
1358+
is_causal=is_causal,
1359+
implementation="xla",
1360+
)
1361+
raise
12211362

12221363
if flash_attention:
12231364
raise RuntimeError(
@@ -1228,6 +1369,9 @@ def dot_product_attention(
12281369
# Ref: jax.nn.dot_product_attention
12291370
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
12301371
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
1372+
1373+
# Fallback to custom XLA implementation
1374+
# This is the reference implementation from jax.nn.dot_product_attention
12311375
output_shape = query.shape
12321376
_, _, K, H = key.shape
12331377
scale = (1.0 / jnp.sqrt(H)) if scale is None else scale

0 commit comments

Comments
 (0)