Skip to content

Commit

Permalink
inductor: fix cpp wrapper ExternKernel check (pytorch#96799)
Browse files Browse the repository at this point in the history
Fix cpp_wrapper functionality for ExternKernel. Changes in pytorch#91575 has disabled the cpp_wrapper for ExternKernel cases.

1. Need to set the `cpp_wrapper` attr before `V.graph.register_buffer(self)`.
`register_buffer` will invoke the below check:
https://github.com/pytorch/pytorch/blob/c6a82e433924b4d36fd571d36ce363cb1c622c76/torch/_inductor/graph.py#L220-L223
The current code which sets the `cpp_wrapper` after the `V.graph.register_buffer(self)` will always disable the cpp wrapper.

2. Fix the missing `ordered_kwargs_for_cpp_kernel` attr for `at::addmm_out`

3. Enhance the UT to check that cpp_wrapper has been turned on for the supported cases to prevent being unintentionally disabled by future changes.

Pull Request resolved: pytorch#96799
Approved by: https://github.com/jgong5, https://github.com/EikanWang, https://github.com/jansel
  • Loading branch information
chunyuan-w authored and pytorchmergebot committed Mar 17, 2023
1 parent 13538c8 commit 238b060
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
40 changes: 22 additions & 18 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5725,29 +5725,33 @@ def test_cpp_wrapper(self):
raise unittest.SkipTest("cpp_wrapper only supports cpu")

device = "cpu"
for name in [
"test_as_strided", # buffer reuse
"test_bitwise", # int32
"test_bmm1",
"test_bmm2",
"test_cat", # alias
"test_linear1",
"test_linear2",
"test_lowmem_dropout1", # None as output
"test_mm_views",
"test_profiler_mark_wrapper_call", # TODO: fallback to default wrapper for now
"test_reduction1", # Reduction
"test_relu", # multiple inputs
"test_silu", # single input, single output
"test_sum_dtype", # float64
"test_sum_int", # bool, int64, int8, uint8
"test_transpose", # multiple outputs, buffer clear
for name, supported in [
["test_as_strided", True], # buffer reuse
["test_bitwise", True], # int32
["test_bmm1", True],
["test_bmm2", True],
["test_cat", True], # alias
["test_linear1", True],
["test_linear2", True],
["test_lowmem_dropout1", True], # None as output
["test_mm_views", True],
[
"test_profiler_mark_wrapper_call",
False,
], # TODO: fallback to default wrapper for now
["test_reduction1", True], # Reduction
["test_relu", True], # multiple inputs
["test_silu", True], # single input, single output
["test_sum_dtype", True], # float64
["test_sum_int", True], # bool, int64, int8, uint8
["test_transpose", True], # multiple outputs, buffer clear
]:
test_name = f"{name}_{device}"
assert hasattr(self, test_name), "undefined function"
func = getattr(self, test_name)
assert callable(func), "not a callable"
func()
code = run_and_get_cpp_code(func, [])
self.assertEqual("load_inline" in code, supported)

@unittest.skipIf(IS_X86 and not HAS_AVX2, "Requires AVX2")
def test_pixel_shuffle_channels_last(self):
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,15 +2812,17 @@ def __init__(
output_view=None,
kernel=None,
cpp_kernel=None,
ordered_kwargs_for_cpp_kernel=(),
):
super().__init__(
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.output_view = output_view
self.cpp_kernel = cpp_kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.name = V.graph.register_buffer(self)
if kernel is not None:
self.kernel = kernel
self.cpp_kernel = cpp_kernel

def should_allocate(self):
return True
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/kernel/mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
aten_mm = ExternKernelChoice(torch.mm, "at::mm_out")


aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out")
aten_addmm = ExternKernelChoice(torch.addmm, "at::addmm_out", ("beta", "alpha"))


def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1):
Expand Down
12 changes: 11 additions & 1 deletion torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,22 @@ def get_dtype(name):


class ExternKernelChoice:
def __init__(self, kernel, cpp_kernel=None, *, name=None, has_out_variant=True):
def __init__(
self,
kernel,
cpp_kernel=None,
ordered_kwargs_for_cpp_kernel=(),
*,
name=None,
has_out_variant=True,
):
super().__init__()
name = name or kernel.__name__
assert callable(kernel)
assert not hasattr(extern_kernels, name), "duplicate extern kernel"
self.name = name
self.cpp_kernel = cpp_kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.has_out_variant = has_out_variant
setattr(extern_kernels, name, kernel)

Expand Down Expand Up @@ -623,6 +632,7 @@ def output_node(self):
inputs=self.input_nodes,
kernel=self.choice.call_name(),
cpp_kernel=self.choice.cpp_kernel,
ordered_kwargs_for_cpp_kernel=self.choice.ordered_kwargs_for_cpp_kernel,
kwargs=self.kwargs,
)
)
Expand Down

0 comments on commit 238b060

Please sign in to comment.