Skip to content

Commit f708e92

Browse files
jiayisunxpytorchmergebot
authored andcommitted
[Inductor] support Conv/Linear + broadcast add fusion (pytorch#138201)
Pull Request resolved: pytorch#138201 Approved by: https://github.com/jgong5, https://github.com/jansel
1 parent 5ab5a61 commit f708e92

File tree

4 files changed

+153
-12
lines changed

4 files changed

+153
-12
lines changed

aten/src/ATen/native/mkldnn/Conv.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,9 @@ Tensor mkldnn_convolution_pointwise_binary(
371371

372372
auto output_sizes = conv_output_size(
373373
input_t.sizes(), weight_t.sizes(), padding_expanded, stride_expanded, dilation_expanded);
374-
// TODO: support broadcast binary fusion.
375374
TORCH_CHECK(
376-
output_sizes == other_t.sizes(),
377-
"Binary Fusion's inputs should have same shape");
375+
input_t.dim() == other_t.dim(),
376+
"Binary Fusion's inputs should have same dimensions");
378377
// Only calling fusion path for channels_last path.
379378
// TODO: OneDNN doesn't optimize well for groups > 1 case, it will be enabled
380379
// at next OneDNN release.
@@ -405,18 +404,17 @@ Tensor mkldnn_convolution_pointwise_binary(
405404
auto weight =
406405
weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format);
407406
auto other = other_t.contiguous(memory_format);
408-
auto output = at::empty_like(other);
407+
auto output = at::empty(output_sizes, input_t.options()).contiguous(memory_format);
409408
const ideep::tensor x = itensor_from_tensor(input);
410409
const ideep::tensor w = itensor_from_tensor(weight);
411410
const ideep::tensor z = itensor_from_tensor(other);
412411
ideep::tensor y = itensor_from_tensor(output);
413-
auto output_size = other.sizes().vec();
414412
ideep::tag format_tag = ideep::tag::nhwc;
415413
if (input_t.ndimension() == 5) {
416414
format_tag = ideep::tag::ndhwc;
417415
}
418416
auto other_desc = ideep::tensor::desc(
419-
output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag);
417+
other.sizes().vec(), get_mkldnn_dtype(other.scalar_type()), format_tag);
420418

421419
ideep::attr_t op_attr;
422420
ideep::post_ops po;
@@ -433,7 +431,7 @@ Tensor mkldnn_convolution_pointwise_binary(
433431
z,
434432
w,
435433
b,
436-
output_size,
434+
output_sizes,
437435
y,
438436
stride_expanded,
439437
dilation_expanded,
@@ -447,7 +445,7 @@ Tensor mkldnn_convolution_pointwise_binary(
447445
x,
448446
z,
449447
w,
450-
output_size,
448+
output_sizes,
451449
y,
452450
stride_expanded,
453451
dilation_expanded,

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ Tensor mkldnn_linear_pointwise_binary(
299299
}
300300

301301
TORCH_CHECK(
302-
output.sizes() == other_reshaped.sizes(),
303-
"linear_binary_run expects the size of output and other tensor to be the same");
302+
output.dim() == other_reshaped.dim(),
303+
"linear_binary_run expects the dimension of output and other tensor to be the same");
304304

305305
c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
306306
ideep::tensor mkldnn_output = itensor_from_tensor(output);

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,92 @@ def test_conv2d_binary(self):
633633
def test_conv3d_binary(self):
634634
self._test_conv_binary_base(dim=5)
635635

636+
def _test_conv_binary_broadcast_shapes_base(self, dim=4):
637+
assert dim == 4 or dim == 5
638+
639+
class M(torch.nn.Module):
640+
def __init__(
641+
self,
642+
binary_fn,
643+
has_relu,
644+
**kwargs,
645+
):
646+
super().__init__()
647+
if dim == 4:
648+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1)
649+
else:
650+
self.conv = torch.nn.Conv3d(3, 16, kernel_size=3, stride=1)
651+
self.binary_fn = binary_fn
652+
self.has_relu = has_relu
653+
654+
def forward(self, x, x2):
655+
x1 = self.conv(x)
656+
if has_relu:
657+
return self.binary_fn(x1, x2).relu()
658+
else:
659+
return self.binary_fn(x1, x2)
660+
661+
dtypes = [
662+
torch.float,
663+
]
664+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
665+
dtypes.append(torch.bfloat16)
666+
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
667+
dtypes.append(torch.float16)
668+
cl_format = torch.channels_last if dim == 4 else torch.channels_last_3d
669+
test_memory_format = [torch.contiguous_format, cl_format]
670+
options = itertools.product(
671+
binary_list,
672+
[True, False],
673+
test_memory_format,
674+
dtypes,
675+
)
676+
677+
for (
678+
binary_fn,
679+
has_relu,
680+
memory_format,
681+
dtype,
682+
) in options:
683+
metrics.reset()
684+
if dim == 4:
685+
x_shape = (1, 3, 56, 56)
686+
other_shape = (1, 16, 1, 1)
687+
else:
688+
x_shape = (1, 3, 20, 56, 56)
689+
other_shape = (1, 16, 1, 1, 1)
690+
mod = M(binary_fn, has_relu).eval()
691+
x = (
692+
torch.randn(x_shape, dtype=torch.float32, requires_grad=True)
693+
.add(1)
694+
.to(memory_format=memory_format)
695+
)
696+
other = (
697+
torch.randn(other_shape, dtype=torch.float32, requires_grad=True)
698+
.add(1)
699+
.to(memory_format=memory_format)
700+
.to(dtype)
701+
)
702+
match_count = binary_list[binary_fn][0] + 1
703+
match_nodes = binary_list[binary_fn][1]
704+
if has_relu:
705+
match_nodes += 1
706+
self._test_common(
707+
mod, (x, other), match_count, match_nodes + 1, check_autocast=dtype
708+
)
709+
710+
@skipIfNoDynamoSupport
711+
@skipIfNoONEDNN
712+
@skipIfRocm
713+
def test_conv2d_binary_broadcast_shapes_cpu(self):
714+
self._test_conv_binary_broadcast_shapes_base(dim=4)
715+
716+
@skipIfNoDynamoSupport
717+
@skipIfNoONEDNN
718+
@skipIfRocm
719+
def test_conv3d_binary_broadcast_shapes_cpu(self):
720+
self._test_conv_binary_broadcast_shapes_base(dim=5)
721+
636722
def test_linear_binary(self):
637723
class M(torch.nn.Module):
638724
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
@@ -683,6 +769,55 @@ def forward(self, x, y):
683769
)
684770
self.assertEqual(metrics.generated_kernel_count, 1)
685771

772+
def test_linear_binary_broadcast_shapes_cpu(self):
773+
class M(torch.nn.Module):
774+
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
775+
super().__init__()
776+
self.linear = torch.nn.Linear(
777+
in_channels, out_channels, bias=bias, **kwargs
778+
)
779+
self.binary_fn = binary_fn
780+
781+
def forward(self, x, y):
782+
x = self.linear(x)
783+
x = self.binary_fn(x, y.clone())
784+
return x
785+
786+
dtypes = []
787+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
788+
dtypes.append(torch.bfloat16)
789+
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
790+
dtypes.append(torch.float16)
791+
options = itertools.product(
792+
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
793+
)
794+
out_feature = 30
795+
796+
for binary_fn, input_shape, bias, dtype in options:
797+
metrics.reset()
798+
# addmm(mm) + (linear+add)
799+
match_count = 2
800+
match_nodes = 3
801+
if len(input_shape) == 3:
802+
is_inplace = binary_list[binary_fn][2]
803+
# view + linear + view(joint_graph+freeze pass)
804+
match_count = match_count + 5 if is_inplace else match_count + 3
805+
match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5
806+
mod = M(binary_fn, input_shape[-1], out_feature, bias).eval()
807+
v = torch.randn(input_shape)
808+
other = torch.randn(input_shape[:-1] + [1]).to(dtype)
809+
self._test_common(
810+
mod,
811+
(
812+
v,
813+
other,
814+
),
815+
match_count,
816+
match_nodes,
817+
check_autocast=dtype,
818+
)
819+
self.assertEqual(metrics.generated_kernel_count, 1)
820+
686821
def test_multi_linear_share_same_input(self):
687822
# llama pattern.
688823
class M(torch.nn.Module):

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,14 @@ def get_meta_value(argument: torch.fx.node.Argument):
367367
for n in binary_nodes
368368
):
369369
return False
370+
370371
if any(
371-
get_meta_value(n.args[0]).size() != get_meta_value(n.args[1]).size()
372+
get_meta_value(n.args[0]).dim() != get_meta_value(n.args[1]).dim()
373+
or not all(
374+
get_meta_value(n.args[0]).size(i) == get_meta_value(n.args[1]).size(i)
375+
or get_meta_value(match.kwargs["other"]).size(i) == 1
376+
for i in range(get_meta_value(n.args[0]).dim())
377+
)
372378
or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device
373379
or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype
374380
for n in binary_nodes
@@ -538,7 +544,9 @@ def fn(match, *args, **kwargs):
538544
computation_args += [1.0, None, [], None]
539545
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
540546
other.realize()
541-
if not _can_be_inplace(other):
547+
if not _can_be_inplace(other) or other.data.shape != list(
548+
match.nodes[0].meta["val"].size()
549+
):
542550
return L[outplace_fusion_op](*computation_args)
543551
return L[inplace_fusion_op](*computation_args)
544552

0 commit comments

Comments
 (0)