Skip to content

Commit

Permalink
register softmax (apache#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
Huyuwei authored and tqchen committed May 29, 2018
1 parent 48038a9 commit 1bc5d0a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
20 changes: 19 additions & 1 deletion nnvm/python/nnvm/top/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def compute_conv2d(attrs, inputs):
out = topi.broadcast_add(out, bias)
return out


@reg.register_schedule("conv2d")
def schedule_conv2d(_, outs, target):
"""Schedule definition of conv2d"""
Expand All @@ -33,3 +32,22 @@ def schedule_conv2d(_, outs, target):
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("conv2d", OpPattern.COMPLEX)


# softmax
@reg.register_compute("softmax")
def compute_softmax(attrs, inputs):
"""Compute definition of softmax"""
axis = attrs.get_int("axis")
assert axis == -1, "only support axis == -1 for now"
return topi.nn.softmax(inputs[0])

@reg.register_schedule("softmax")
def schedule_softmax(_, outs, target):
"""Schedule definition of softmax"""
if target == "cuda":
return topi.cuda.schedule_softmax(outs)
# naive schedule
return tvm.create_schedule([x.op for x in outs])

reg.register_pattern("softmax", OpPattern.COMPLEX)
46 changes: 46 additions & 0 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np

import tvm
import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime

USE_GPU=True

def default_target():
if USE_GPU:
return 'cuda'
else:
return 'llvm'

def default_ctx():
if USE_GPU:
return tvm.gpu(0)
else:
return tvm.cpu(0)

def test_softmax():
x = sym.Variable("x")
y = sym.softmax(x)
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
set_input("x", data)
# execute
run()
# get outputs
out = tvm.nd.empty(oshape, dtype)
get_output(0, out)
y_np = topi.testing.softmax_python(data.asnumpy())
np.testing.assert_allclose(out.asnumpy(), y_np, rtol=1e-5)


if __name__ == "__main__":
test_softmax()
2 changes: 1 addition & 1 deletion nnvm/tests/python/unittest/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def test_dense():
x1 = sym.dense(x, units=3, name="dense")
x2 = sym.flatten(x1)
x3 = sym.softmax(x2)
assert x2.list_input_names() == ['x', 'dense_weight', 'dense_bias']
assert x3.list_input_names() == ['x', 'dense_weight', 'dense_bias']


def test_concatenate_split():
Expand Down

0 comments on commit 1bc5d0a

Please sign in to comment.