Skip to content

Commit 33aaa15

Browse files
committed
[Kernel] Apply torch.Tag.needs_fixed_stride_order only for torch==2.6.0
Summary: In torch 2.6.0, torch accidentally changed the default for custom operators to be "requires_contiguous". As a workaround, vLLM added needs_fixed_stride_order to a large number of custom operators. vLLM is currently on torch 2.7.0 which has reverted the default for custom operators back to needs_fixed_stride_order. This PR cleans up the kernel logic by flipping the default back. The other reason why I want to flip the default back is that needs_fixed_stride_order is actually buggy and torch 2.8.0 has better behavior for custom operators with no layout tags set. Also Kaichao tells me that some backends may not have moved to PyTorch 2.7.0 yet (vllm-project#8932) so I didn't delete the code in this PR. Test Plan: - Existing tests - Ran `pytest tests/compile/test_full_graph.py` (this was the test that originally caused us to add the needs_fixed_stride_order tag, see vllm-project#12721 for context) Signed-off-by: rzou <zou3519@gmail.com>
1 parent 2e3e3c8 commit 33aaa15

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

csrc/torch_bindings.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
2020
// vLLM custom ops
2121
//
2222

23-
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
23+
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
24+
// so we need
2425
// to override this for many GEMMs with the following tag. Otherwise,
2526
// torch.compile will force all input tensors to be contiguous(), which
2627
// will break many custom ops that require column-major weight matrices.
27-
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
28-
// to match exact eager-mode strides.
29-
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
28+
// This was a bug and PyTorch 2.7 has since fixed this.
29+
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
30+
#define stride_tag at::Tag::needs_fixed_stride_order
31+
#else
32+
#define stride_tag
33+
#endif
3034

3135
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
3236
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.platforms import current_platform
9-
from vllm.utils import direct_register_custom_op
9+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
1010

1111

1212
def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
@@ -93,8 +93,12 @@ def mla_decode_fwd_fake(
9393

9494

9595
if current_platform.is_rocm():
96+
if is_torch_equal_or_newer("2.7.0"):
97+
tags = ()
98+
else:
99+
tags = (torch.Tag.needs_fixed_stride_order, ),
96100
direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd",
97101
op_func=mla_decode_fwd_impl,
98102
mutates_args=["o"],
99103
fake_impl=mla_decode_fwd_fake,
100-
tags=[torch.Tag.needs_fixed_stride_order])
104+
tags=tags)

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
_resize_cache, moe_kernel_quantize_input)
2323
from vllm.platforms import current_platform
2424
from vllm.triton_utils import tl, triton
25-
from vllm.utils import direct_register_custom_op
25+
from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer
2626

2727
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
2828

@@ -1053,7 +1053,8 @@ def inplace_fused_experts_fake(
10531053
op_func=inplace_fused_experts,
10541054
mutates_args=["hidden_states"],
10551055
fake_impl=inplace_fused_experts_fake,
1056-
tags=(torch.Tag.needs_fixed_stride_order, ),
1056+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1057+
(torch.Tag.needs_fixed_stride_order, )),
10571058
)
10581059

10591060

@@ -1117,7 +1118,8 @@ def outplace_fused_experts_fake(
11171118
op_func=outplace_fused_experts,
11181119
mutates_args=[],
11191120
fake_impl=outplace_fused_experts_fake,
1120-
tags=(torch.Tag.needs_fixed_stride_order, ),
1121+
tags=(() if is_torch_equal_or_newer("2.7.0") else
1122+
(torch.Tag.needs_fixed_stride_order, )),
11211123
)
11221124

11231125

0 commit comments

Comments
 (0)