|
7 | 7 | from tvm.relay.scope_builder import ScopeBuilder |
8 | 8 | from tvm.relay.op import add |
9 | 9 | from tvm.relay.module import Module |
| 10 | +from tvm.relay.testing.config import ctx_list |
10 | 11 |
|
11 | 12 | # @tq, @jr should we put this in testing ns? |
12 | 13 | def check_rts(expr, args, expected_result, mod=None): |
@@ -127,9 +128,47 @@ def test_plan_memory(): |
127 | 128 | assert len(device_types) == 1 |
128 | 129 |
|
129 | 130 |
|
| 131 | +def test_gru(): |
| 132 | + def gru(rnn_dim): |
| 133 | + X = relay.var("X", shape=(1, rnn_dim)) |
| 134 | + W = relay.var("y", shape=(3 * rnn_dim, rnn_dim)) |
| 135 | + matmul = relay.nn.dense(X, W) |
| 136 | + splitted = relay.split(matmul, indices_or_sections=3, axis=1) |
| 137 | + out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2]) |
| 138 | + return relay.Function([X, W], out) |
| 139 | + |
| 140 | + def sigmoid(x): |
| 141 | + return 1 / (1 + np.exp(-x)) |
| 142 | + |
| 143 | + def gru_numpy(X, W): |
| 144 | + prod = np.dot(X, W.transpose()) |
| 145 | + splits = np.split(prod, indices_or_sections=3, axis=1) |
| 146 | + return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2]) |
| 147 | + |
| 148 | + dtype = "float32" |
| 149 | + rnn_dim = 1000 |
| 150 | + x = np.random.rand(1, rnn_dim).astype(dtype) |
| 151 | + y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005 |
| 152 | + out_shape = (1, rnn_dim) |
| 153 | + z = gru(rnn_dim) |
| 154 | + |
| 155 | + for target, ctx in ctx_list(): |
| 156 | + with relay.build_config(opt_level=2): |
| 157 | + graph, lib, params = relay.build(z, target) |
| 158 | + m = graph_runtime.create(graph, lib, ctx) |
| 159 | + m.set_input("X", tvm.nd.array(x.astype(dtype))) |
| 160 | + m.set_input("y", tvm.nd.array(y.astype(dtype))) |
| 161 | + m.set_input(**params) |
| 162 | + m.run() |
| 163 | + out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() |
| 164 | + ref = gru_numpy(x, y) |
| 165 | + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) |
| 166 | + |
| 167 | + |
130 | 168 | if __name__ == "__main__": |
131 | 169 | test_plan_memory() |
132 | 170 | test_with_params() |
133 | 171 | test_add_op_scalar() |
134 | 172 | test_add_op_tensor() |
135 | 173 | test_add_op_broadcast() |
| 174 | + test_gru() |
0 commit comments