Skip to content

Commit e151d69

Browse files
authored
Supports TPU context parallel training (#981)
Fix Fix tests
1 parent 67645d0 commit e151d69

File tree

88 files changed

+819
-1118
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+819
-1118
lines changed

axlearn/common/attention.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import einops
8383
import jax
8484
from jax import numpy as jnp
85+
from jax._src.mesh import thread_resources
8586

8687
from axlearn.common import ops, param_init
8788
from axlearn.common.attention_bias import (
@@ -149,6 +150,7 @@
149150
save_and_offload_only_these_names_regex,
150151
shapes,
151152
split_prng_key,
153+
with_sharding_constraint,
152154
)
153155

154156

@@ -1419,6 +1421,20 @@ def forward(
14191421
)
14201422
if query_positions is None:
14211423
query_positions = jnp.arange(q_proj.shape[1])[None]
1424+
# This sharding hint is needed since compiler sometimes will generate large allgather
1425+
# before the split and then slice, which is not the ideal compilation. Ensure sharding
1426+
# after the split to ensure allgather is inserted after the split.
1427+
axis_names = thread_resources.env.physical_mesh.axis_names
1428+
batch_axes = tuple(x for x in axis_names if x in ("data", "fsdp")) or None
1429+
spec = PartitionSpec(
1430+
batch_axes,
1431+
"seq" if "seq" in axis_names else None,
1432+
"model" if "model" in axis_names else None,
1433+
None,
1434+
)
1435+
q_proj = with_sharding_constraint(q_proj, spec)
1436+
k_proj = with_sharding_constraint(k_proj, spec)
1437+
v_proj = with_sharding_constraint(v_proj, spec)
14221438
return self.Output(
14231439
query=q_proj,
14241440
key=k_proj,

axlearn/common/attention_bias.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,10 +439,16 @@ def partition_spec(
439439
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
440440
) -> Union[BaseAttentionBias, PartitionSpec]:
441441
# Segment IDs: [batch_size, seq_len].
442-
q_spec = mha_dim_to_partition_spec["btnh"]
443-
if q_spec == PartitionSpec(None):
442+
# We use the partition spec of KV (which are not sequence sharded) for segment ids. This is
443+
# because Splash requires two seq ids, q_seg and kv_seg. Therefore, we pass a not seq
444+
# sharded seg ids into the shard map, and manually shard it inside for q_seg and not
445+
# shard it for kv_seg.
446+
kv_spec = mha_dim_to_partition_spec["bsnh"]
447+
if kv_spec == PartitionSpec(None):
444448
return PartitionSpec(None)
445-
return PartitionSpec(q_spec[0], q_spec[1])
449+
if kv_spec[1] is not None:
450+
raise ValueError("The partition spec of `s` in `bsnh` should be None.")
451+
return PartitionSpec(kv_spec[0], kv_spec[1])
446452

447453

448454
class MaskFn(Protocol):

axlearn/common/flash_attention/layer.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
from axlearn.common.attention import Dropout, GroupedQueryAttention
1616
from axlearn.common.attention_bias import BaseAttentionBias
1717
from axlearn.common.config import REQUIRED, ConfigBase, ConfigModifier, Required, config_class
18-
from axlearn.common.flash_attention.utils import (
19-
MultiHeadAttentionImpl,
20-
flash_attention_implementation,
21-
)
18+
from axlearn.common.flash_attention.utils import flash_attention_implementation
2219
from axlearn.common.module import Module
2320
from axlearn.common.utils import Tensor, with_sharding_constraint
2421

@@ -167,13 +164,6 @@ def _compute_attention(
167164

168165
attention_logit_biases = attention_logit_biases.astype(q_proj.dtype)
169166

170-
jit_attn: MultiHeadAttentionImpl = flash_attention_implementation(
171-
backend=backend,
172-
softmax_scale=1.0,
173-
block_size=cfg.tpu_block_size,
174-
dropout_rate=cfg.dropout.rate,
175-
)
176-
177167
attention_logit_biases_spec = self._logit_biases_spec(attention_logit_biases)
178168
attention_logit_biases = with_sharding_constraint(
179169
attention_logit_biases, attention_logit_biases_spec
@@ -188,10 +178,21 @@ def _compute_attention(
188178
k_proj = with_sharding_constraint(k_proj, cfg.mha_dim_to_partition_spec["bsnh"])
189179
v_proj = with_sharding_constraint(v_proj, cfg.mha_dim_to_partition_spec["bsnh"])
190180

181+
shard_map_specs = flash_attention_implementation(
182+
query=q_proj,
183+
key=k_proj,
184+
value=v_proj,
185+
bias=attention_logit_biases,
186+
backend=backend,
187+
softmax_scale=1.0,
188+
block_size=cfg.tpu_block_size,
189+
dropout_rate=cfg.dropout.rate,
190+
)
191+
191192
# We need to manually partition pallas | jax-triton calls.
192193
# Note: shard_map doesn't support kwargs.
193194
partitioned_mha = shard_map(
194-
jit_attn,
195+
shard_map_specs.fn,
195196
mesh=thread_resources.env.physical_mesh,
196197
in_specs=(
197198
# Q [batch_size, seq_len, num_heads, per_head_dim].
@@ -204,6 +205,7 @@ def _compute_attention(
204205
attention_logit_biases_spec,
205206
# PRNG Key.
206207
PartitionSpec(None),
208+
*shard_map_specs.additional_in_specs,
207209
),
208210
# O [batch_size, seq_len, num_heads, per_head_dim].
209211
out_specs=cfg.mha_dim_to_partition_spec["btnh"],
@@ -221,6 +223,7 @@ def _compute_attention(
221223
v_proj,
222224
attention_logit_biases,
223225
self.dropout.get_prng_key(),
226+
*shard_map_specs.additional_args,
224227
),
225228
cfg.output_dim_to_partition_spec["btnh"],
226229
)
@@ -247,12 +250,13 @@ def default_mha_dim_to_partition_spec(
247250
Returns:
248251
A dictionary keyed by MHA tensor dims with partition spec values.
249252
"""
250-
batch_axis_names = tuple(el for el in mesh_axis_names if el != "model")
253+
batch_axis_names = tuple(el for el in mesh_axis_names if el in ["data", "fsdp"])
251254
tp_axis_name = "model" if "model" in mesh_axis_names else None
255+
sp_axis_name = "seq" if "seq" in mesh_axis_names else None
252256
return {
253-
"btnh": PartitionSpec(batch_axis_names, None, tp_axis_name, None),
257+
"btnh": PartitionSpec(batch_axis_names, sp_axis_name, tp_axis_name, None),
254258
"bsnh": PartitionSpec(batch_axis_names, None, tp_axis_name, None),
255-
"bnts": PartitionSpec(batch_axis_names, tp_axis_name, None, None),
259+
"bnts": PartitionSpec(batch_axis_names, tp_axis_name, sp_axis_name, None),
256260
}
257261

258262

@@ -271,7 +275,7 @@ def default_output_dim_to_partition_spec(
271275
Returns:
272276
A dictionary keyed by FlashAttention output tensor dims with partition spec values.
273277
"""
274-
batch_axis_names = tuple(el for el in mesh_axis_names if el not in ["seq", "model"])
278+
batch_axis_names = tuple(el for el in mesh_axis_names if el in ["data", "fsdp"])
275279
tp_axis_name = "model" if "model" in mesh_axis_names else None
276280
sp_axis_name = "seq" if "seq" in mesh_axis_names else None
277281
return {

axlearn/common/flash_attention/layer_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,19 @@ def test_backend(
451451
backend = test_layer._backend() # pylint: disable=protected-access
452452
self.assertEqual(backend, "tpu")
453453

454+
def _maybe_skip_unsupported_context_parallel(
455+
self, mesh, mesh_axis_names, use_bias, per_head_dim
456+
):
457+
# TODO(hanzhi-zhou): Add GPU support.
458+
if jax.default_backend() != "tpu":
459+
self.skipTest("Context parallelism is only supported on TPU for now.")
460+
for mesh_dim, mesh_name in zip(mesh, mesh_axis_names):
461+
if mesh_name == "seq" and mesh_dim > 1 and (use_bias or per_head_dim % 128 != 0):
462+
self.skipTest(
463+
"Context parallelism is not supported when need to fallback to legacy TPU"
464+
" flash attention."
465+
)
466+
454467
@parameterized.parameters(_TEST_CONFIGS)
455468
def test_shard_biases(
456469
self, batch, seq_len, num_heads, num_kv_heads, per_head_dim, mesh, mesh_axis_names
@@ -501,7 +514,7 @@ def as_partition_spec(pytree: CompositeAttentionBias) -> PartitionSpec:
501514
spec = test_layer._logit_biases_spec(segment_ids) # pylint: disable=protected-access
502515
spec = as_partition_spec(spec)
503516
self.assertIsInstance(spec, PartitionSpec)
504-
self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["btnh"][:2])
517+
self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["bsnh"][:2])
505518

506519
@parameterized.product(
507520
_TEST_CONFIGS,
@@ -534,6 +547,10 @@ def test_forward(
534547
pytest.skip(reason=f"Unsupported mesh {mesh}.")
535548
if not causal and left_context is not None:
536549
pytest.skip(reason="Sliding window attention must be causal.")
550+
if query_len_multiplier > 1 and left_context is not None:
551+
# When sliding window is enabled and q_len > kv_len, there might be be fully masked
552+
# rows.
553+
pytest.skip(reason="Sliding window attention does not make sense when q_len > kv_len.")
537554
if causal and use_bias:
538555
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
539556
pytest.skip(reason="Only one of causal and use_bias can be True.")
@@ -544,6 +561,7 @@ def test_forward(
544561
pytest.skip(reason="Unsupported large bias matrix in fp32 format.")
545562
if dropout_rate > 0.0 and jax.default_backend() == "tpu":
546563
pytest.skip("Dropout is implemented for GPU only.")
564+
self._maybe_skip_unsupported_context_parallel(mesh, mesh_axis_names, use_bias, per_head_dim)
547565

548566
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
549567
test_layer, ref_layer, params, hidden_dim = _prepare_layers(
@@ -630,6 +648,7 @@ def test_backward(
630648
if attn_type == "causal" and use_bias:
631649
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
632650
pytest.skip(reason="Only one of causal and use_bias can be True.")
651+
self._maybe_skip_unsupported_context_parallel(mesh, mesh_axis_names, use_bias, per_head_dim)
633652

634653
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
635654
hidden_dim = num_heads * per_head_dim

0 commit comments

Comments
 (0)