Skip to content

Commit 1650593

Browse files
Huyuweitqchen
authored andcommitted
register softmax (#16)
1 parent 7cbfb60 commit 1650593

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed

nnvm/python/nnvm/top/nn.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def compute_conv2d(attrs, inputs):
2323
out = topi.broadcast_add(out, bias)
2424
return out
2525

26-
2726
@reg.register_schedule("conv2d")
2827
def schedule_conv2d(_, outs, target):
2928
"""Schedule definition of conv2d"""
@@ -33,3 +32,22 @@ def schedule_conv2d(_, outs, target):
3332
return tvm.create_schedule([x.op for x in outs])
3433

3534
reg.register_pattern("conv2d", OpPattern.COMPLEX)
35+
36+
37+
# softmax
38+
@reg.register_compute("softmax")
39+
def compute_softmax(attrs, inputs):
40+
"""Compute definition of softmax"""
41+
axis = attrs.get_int("axis")
42+
assert axis == -1, "only support axis == -1 for now"
43+
return topi.nn.softmax(inputs[0])
44+
45+
@reg.register_schedule("softmax")
46+
def schedule_softmax(_, outs, target):
47+
"""Schedule definition of softmax"""
48+
if target == "cuda":
49+
return topi.cuda.schedule_softmax(outs)
50+
# naive schedule
51+
return tvm.create_schedule([x.op for x in outs])
52+
53+
reg.register_pattern("softmax", OpPattern.COMPLEX)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
3+
import tvm
4+
import topi
5+
import nnvm.symbol as sym
6+
import nnvm.compiler
7+
import nnvm.runtime
8+
9+
USE_GPU=True
10+
11+
def default_target():
12+
if USE_GPU:
13+
return 'cuda'
14+
else:
15+
return 'llvm'
16+
17+
def default_ctx():
18+
if USE_GPU:
19+
return tvm.gpu(0)
20+
else:
21+
return tvm.cpu(0)
22+
23+
def test_softmax():
24+
x = sym.Variable("x")
25+
y = sym.softmax(x)
26+
dtype = "float32"
27+
dshape = (10, 1000)
28+
oshape = dshape
29+
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
30+
m = nnvm.runtime.create(graph, lib, default_ctx())
31+
# get member functions
32+
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
33+
# set input
34+
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
35+
set_input("x", data)
36+
# execute
37+
run()
38+
# get outputs
39+
out = tvm.nd.empty(oshape, dtype)
40+
get_output(0, out)
41+
y_np = topi.testing.softmax_python(data.asnumpy())
42+
np.testing.assert_allclose(out.asnumpy(), y_np, rtol=1e-5)
43+
44+
45+
if __name__ == "__main__":
46+
test_softmax()

nnvm/tests/python/unittest/test_top_level1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def test_dense():
55
x1 = sym.dense(x, units=3, name="dense")
66
x2 = sym.flatten(x1)
77
x3 = sym.softmax(x2)
8-
assert x2.list_input_names() == ['x', 'dense_weight', 'dense_bias']
8+
assert x3.list_input_names() == ['x', 'dense_weight', 'dense_bias']
99

1010

1111
def test_concatenate_split():

0 commit comments

Comments
 (0)