Skip to content

Commit 6e6322e

Browse files
QiliangCuiLumosis
authored andcommitted
[RpaV3] Renable Tests Part 1 (#534)
Signed-off-by: Qiliang Cui <derrhein@gmail.com>
1 parent 14371a6 commit 6e6322e

File tree

6 files changed

+33
-13
lines changed

6 files changed

+33
-13
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ steps:
7676
python3 -m pytest -s -v -x /workspace/tpu_commons/tests/ \
7777
--ignore=/workspace/tpu_commons/tests/kernels \
7878
--ignore=/workspace/tpu_commons/tests/e2e \
79-
--ignore=/workspace/tpu_commons/tests/models/vllm \
8079
--ignore=/workspace/tpu_commons/tpu_commons/mock \
8180
--cov-config=/workspace/tpu_commons/.coveragerc --cov tpu_commons --cov-report term-missing --cov-fail-under=69
8281

tests/models/vllm/test_jax_attention.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def generate_attention_metadata(num_tokens, mesh) -> AttentionMetadata:
5757

5858
def generate_kv_caches(num_kv_heads, head_size, mesh, dtype):
5959
cache_shape = get_kv_cache_shape_with_mesh(mesh, 1024, 16, num_kv_heads,
60-
head_size, dtype)
60+
head_size, t2j_dtype(dtype))
6161
sharding = NamedSharding(mesh, PartitionSpec())
6262

6363
def _allocate():
@@ -138,15 +138,16 @@ def test_jax_attention(mesh, num_heads, head_size, num_kv_heads, num_tokens):
138138
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
139139
kv_cache = vllm_model_wrapper_context.kv_caches[0]
140140

141-
ref_output = ref_ragged_paged_attention(q,
142-
k,
143-
v,
144-
kv_cache,
145-
md.seq_lens,
146-
md.block_tables,
147-
md.query_start_loc,
148-
md.request_distribution,
149-
sm_scale=scale)
141+
ref_output, _ = ref_ragged_paged_attention(
142+
q,
143+
jax.device_put(t2j(k), NamedSharding(mesh, P())),
144+
jax.device_put(t2j(v), NamedSharding(mesh, P())),
145+
kv_cache,
146+
md.seq_lens,
147+
md.block_tables,
148+
md.query_start_loc,
149+
md.request_distribution,
150+
sm_scale=scale)
150151
ref_output = j2t(ref_output.astype(jnp.float32)).to(dtype)
151152

152153
torch.testing.assert_close(ref_output, jax_output, atol=1e-2, rtol=1e-5)

tests/models/vllm/test_jax_merged_column_parallel_linear.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def setup_environment():
4949
ensure_model_parallel_initialized(1, 1)
5050

5151

52+
@pytest.mark.skip(
53+
reason=
54+
"b/440248045. The failure is not caused by Rpav3. Will fix in another change."
55+
)
5256
@pytest.mark.parametrize("bias", [False, True])
5357
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
5458
@pytest.mark.parametrize("fuse_matmuls", [False, True])
@@ -95,6 +99,10 @@ def test_jax_merged_column_parallel_linear(bias, mesh, fuse_matmuls,
9599
torch.testing.assert_close(output, jax_output)
96100

97101

102+
@pytest.mark.skip(
103+
reason=
104+
"b/440248045. The failure is not caused by Rpav3. Will fix in another change."
105+
)
98106
@pytest.mark.parametrize("bias", [False, True])
99107
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
100108
@pytest.mark.parametrize("fuse_matmuls", [False, True])

tests/models/vllm/test_jax_qkv_parallel_linear.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def setup_environment():
4949
ensure_model_parallel_initialized(1, 1)
5050

5151

52+
@pytest.mark.skip(
53+
reason=
54+
"b/440248045. The failure is not caused by Rpav3. Will fix in another change."
55+
)
5256
@pytest.mark.parametrize("bias", [False, True])
5357
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
5458
@pytest.mark.parametrize("fuse_matmuls", [False, True])
@@ -89,6 +93,10 @@ def test_jax_qkv_parallel_linear(bias, mesh, fuse_matmuls):
8993
torch.testing.assert_close(output, jax_output)
9094

9195

96+
@pytest.mark.skip(
97+
reason=
98+
"b/440248045. The failure is not caused by Rpav3. Will fix in another change."
99+
)
92100
@pytest.mark.parametrize("bias", [False, True])
93101
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
94102
@pytest.mark.parametrize("fuse_matmuls", [False, True])

tests/models/vllm/test_jax_row_parallel_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ def test_jax_row_parallel_linear(bias, mesh, enable_sp):
9898

9999
@pytest.mark.parametrize("bias", [False, True])
100100
@pytest.mark.parametrize("mesh", [test_utils.get_spmd_mesh()])
101-
def test_jax_row_parallel_linear_w8a8_int8(bias, mesh):
101+
@pytest.mark.parametrize("enable_sp", [False, True])
102+
def test_jax_row_parallel_linear_w8a8_int8(bias, mesh, enable_sp):
102103
dtype = torch.bfloat16
103104

104105
engine_args = EngineArgs(
@@ -142,7 +143,8 @@ def test_jax_row_parallel_linear_w8a8_int8(bias, mesh):
142143

143144
# Set jax default device to workaround a layout bug in JAX 0.7.0 and earlier
144145
with torchax.default_env(), jax.default_device(jax.devices("tpu")[0]):
145-
jax_row_linear = JaxRowParallelLinear(row_linear, mesh=mesh)
146+
jax_row_linear = JaxRowParallelLinear(
147+
row_linear, mesh=mesh, enable_sequence_parallelism=enable_sp)
146148
jax_input_tensor = torch_view(t2j(input_tensor, use_dlpack=False))
147149
jax_input_tensor.apply_jax_(jax.device_put,
148150
NamedSharding(mesh, P(None, None)))

tests/models/vllm/test_torchax_wrapper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def test_func():
5050
mock_disable.assert_called_once()
5151

5252

53+
@pytest.mark.skip(
54+
reason="b/440250062. Delete the test when deleting torchax-pt path.")
5355
@pytest.mark.parametrize("tensor_dtype", [torch.float32, torch.bfloat16])
5456
@pytest.mark.parametrize("use_mesh", [True, False])
5557
def test_get_cpu_tensor_from_torchax_tensor(tensor_dtype, use_mesh):

0 commit comments

Comments
 (0)