@@ -451,6 +451,19 @@ def test_backend(
451
451
backend = test_layer ._backend () # pylint: disable=protected-access
452
452
self .assertEqual (backend , "tpu" )
453
453
454
+ def _maybe_skip_unsupported_context_parallel (
455
+ self , mesh , mesh_axis_names , use_bias , per_head_dim
456
+ ):
457
+ # TODO(hanzhi-zhou): Add GPU support.
458
+ if jax .default_backend () != "tpu" :
459
+ self .skipTest ("Context parallelism is only supported on TPU for now." )
460
+ for mesh_dim , mesh_name in zip (mesh , mesh_axis_names ):
461
+ if mesh_name == "seq" and mesh_dim > 1 and (use_bias or per_head_dim % 128 != 0 ):
462
+ self .skipTest (
463
+ "Context parallelism is not supported when need to fallback to legacy TPU"
464
+ " flash attention."
465
+ )
466
+
454
467
@parameterized .parameters (_TEST_CONFIGS )
455
468
def test_shard_biases (
456
469
self , batch , seq_len , num_heads , num_kv_heads , per_head_dim , mesh , mesh_axis_names
@@ -501,7 +514,7 @@ def as_partition_spec(pytree: CompositeAttentionBias) -> PartitionSpec:
501
514
spec = test_layer ._logit_biases_spec (segment_ids ) # pylint: disable=protected-access
502
515
spec = as_partition_spec (spec )
503
516
self .assertIsInstance (spec , PartitionSpec )
504
- self .assertEqual (spec , test_layer .config .mha_dim_to_partition_spec ["btnh " ][:2 ])
517
+ self .assertEqual (spec , test_layer .config .mha_dim_to_partition_spec ["bsnh " ][:2 ])
505
518
506
519
@parameterized .product (
507
520
_TEST_CONFIGS ,
@@ -534,6 +547,10 @@ def test_forward(
534
547
pytest .skip (reason = f"Unsupported mesh { mesh } ." )
535
548
if not causal and left_context is not None :
536
549
pytest .skip (reason = "Sliding window attention must be causal." )
550
+ if query_len_multiplier > 1 and left_context is not None :
551
+ # When sliding window is enabled and q_len > kv_len, there might be be fully masked
552
+ # rows.
553
+ pytest .skip (reason = "Sliding window attention does not make sense when q_len > kv_len." )
537
554
if causal and use_bias :
538
555
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
539
556
pytest .skip (reason = "Only one of causal and use_bias can be True." )
@@ -544,6 +561,7 @@ def test_forward(
544
561
pytest .skip (reason = "Unsupported large bias matrix in fp32 format." )
545
562
if dropout_rate > 0.0 and jax .default_backend () == "tpu" :
546
563
pytest .skip ("Dropout is implemented for GPU only." )
564
+ self ._maybe_skip_unsupported_context_parallel (mesh , mesh_axis_names , use_bias , per_head_dim )
547
565
548
566
with Mesh (mesh_utils .create_device_mesh (mesh ), mesh_axis_names ):
549
567
test_layer , ref_layer , params , hidden_dim = _prepare_layers (
@@ -630,6 +648,7 @@ def test_backward(
630
648
if attn_type == "causal" and use_bias :
631
649
# TODO(c_lan): Investigate the numerical errors when both causal and bias are used.
632
650
pytest .skip (reason = "Only one of causal and use_bias can be True." )
651
+ self ._maybe_skip_unsupported_context_parallel (mesh , mesh_axis_names , use_bias , per_head_dim )
633
652
634
653
with Mesh (mesh_utils .create_device_mesh (mesh ), mesh_axis_names ):
635
654
hidden_dim = num_heads * per_head_dim
0 commit comments