Skip to content

Commit

Permalink
[inductor] fix linear add bias pattern (pytorch#128473)
Browse files Browse the repository at this point in the history
Fix pytorch#128287.
Previous the assertion in `linear_add_bias` are pretty bad
```
assert packed_weight_node.name == "_reorder_linear_weight"
assert transpose_weight_node.name == "permute_default"
```
because the `name` can be changed to `_reorder_linear_weight_id, permute_default_id` if we have more than 1 reorder/permute.

Check `target` instead `name` can solve this issue.

UT is also updated to have match more than 1 `linear_add_bias` pattern to cover this case.

Co-authored-by: Jiong Gong <jiong.gong@intel.com>
Pull Request resolved: pytorch#128473
Approved by: https://github.com/jgong5

(cherry picked from commit c53d65b)
  • Loading branch information
zhuhaozhe committed Jun 13, 2024
1 parent 1cd4199 commit 5f23f6d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 10 additions & 6 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,16 @@ def test_linear_add_bias(self):
class M(torch.nn.Module):
def __init__(self, dtype, unary_fn):
super().__init__()
self.linear = torch.nn.Linear(10, 64, bias=False)
self.bias = torch.randn(64).to(dtype=dtype)
self.linear1 = torch.nn.Linear(10, 64, bias=False)
self.bias1 = torch.randn(64).to(dtype=dtype)
self.linear2 = torch.nn.Linear(10, 64, bias=False)
self.bias2 = torch.randn(64).to(dtype=dtype)
self.unary_fn = unary_fn

def forward(self, x):
x = self.linear(x) + self.bias
return self.unary_fn(x)
a = self.linear1(x) + self.bias1
b = self.linear2(x) + self.bias2
return self.unary_fn(a), self.unary_fn(b)

dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
Expand All @@ -419,13 +422,14 @@ def forward(self, x):
mod = M(dtype, unary_fn).eval()
v = torch.randn(2, 10)
matcher_count = 3
# Add 1 for weight packing pass, add 2 for bias folding pass.
# Add 1 for weight packing pass, add 2 for bias folding pass per linear.
matcher_nodes = unary_list[unary_fn] + 3
if self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
matcher_nodes += 2
# we have 2 linears, so we double the matcher_count/nodes
self._test_common(
mod, (v,), matcher_count, matcher_nodes, check_autocast=dtype
mod, (v,), matcher_count * 2, matcher_nodes * 2, check_autocast=dtype
)
self.assertEqual(metrics.generated_kernel_count, 1)

Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/fx_passes/mkldnn_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,9 +790,9 @@ def is_linear_add_bias(match):
add_node = match.output_node()
linear_node = add_node.args[0]
packed_weight_node = linear_node.args[1]
assert packed_weight_node.name == "_reorder_linear_weight"
assert packed_weight_node.target == mkldnn._reorder_linear_weight
transpose_weight_node = packed_weight_node.args[0]
assert transpose_weight_node.name == "permute_default"
assert transpose_weight_node.target == aten.permute.default
weight_meta = transpose_weight_node.args[0].meta.get("val")
bias_node = add_node.args[1]
if isinstance(bias_node, int):
Expand Down

0 comments on commit 5f23f6d

Please sign in to comment.