Skip to content

Commit 1b95ac8

Browse files
committed
Fix tests
1 parent bbad4f0 commit 1b95ac8

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

axlearn/common/flash_attention/layer_test.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,19 @@ def test_backend(
451451
backend = test_layer._backend() # pylint: disable=protected-access
452452
self.assertEqual(backend, "tpu")
453453

454+
def _maybe_skip_unsupported_context_parallel(
455+
self, mesh, mesh_axis_names, use_bias, per_head_dim
456+
):
457+
# TODO(hanzhi-zhou): Add GPU support.
458+
if jax.default_backend() != "tpu":
459+
self.skipTest("Context parallelism is only supported on TPU for now.")
460+
for mesh_dim, mesh_name in zip(mesh, mesh_axis_names):
461+
if mesh_name == "seq" and mesh_dim > 1 and (use_bias or per_head_dim % 128 != 0):
462+
self.skipTest(
463+
"Context parallelism is not supported when need to fallback to legacy TPU"
464+
" flash attention."
465+
)
466+
454467
@parameterized.parameters(_TEST_CONFIGS)
455468
def test_shard_biases(
456469
self, batch, seq_len, num_heads, num_kv_heads, per_head_dim, mesh, mesh_axis_names
@@ -501,7 +514,7 @@ def as_partition_spec(pytree: CompositeAttentionBias) -> PartitionSpec:
501514
spec = test_layer._logit_biases_spec(segment_ids) # pylint: disable=protected-access
502515
spec = as_partition_spec(spec)
503516
self.assertIsInstance(spec, PartitionSpec)
504-
self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["btnh"][:2])
517+
self.assertEqual(spec, test_layer.config.mha_dim_to_partition_spec["bsnh"][:2])
505518

506519
@parameterized.product(
507520
_TEST_CONFIGS,
@@ -534,6 +547,10 @@ def test_forward(
534547
pytest.skip(reason=f"Unsupported mesh {mesh}.")
535548
if not causal and left_context is not None:
536549
pytest.skip(reason="Sliding window attention must be causal.")
550+
if query_len_multiplier > 1 and left_context is not None:
551+
# When sliding window is enabled and q_len > kv_len, there might be be fully masked
552+
# rows.
553+
pytest.skip(reason="Sliding window attention does not make sense when q_len > kv_len.")
537554
if causal and use_bias:
538555
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
539556
pytest.skip(reason="Only one of causal and use_bias can be True.")
@@ -544,6 +561,7 @@ def test_forward(
544561
pytest.skip(reason="Unsupported large bias matrix in fp32 format.")
545562
if dropout_rate > 0.0 and jax.default_backend() == "tpu":
546563
pytest.skip("Dropout is implemented for GPU only.")
564+
self._maybe_skip_unsupported_context_parallel(mesh, mesh_axis_names, use_bias, per_head_dim)
547565

548566
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
549567
test_layer, ref_layer, params, hidden_dim = _prepare_layers(
@@ -630,6 +648,7 @@ def test_backward(
630648
if attn_type == "causal" and use_bias:
631649
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
632650
pytest.skip(reason="Only one of causal and use_bias can be True.")
651+
self._maybe_skip_unsupported_context_parallel(mesh, mesh_axis_names, use_bias, per_head_dim)
633652

634653
with Mesh(mesh_utils.create_device_mesh(mesh), mesh_axis_names):
635654
hidden_dim = num_heads * per_head_dim

0 commit comments

Comments
 (0)