Skip to content

[RELAY][PASS] Make Fusor to handle split gracefully #2890

@ajtulloch

Description

@ajtulloch

Hi folks,

A common pattern in LSTM/GRU-style cells is a structure like (for simplicity):

        rnn_dim = 10
        X = relay.var("X", shape=(1, rnn_dim))
        W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
        matmul = relay.nn.dense(X, W)
        splitted = relay.split(matmul, indices_or_sections=3, axis=1)
        out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])

Normally when implementing this in Relay, we'd expect that graph_fuse would fuse this entire sequence (matmul + split + sigmoid/tanh/exp/add/mul) into a single function, as that's entirely reasonable expectation and generates the highest performance code. That is, we expect:

fn (%X: Tensor[(1, 10), float32],
    %y: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = nn.dense(%X, %y, units=None)
  %1 = split(%0, indices_or_sections=int64(3), axis=1)
  %2 = %1.0
  %3 = sigmoid(%2)
  %4 = %1.1
  %5 = tanh(%4)
  %6 = %1.2
  %7 = exp(%6)
  %8 = multiply(%5, %7)
  %9 = add(%3, %8)
  %9
}

Instead, Relay generates something like:

fn (%X: Tensor[(1, 10), float32],
    %y: Tensor[(30, 10), float32])
    -> Tensor[(1, 10), float32] {
  %0 = fn(%p0: Tensor[(1, 10), float32],
          %p1: Tensor[(30, 10), float32])
          -> Tensor[(1, 30), float32] {
    %1 = nn.dense(%p0, %p1, units=None) # ty=Tensor[(1, 30), float32]
    %1
  }
  %2 = %0(%X, %y) # ty=Tensor[(1, 30), float32]
  %3 = fn(%p01: Tensor[(1, 30), float32])
          -> Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]] {
    %4 = split(%p01, indices_or_sections=int64(3), axis=1) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
    %4
  }
  %5 = %3(%2) # ty=Tuple[Tensor[(1, 10), float32], Tensor[(1, 10), float32], Tensor[(1, 10), float32]]
  %6 = %5.0
  %7 = %5.1
  %8 = %5.2
  %9 = fn(%p02: Tensor[(1, 10), float32],
          %p11: Tensor[(1, 10), float32],
          %p2: Tensor[(1, 10), float32])
          -> Tensor[(1, 10), float32] {
    %10 = sigmoid(%p02) # ty=Tensor[(1, 10), float32]
    %11 = tanh(%p11) # ty=Tensor[(1, 10), float32]
    %12 = exp(%p2) # ty=Tensor[(1, 10), float32]
    %13 = multiply(%11, %12) # ty=Tensor[(1, 10), float32]
    %14 = add(%10, %13) # ty=Tensor[(1, 10), float32]
    %14
  }
  %15 = %9(%6, %7, %8) # ty=Tensor[(1, 10), float32]
  %15
}

While of course it would be possible to implement a "GateComputation" op or similar which is internally just (split + pointwise functions), but it would be quite elegant to avoid that if possible.

I'm not fluent in the Relay GraphFuser code, but I was hoping someone (@jroesch?) knows off the top of their head what needs to be modified inside the fuser, and I or someone else can do the implementation work.

cc @jroesch, @yidawang, @tqchen

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions