Skip to content

Commit

Permalink
Fuse tensor-scalar ops when scalar is constant (pytorch#10511)
Browse files Browse the repository at this point in the history
Summary:
This is on the way to resolving pytorch#9940.

Fixes pytorch#10501

This PR modifies graph fuser to fuse operations that have constant
scalar arguments. These constant scalar arguments are directly inlined
into the kernel body.

The context for this is that LSTM backward (in particular, sigmoid
backward) has many add(x, 1.) operations. This PR should be sufficient for
LSTM backward to get fused by the graph fuser.

cc apaszke zdevito
Pull Request resolved: pytorch#10511

Differential Revision: D9378896

Pulled By: zou3519

fbshipit-source-id: 6a7a2987f5b6e8edaaf4b599cd200df33361650f
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 17, 2018
1 parent f3ac619 commit 86c9856
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 168 deletions.
16 changes: 8 additions & 8 deletions test/expect/TestJit.test_concat_fusion_invariant_cuda.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@ graph(%0 : Float(2, 2)
%1 : Float(2, 2)
%2 : Float(4, 2)) {
%3 : int = prim::Constant[value=1]()
%4 : Float(2, 2) = aten::sub(%0, %1, %3)
%5 : Float(4, 2) = prim::FusionGroup_0[device=0](%4, %0, %1)
%6 : Float(4, 2) = aten::add(%5, %2, %3)
return (%6);
%4 : Float(4, 2) = prim::FusionGroup_0[device=0](%0, %1)
%5 : Float(4, 2) = aten::add(%4, %2, %3)
return (%5);
}
with prim::FusionGroup_0 = graph(%1 : Float(2, 2)
%3 : Float(2, 2)
with prim::FusionGroup_0 = graph(%3 : Float(2, 2)
%4 : Float(2, 2)) {
%7 : int = prim::Constant[value=1]()
%8 : Float(2, 2) = aten::add(%3, %4, %7)
%5 : int = prim::Constant[value=1]()
%6 : Float(2, 2) = aten::add(%3, %4, %5)
%2 : Float(4, 2) = prim::FusedConcat[dim=0](%6, %1)
%6 : Float(2, 2) = aten::sub(%3, %4, %5)
%2 : Float(4, 2) = prim::FusedConcat[dim=0](%8, %6)
return (%2);
}
114 changes: 50 additions & 64 deletions test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,56 @@ graph(%0 : Float(3, 20!)
%outgate : Float(3, 20)
%18 : Float(3, 20)) {
%19 : int = prim::Constant[value=1]()
%20 : Float(3, 20) = prim::FusionGroup_0[device=0](%18)
%21 : Float(3, 20) = aten::add(%20, %19, %19)
%22 : Float(3, 20) = prim::FusionGroup_1[device=0](%1, %21, %0, %outgate)
%23 : Float(3, 20) = aten::mul(%22, %19)
%24 : Float(3, 20) = aten::neg(%outgate)
%25 : Float(3, 20) = aten::add(%24, %19, %19)
%26 : Float(3, 20) = prim::FusionGroup_2[device=0](%cellgate)
%27 : Float(3, 20) = aten::add(%26, %19, %19)
%28 : Float(3, 20) = aten::neg(%forgetgate)
%29 : Float(3, 20) = aten::add(%28, %19, %19)
%30 : Float(3, 20) = aten::neg(%ingate)
%31 : Float(3, 20) = aten::add(%30, %19, %19)
%32 : Float(3, 80) = prim::FusionGroup_3[device=0](%31, %ingate, %29, %forgetgate, %27, %25, %outgate, %22, %cx, %23, %cellgate, %0, %18)
%33 : Float(3, 80) = aten::mul(%32, %19)
%34 : Float(80!, 3!) = aten::t(%33)
%35 : Float(80, 20) = aten::mm(%34, %hx)
%36 : Float(80!, 3!) = aten::t(%32)
%37 : Float(80, 10) = aten::mm(%36, %x)
return (%37, %35, %33, %33);
%20 : Float(3, 80) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %18, %0)
%21 : Float(3, 80) = aten::mul(%20, %19)
%22 : Float(80!, 3!) = aten::t(%21)
%23 : Float(80, 20) = aten::mm(%22, %hx)
%24 : Float(80!, 3!) = aten::t(%20)
%25 : Float(80, 10) = aten::mm(%24, %x)
return (%25, %23, %21, %21);
}
with prim::FusionGroup_0 = graph(%2 : Float(3, 20)) {
%3 : Float(3, 20) = aten::mul(%2, %2)
%1 : Float(3, 20) = aten::neg(%3)
return (%1);
}
with prim::FusionGroup_1 = graph(%0 : Float(3, 20!)
%5 : Float(3, 20)
%7 : Float(3, 20!)
%8 : Float(3, 20)) {
%9 : Float(3, 20) = aten::mul(%7, %8)
%6 : Float(3, 20) = aten::mul(%9, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(3, 20) = aten::add(%0, %6, %2)
return (%3);
}
with prim::FusionGroup_2 = graph(%2 : Float(3, 20)) {
%3 : Float(3, 20) = aten::mul(%2, %2)
%1 : Float(3, 20) = aten::neg(%3)
return (%1);
}
with prim::FusionGroup_3 = graph(%6 : Float(3, 20)
%9 : Float(3, 20)
%12 : Float(3, 20)
%15 : Float(3, 20)
%18 : Float(3, 20)
%21 : Float(3, 20)
%24 : Float(3, 20)
%26 : Float(3, 20)
%27 : Float(3, 20)
%29 : Float(3, 20)
%31 : Float(3, 20)
%33 : Float(3, 20!)
%34 : Float(3, 20)) {
%35 : Float(3, 20) = aten::mul(%33, %34)
%32 : Float(3, 20) = aten::mul(%29, %31)
%30 : Float(3, 20) = aten::mul(%29, %9)
%28 : Float(3, 20) = aten::mul(%26, %27)
%25 : Float(3, 20) = aten::mul(%35, %24)
%22 : Float(3, 20) = aten::mul(%25, %21)
%19 : Float(3, 20) = aten::mul(%30, %18)
%16 : Float(3, 20) = aten::mul(%28, %15)
%13 : Float(3, 20) = aten::mul(%16, %12)
%10 : Float(3, 20) = aten::mul(%32, %9)
%7 : Float(3, 20) = aten::mul(%10, %6)
%4 : Float(3, 80) = prim::FusedConcat[dim=1](%7, %13, %19, %22)
with prim::FusionGroup_0 = graph(%9 : Float(3, 20)
%19 : Float(3, 20)
%33 : Float(3, 20)
%39 : Float(3, 20)
%46 : Float(3, 20)
%53 : Float(3, 20!)
%65 : Float(3, 20)
%67 : Float(3, 20!)) {
%69 : Float(3, 20) = aten::mul(%67, %65)
%68 : Float(3, 20) = aten::mul(%67, %39)
%66 : Float(3, 20) = aten::mul(%65, %65)
%64 : Float(3, 20) = aten::neg(%66)
%61 : int = prim::Constant[value=1]()
%62 : Float(3, 20) = aten::add(%64, %61, %61)
%59 : Float(3, 20) = aten::mul(%68, %62)
%55 : int = prim::Constant[value=1]()
%56 : Float(3, 20) = aten::add(%53, %59, %55)
%51 : int = prim::Constant[value=1]()
%52 : Float(3, 20) = aten::mul(%56, %51)
%50 : Float(3, 20) = aten::mul(%52, %33)
%49 : Float(3, 20) = aten::mul(%52, %9)
%47 : Float(3, 20) = aten::mul(%56, %46)
%44 : Float(3, 20) = aten::neg(%39)
%42 : int = prim::Constant[value=1]()
%43 : Float(3, 20) = aten::add(%44, %42, %42)
%40 : Float(3, 20) = aten::mul(%69, %39)
%37 : Float(3, 20) = aten::mul(%40, %43)
%34 : Float(3, 20) = aten::mul(%33, %33)
%32 : Float(3, 20) = aten::neg(%34)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%32, %29, %29)
%27 : Float(3, 20) = aten::mul(%49, %30)
%24 : Float(3, 20) = aten::neg(%19)
%22 : int = prim::Constant[value=1]()
%23 : Float(3, 20) = aten::add(%24, %22, %22)
%20 : Float(3, 20) = aten::mul(%47, %19)
%17 : Float(3, 20) = aten::mul(%20, %23)
%14 : Float(3, 20) = aten::neg(%9)
%12 : int = prim::Constant[value=1]()
%13 : Float(3, 20) = aten::add(%14, %12, %12)
%10 : Float(3, 20) = aten::mul(%50, %9)
%7 : Float(3, 20) = aten::mul(%10, %13)
%4 : Float(3, 80) = prim::FusedConcat[dim=1](%7, %17, %27, %37)
return (%4);
}
140 changes: 66 additions & 74 deletions test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
Original file line number Diff line number Diff line change
Expand Up @@ -27,90 +27,82 @@ graph(%0 : Float(3, 20!)
%outgate : Float(3, 20)
%27 : Float(3, 20)) {
%28 : int = prim::Constant[value=1]()
%29 : Float(3, 20) = prim::FusionGroup_0[device=0](%27)
%30 : Float(3, 20) = aten::add(%29, %28, %28)
%31 : Float(3, 20) = prim::FusionGroup_1[device=0](%1, %30, %0, %outgate)
%32 : Float(3, 20) = aten::mul(%31, %28)
%33 : Float(3, 20) = aten::neg(%outgate)
%34 : Float(3, 20) = aten::add(%33, %28, %28)
%35 : Float(3, 20) = prim::FusionGroup_2[device=0](%cellgate)
%36 : Float(3, 20) = aten::add(%35, %28, %28)
%37 : Float(3, 20) = aten::neg(%forgetgate)
%38 : Float(3, 20) = aten::add(%37, %28, %28)
%39 : Float(3, 20) = aten::neg(%ingate)
%40 : Float(3, 20) = aten::add(%39, %28, %28)
%41 : Float(3, 80) = prim::FusionGroup_3[device=0](%40, %ingate, %38, %forgetgate, %36, %34, %outgate, %31, %cx, %32, %cellgate, %0, %27)
%42 : Float(3, 80) = aten::mul(%41, %28)
%43 : Float(3, 80) = aten::mul(%42, %Uz)
%44 : Float(3, 80) = aten::mul(%42, %beta_h)
%45 : Float(3, 80) = aten::mul(%42, %Wx)
%46 : Float(3, 80) = aten::mul(%42, %beta_i)
%47 : Float(3, 80) = prim::FusionGroup_4[device=0](%44, %41, %22)
%48 : Float(3, 80), %49 : Float(3, 80) = prim::FusionGroup_5[device=0](%Wx, %41, %Uz)
%50 : Float(3, 80) = aten::mul(%49, %alpha)
%51 : Float(3, 80) = aten::add(%46, %50, %28)
%52 : Float(80!, 3!) = aten::t(%47)
%53 : Float(80, 20) = aten::mm(%52, %hx)
%54 : Float(80!, 3!) = aten::t(%51)
%55 : Float(80, 10) = aten::mm(%54, %x)
return (%55, %53, %48, %45, %43, %42);
%29 : Float(3, 80) = prim::FusionGroup_0[device=0](%ingate, %forgetgate, %cellgate, %outgate, %cx, %1, %27, %0)
%30 : Float(3, 80), %31 : Float(3, 80) = prim::FusionGroup_1[device=0](%Uz, %29)
%32 : Float(3, 80) = aten::mul(%31, %beta_h)
%33 : Float(3, 80) = aten::mul(%31, %Wx)
%34 : Float(3, 80) = aten::mul(%31, %beta_i)
%35 : Float(3, 80) = prim::FusionGroup_2[device=0](%32, %29, %22)
%36 : Float(3, 80), %37 : Float(3, 80) = prim::FusionGroup_3[device=0](%Wx, %29, %Uz)
%38 : Float(3, 80) = aten::mul(%37, %alpha)
%39 : Float(3, 80) = aten::add(%34, %38, %28)
%40 : Float(80!, 3!) = aten::t(%35)
%41 : Float(80, 20) = aten::mm(%40, %hx)
%42 : Float(80!, 3!) = aten::t(%39)
%43 : Float(80, 10) = aten::mm(%42, %x)
return (%43, %41, %36, %33, %30, %31);
}
with prim::FusionGroup_0 = graph(%2 : Float(3, 20)) {
%3 : Float(3, 20) = aten::mul(%2, %2)
%1 : Float(3, 20) = aten::neg(%3)
return (%1);
}
with prim::FusionGroup_1 = graph(%0 : Float(3, 20!)
%5 : Float(3, 20)
%7 : Float(3, 20!)
%8 : Float(3, 20)) {
%9 : Float(3, 20) = aten::mul(%7, %8)
%6 : Float(3, 20) = aten::mul(%9, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(3, 20) = aten::add(%0, %6, %2)
return (%3);
}
with prim::FusionGroup_2 = graph(%2 : Float(3, 20)) {
%3 : Float(3, 20) = aten::mul(%2, %2)
%1 : Float(3, 20) = aten::neg(%3)
return (%1);
}
with prim::FusionGroup_3 = graph(%6 : Float(3, 20)
%9 : Float(3, 20)
%12 : Float(3, 20)
%15 : Float(3, 20)
%18 : Float(3, 20)
%21 : Float(3, 20)
%24 : Float(3, 20)
%26 : Float(3, 20)
%27 : Float(3, 20)
%29 : Float(3, 20)
%31 : Float(3, 20)
%33 : Float(3, 20!)
%34 : Float(3, 20)) {
%35 : Float(3, 20) = aten::mul(%33, %34)
%32 : Float(3, 20) = aten::mul(%29, %31)
%30 : Float(3, 20) = aten::mul(%29, %9)
%28 : Float(3, 20) = aten::mul(%26, %27)
%25 : Float(3, 20) = aten::mul(%35, %24)
%22 : Float(3, 20) = aten::mul(%25, %21)
%19 : Float(3, 20) = aten::mul(%30, %18)
%16 : Float(3, 20) = aten::mul(%28, %15)
%13 : Float(3, 20) = aten::mul(%16, %12)
%10 : Float(3, 20) = aten::mul(%32, %9)
%7 : Float(3, 20) = aten::mul(%10, %6)
%4 : Float(3, 80) = prim::FusedConcat[dim=1](%7, %13, %19, %22)
with prim::FusionGroup_0 = graph(%9 : Float(3, 20)
%19 : Float(3, 20)
%33 : Float(3, 20)
%39 : Float(3, 20)
%46 : Float(3, 20)
%53 : Float(3, 20!)
%65 : Float(3, 20)
%67 : Float(3, 20!)) {
%69 : Float(3, 20) = aten::mul(%67, %65)
%68 : Float(3, 20) = aten::mul(%67, %39)
%66 : Float(3, 20) = aten::mul(%65, %65)
%64 : Float(3, 20) = aten::neg(%66)
%61 : int = prim::Constant[value=1]()
%62 : Float(3, 20) = aten::add(%64, %61, %61)
%59 : Float(3, 20) = aten::mul(%68, %62)
%55 : int = prim::Constant[value=1]()
%56 : Float(3, 20) = aten::add(%53, %59, %55)
%51 : int = prim::Constant[value=1]()
%52 : Float(3, 20) = aten::mul(%56, %51)
%50 : Float(3, 20) = aten::mul(%52, %33)
%49 : Float(3, 20) = aten::mul(%52, %9)
%47 : Float(3, 20) = aten::mul(%56, %46)
%44 : Float(3, 20) = aten::neg(%39)
%42 : int = prim::Constant[value=1]()
%43 : Float(3, 20) = aten::add(%44, %42, %42)
%40 : Float(3, 20) = aten::mul(%69, %39)
%37 : Float(3, 20) = aten::mul(%40, %43)
%34 : Float(3, 20) = aten::mul(%33, %33)
%32 : Float(3, 20) = aten::neg(%34)
%29 : int = prim::Constant[value=1]()
%30 : Float(3, 20) = aten::add(%32, %29, %29)
%27 : Float(3, 20) = aten::mul(%49, %30)
%24 : Float(3, 20) = aten::neg(%19)
%22 : int = prim::Constant[value=1]()
%23 : Float(3, 20) = aten::add(%24, %22, %22)
%20 : Float(3, 20) = aten::mul(%47, %19)
%17 : Float(3, 20) = aten::mul(%20, %23)
%14 : Float(3, 20) = aten::neg(%9)
%12 : int = prim::Constant[value=1]()
%13 : Float(3, 20) = aten::add(%14, %12, %12)
%10 : Float(3, 20) = aten::mul(%50, %9)
%7 : Float(3, 20) = aten::mul(%10, %13)
%4 : Float(3, 80) = prim::FusedConcat[dim=1](%7, %17, %27, %37)
return (%4);
}
with prim::FusionGroup_4 = graph(%0 : Float(3, 80)
with prim::FusionGroup_1 = graph(%1 : Float(3, 80)
%3 : Float(3, 80)) {
%4 : int = prim::Constant[value=1]()
%5 : Float(3, 80) = aten::mul(%3, %4)
%2 : Float(3, 80) = aten::mul(%5, %1)
return (%2, %5);
}
with prim::FusionGroup_2 = graph(%0 : Float(3, 80)
%4 : Float(3, 80)
%5 : Float(3, 80)) {
%6 : Float(3, 80) = aten::mul(%4, %5)
%2 : int = prim::Constant[value=1]()
%3 : Float(3, 80) = aten::add(%0, %6, %2)
return (%3);
}
with prim::FusionGroup_5 = graph(%1 : Float(3, 80)
with prim::FusionGroup_3 = graph(%1 : Float(3, 80)
%3 : Float(3, 80)
%4 : Float(3, 80)) {
%5 : Float(3, 80) = aten::mul(%3, %4)
Expand Down
11 changes: 11 additions & 0 deletions test/expect/TestScript.test_tensor_scalar_fusion_cuda-1.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
graph(%x : Float(2, 2)) {
%1 : Float(2, 2) = prim::FusionGroup_0[device=0](%x)
return (%1);
}
with prim::FusionGroup_0 = graph(%0 : Float(2, 2)) {
%z : float = prim::Constant[value=3]()
%4 : int = prim::Constant[value=1]()
%y : Float(2, 2) = aten::add(%0, %z, %4)
%2 : Float(2, 2) = aten::mul(%0, %y)
return (%2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
graph(%x : Float(2, 2)
%z : Float()) {
%2 : int = prim::TensorToNum(%z)
%3 : int = prim::Constant[value=1]()
%y : Dynamic = aten::add(%x, %2, %3)
%5 : Dynamic = aten::mul(%x, %y)
return (%5);
}
26 changes: 26 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,32 @@ def func2():

self.checkScript(func2, ())

@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@skipIfRocm
def test_tensor_scalar_fusion_cuda(self):
def should_fuse(x):
z = 3.
y = x + z
return x * y

# XXX: right now we only support fusing scalars if
# they're constant (#9940)
def should_not_fuse(x, z):
y = x + int(z)
return x * y

inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
ge = self.checkScript(should_fuse, inputs)
self.assertExpectedGraph(ge.graph_for(*inputs), subname='1')

inputs = [
torch.randn(2, 2, dtype=torch.float, device='cuda'),
torch.tensor(3., dtype=torch.float, device='cuda'),
]
ge = self.checkScript(should_not_fuse, inputs)
self.assertExpectedGraph(ge.graph_for(*inputs), subname='2')

def test_list_ops(self):
def test_equality():
a = [1, 2, 3]
Expand Down
Loading

0 comments on commit 86c9856

Please sign in to comment.