Skip to content

Commit d17aabb

Browse files
masahiMasahiro Masuda
authored andcommitted
add gru runtime test
1 parent d97a140 commit d17aabb

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

tests/python/relay/test_backend_graph_runtime.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tvm.relay.scope_builder import ScopeBuilder
88
from tvm.relay.op import add
99
from tvm.relay.module import Module
10+
from tvm.relay.testing.config import ctx_list
1011

1112
# @tq, @jr should we put this in testing ns?
1213
def check_rts(expr, args, expected_result, mod=None):
@@ -127,9 +128,47 @@ def test_plan_memory():
127128
assert len(device_types) == 1
128129

129130

131+
def test_gru():
132+
def gru(rnn_dim):
133+
X = relay.var("X", shape=(1, rnn_dim))
134+
W = relay.var("y", shape=(3 * rnn_dim, rnn_dim))
135+
matmul = relay.nn.dense(X, W)
136+
splitted = relay.split(matmul, indices_or_sections=3, axis=1)
137+
out = relay.sigmoid(splitted[0]) + relay.tanh(splitted[1]) * relay.exp(splitted[2])
138+
return relay.Function([X, W], out)
139+
140+
def sigmoid(x):
141+
return 1 / (1 + np.exp(-x))
142+
143+
def gru_numpy(X, W):
144+
prod = np.dot(X, W.transpose())
145+
splits = np.split(prod, indices_or_sections=3, axis=1)
146+
return sigmoid(splits[0]) + np.tanh(splits[1]) * np.exp(splits[2])
147+
148+
dtype = "float32"
149+
rnn_dim = 1000
150+
x = np.random.rand(1, rnn_dim).astype(dtype)
151+
y = np.random.rand(3*rnn_dim, rnn_dim).astype(dtype) * 0.01 - 0.005
152+
out_shape = (1, rnn_dim)
153+
z = gru(rnn_dim)
154+
155+
for target, ctx in ctx_list():
156+
with relay.build_config(opt_level=2):
157+
graph, lib, params = relay.build(z, target)
158+
m = graph_runtime.create(graph, lib, ctx)
159+
m.set_input("X", tvm.nd.array(x.astype(dtype)))
160+
m.set_input("y", tvm.nd.array(y.astype(dtype)))
161+
m.set_input(**params)
162+
m.run()
163+
out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy()
164+
ref = gru_numpy(x, y)
165+
tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
166+
167+
130168
if __name__ == "__main__":
131169
test_plan_memory()
132170
test_with_params()
133171
test_add_op_scalar()
134172
test_add_op_tensor()
135173
test_add_op_broadcast()
174+
test_gru()

0 commit comments

Comments
 (0)