Skip to content

Commit b16d3e9

Browse files
jasonlizhengjianProExpertProg
authored andcommitted
[BugFix][torch.compile] Fix fused_scaled_matmul_reduce_scatter signature for PyTorch 2.8 (vllm-project#26038)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: <> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: bbartels <benjamin@bartels.dev>
1 parent 1f3fc1c commit b16d3e9

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,6 @@ steps:
400400
- pytest -v -s compile/test_fusion_attn.py
401401
- pytest -v -s compile/test_functionalization.py
402402
- pytest -v -s compile/test_silu_mul_quant_fusion.py
403-
- pytest -v -s compile/test_sequence_parallelism.py
404-
- pytest -v -s compile/test_async_tp.py
405403
- pytest -v -s compile/test_fusion_all_reduce.py
406404
- pytest -v -s compile/test_decorator.py
407405
- pytest -v -s compile/test_noop_elimination.py
@@ -1093,6 +1091,8 @@ steps:
10931091
working_dir: "/vllm-workspace/"
10941092
num_gpus: 2
10951093
commands:
1094+
- pytest -v -s tests/compile/test_async_tp.py
1095+
- pytest -v -s tests/compile/test_sequence_parallelism.py
10961096
- pytest -v -s tests/distributed/test_context_parallel.py
10971097
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
10981098

vllm/compilation/collective_fusion.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,23 @@ def replacement(
169169
scale_a: torch.Tensor,
170170
scale_b: torch.Tensor,
171171
) -> torch.Tensor:
172+
# Calculate output shape: input @ mat2 with scatter_dim reduced
173+
output_shape = [*input.shape[:-1], mat2.shape[1]]
174+
scatter_dim = 0
172175
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
173176
input,
174177
mat2,
175178
scale_a,
176179
scale_b,
177180
"avg",
178-
scatter_dim=0,
179-
out_dtype=self.dtype,
180-
group_name=self.tp.device_group.group_name,
181+
scatter_dim, # orig_scatter_dim
182+
scatter_dim, # scatter_dim_after_maybe_reshape
183+
self.tp.device_group.group_name,
184+
output_shape,
185+
None, # bias
186+
None, # result_scale
187+
self.dtype, # out_dtype
188+
False, # use_fast_accum
181189
)
182190

183191
return gemm_rs
@@ -296,15 +304,23 @@ def replacement(
296304
scale_b: torch.Tensor,
297305
cutlass_mm_output: torch.Tensor,
298306
) -> torch.Tensor:
307+
# Calculate output shape: input @ mat2 with scatter_dim reduced
308+
output_shape = [*input.shape[:-1], mat2.shape[1]]
309+
scatter_dim = 0
299310
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
300311
input,
301312
mat2,
302313
scale_a,
303314
scale_b,
304315
"avg",
305-
scatter_dim=0,
306-
out_dtype=self.dtype,
307-
group_name=self.tp.device_group.group_name,
316+
scatter_dim, # orig_scatter_dim
317+
scatter_dim, # scatter_dim_after_maybe_reshape
318+
self.tp.device_group.group_name,
319+
output_shape,
320+
None, # bias
321+
None, # result_scale
322+
self.dtype, # out_dtype
323+
False, # use_fast_accum
308324
)
309325

310326
return gemm_rs

0 commit comments

Comments
 (0)