Skip to content

Commit 718c186

Browse files
angelayilywa1998
authored andcommitted
[bugfix] Fix SP + PP without specifying compile size (vllm-project#26955)
Signed-off-by: angelayi <yiangela7@gmail.com>
1 parent 591a4f2 commit 718c186

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

tests/distributed/test_sequence_parallel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.config.compilation import CompilationMode
1919
from vllm.config.model import RunnerOption
2020
from vllm.logger import init_logger
21+
from vllm.utils import is_torch_equal_or_newer
2122

2223
from ..models.registry import HF_EXAMPLE_MODELS
2324
from ..utils import compare_two_settings, create_new_process_for_each_test
@@ -159,6 +160,7 @@ def _compare_sp(
159160
runner: RunnerOption,
160161
test_options: SPTestOptions,
161162
num_gpus_available: int,
163+
use_inductor_graph_partition: bool,
162164
*,
163165
method: Literal["generate", "encode"],
164166
is_multimodal: bool,
@@ -243,6 +245,7 @@ def _compare_sp(
243245
"enable_fusion": enable_fusion,
244246
"enable_noop": True,
245247
},
248+
"use_inductor_graph_partition": use_inductor_graph_partition,
246249
}
247250

248251
tp_sp_args = [
@@ -297,6 +300,7 @@ def _compare_sp(
297300
if model_id in SP_TEST_MODELS
298301
],
299302
)
303+
@pytest.mark.parametrize("use_inductor_graph_partition", [True, False])
300304
@create_new_process_for_each_test()
301305
def test_tp_sp_generation(
302306
model_id: str,
@@ -305,14 +309,19 @@ def test_tp_sp_generation(
305309
runner: RunnerOption,
306310
test_options: SPTestOptions,
307311
num_gpus_available,
312+
use_inductor_graph_partition: bool,
308313
):
314+
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
315+
pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
316+
309317
_compare_sp(
310318
model_id,
311319
parallel_setup,
312320
distributed_backend,
313321
runner,
314322
test_options,
315323
num_gpus_available,
324+
use_inductor_graph_partition,
316325
method="generate",
317326
is_multimodal=False,
318327
)

vllm/v1/worker/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,12 @@ def is_residual_scattered_for_sp(
328328
"""Check if the residual tensor is scattered for sequence parallelism.
329329
330330
The residual tensor is scattered across tensor parallel ranks when sequence
331-
parallelism and tensor parallelism is enabled, and the number of
332-
input tokens is one of the compilation sizes.
331+
parallelism and tensor parallelism is enabled.
332+
333+
This follows the same logic as SequenceParallelismPass.is_applicable():
334+
- In full-graph compilation mode (no splitting ops or using inductor graph
335+
partition), SP is always applied
336+
- Otherwise, SP is only applied for specific shapes in compile_sizes
333337
"""
334338
if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism:
335339
return False
@@ -343,5 +347,10 @@ def is_residual_scattered_for_sp(
343347
# to be a multiple of tensor_parallel_size (tp) earlier.
344348
assert num_input_tokens % tp == 0
345349

346-
# Currently, SP is only enabled for static size fx graphs.
350+
if (
351+
not vllm_config.compilation_config.splitting_ops
352+
or vllm_config.compilation_config.use_inductor_graph_partition
353+
):
354+
return True
355+
347356
return num_input_tokens in vllm_config.compilation_config.compile_sizes

0 commit comments

Comments
 (0)