-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Hi folks,
A common pattern in LSTM/GRU-style cells is a structure like (for simplicity):
rnn_dim = 10
X = relay.var("X", shape=(1, rnn_dim))
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
matmul = relay.nn.dense(X, W)
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
Normally when implementing this in Relay, we'd expect that graph_fuse would fuse this entire sequence (matmul + split + sigmoid/tanh/exp/add/mul) into a single function, as that's entirely reasonable expectation and generates the highest performance code. That is, we expect:
fn (%X: Tensor[(1, 10), float32],
%y: Tensor[(30, 10), float32])
-> Tensor[(1, 10), float32] {
%0 = nn.dense(%X, %y, units=None)
%1 = split(%0, indices_or_sections=int64(3), axis=1)
%2 = %1.0
%3 = sigmoid(%2)
%4 = %1.1
%5 = tanh(%4)
%6 = %1.2
%7 = exp(%6)
%8 = multiply(%5, %7)
%9 = add(%3, %8)
%9
}
Instead, Relay generates something like:
fn (%X: Tensor[(1, 10), float32],
%y: Tensor[(30, 10), float32])
-> Tensor[(1, 10), float32] {
%0 = fn(%p0: Tensor[(1, 10), float32],
%p1: Tensor[(30, 10), float32])
-> Tensor[(1, 30), float32] {
%1 = nn.dense(%p0, %p1, units=None) # ty=Tensor[(1, 30), float32]
%1
}
%2 = %0(%X, %y) # ty=Tensor[(1, 30), float32]
%3 = fn(%p01: Tensor[(1, 30), float32])
-> Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]] {
%4 = split(%p01, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
%4
}
%5 = %3(%2) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
%6 = %5.0
%7 = %5.1
%8 = %5.2
%9 = fn(%p02: Tensor[(1, 10), float32],
%p11: Tensor[(1, 10), float32],
%p2: Tensor[(1, 10), float32])
-> Tensor[(1, 10), float32] {
%10 = sigmoid(%p02) # ty=Tensor[(1, 10), float32]
%11 = tanh(%p11) # ty=Tensor[(1, 10), float32]
%12 = exp(%p2) # ty=Tensor[(1, 10), float32]
%13 = multiply(%11, %12) # ty=Tensor[(1, 10), float32]
%14 = add(%10, %13) # ty=Tensor[(1, 10), float32]
%14
}
%15 = %9(%6, %7, %8) # ty=Tensor[(1, 10), float32]
%15
}
While of course it would be possible to implement a "GateComputation" op or similar which is internally just (split + pointwise functions), but it would be quite elegant to avoid that if possible.
I'm not fluent in the Relay GraphFuser code, but I was hoping someone (@jroesch?) knows off the top of their head what needs to be modified inside the fuser, and I or someone else can do the implementation work.