Skip to content

Commit 3b519c5

Browse files
committed
[TOP][COMPILER] sum, min, max, transpose, fix dense (apache#28)
1 parent 9809396 commit 3b519c5

File tree

10 files changed

+180
-47
lines changed

10 files changed

+180
-47
lines changed

nnvm/python/nnvm/frontend/mxnet.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=invalid-name
12
"""MXNet symbol frontend."""
23
from __future__ import absolute_import as _abs
34
import json
@@ -155,14 +156,14 @@ def _split(attrs):
155156
return op_name, new_attrs
156157

157158
_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
158-
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
159-
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
160-
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
161-
'broadcast_add', 'broadcast_div', 'broadcast_mul',
162-
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
163-
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
164-
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
165-
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
159+
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
160+
'__pow_scalar__', '__rdiv_scalar__', '__rpow_scalar__',
161+
'__rsub_scalar__', '__sub_scalar__', '__sub_symbol__',
162+
'broadcast_add', 'broadcast_div', 'broadcast_mul',
163+
'broadcast_sub', 'broadcast_to', 'cast', 'elemwise_add',
164+
'elemwise_div', 'elemwise_mul', 'elemwise_sub', 'exp',
165+
'flatten', 'log', 'log_softmax', 'max', 'min', 'negative',
166+
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']
166167

167168
_convert_map = {
168169
'null' : _variable,
@@ -190,8 +191,8 @@ def _split(attrs):
190191
}
191192

192193
def _convert_symbol(op_name, attrs,
193-
identity_list=_identity_list,
194-
convert_map=_convert_map):
194+
identity_list=None,
195+
convert_map=None):
195196
"""Convert from mxnet op to nnvm op.
196197
The converter must specify some conversions explicitly to
197198
support gluon format ops such as conv2d...
@@ -214,6 +215,8 @@ def _convert_symbol(op_name, attrs,
214215
(op_name, attrs)
215216
Converted (op_name, attrs) for nnvm.
216217
"""
218+
identity_list = identity_list if identity_list else _identity_list
219+
convert_map = convert_map if convert_map else _convert_map
217220
if op_name in identity_list:
218221
pass
219222
elif op_name in convert_map:

nnvm/python/nnvm/top/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from . import tensor
44
from . import nn
55
from . import transform
6+
from . import reduction

nnvm/python/nnvm/top/attr_dict.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# pylint: disable=invalid-name
22
"""Attr dictionary object used by schedule functions"""
3-
4-
import json
53
import tvm
64

75
_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
@@ -51,7 +49,7 @@ def get_int_tuple(self, key):
5149
tuple : tuple of int
5250
The result tuple
5351
"""
54-
return tuple(json.loads(self[key]))
52+
return tuple(int(x) for x in self[key][1:-1].split(",") if x)
5553

5654
def get_int(self, key):
5755
"""Get integer from attr dict

nnvm/python/nnvm/top/nn.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ def compute_relu(attrs, inputs, _):
2020

2121
# leaky_relu
2222
@reg.register_compute("leaky_relu")
23-
def compute_relu(attrs, inputs, _):
23+
def compute_leaky_relu(attrs, inputs, _):
2424
"""Compute definition of relu"""
25-
return topi.nn.leaky_relu(inputs[0])
25+
return topi.nn.leaky_relu(inputs[0], attrs.get_float("alpha"))
2626

2727
reg.register_schedule("leaky_relu", _fschedule_broadcast)
2828
reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
@@ -62,20 +62,19 @@ def schedule_softmax(_, outs, target):
6262
def compute_dense(attrs, inputs, _):
6363
"""Compute definition of dense"""
6464
if attrs.get_bool("use_bias"):
65-
return topi.nn.fully_connected_with_bias(
66-
inputs[0], inputs[1], inputs[2])
67-
return topi.nn.fully_connected(inputs[0], inputs[1])
65+
return topi.nn.dense(inputs[0], inputs[1], bias=inputs[2])
66+
return topi.nn.dense(inputs[0], inputs[1])
6867

6968
@reg.register_schedule("dense")
7069
def schedule_dense(_, outs, target):
7170
"""Schedule definition of dense"""
7271
if target == "cuda":
73-
raise ValueError("fully_connected not yet implemented")
72+
return topi.cuda.schedule_dense(outs)
7473
# naive schedule
7574
return tvm.create_schedule([x.op for x in outs])
7675

7776
# register extern for now, change me when fusion is enabled.
78-
reg.register_pattern("dense", OpPattern.OPAQUE)
77+
reg.register_pattern("dense", OpPattern.OUT_ELEMWISE_FUSABLE)
7978

8079

8180
# conv

nnvm/python/nnvm/top/reduction.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# pylint: disable=invalid-name, unused-argument
2+
"""Reduction ops"""
3+
from __future__ import absolute_import
4+
5+
import tvm
6+
import topi
7+
import topi.cuda
8+
from ..compiler import registry as reg
9+
from ..compiler import OpPattern
10+
11+
def _schedule_reduce(_, outs, target):
12+
"""Generic schedule for reduce"""
13+
if target == "cuda":
14+
return topi.cuda.schedule_reduce(outs)
15+
assert target.startswith("llvm")
16+
s = tvm.create_schedule([x.op for x in outs])
17+
x = outs[0]
18+
tvm.schedule.AutoInlineInjective(s)
19+
s[x].fuse(s[x].op.axis)
20+
return s
21+
22+
_fschedule_reduce = tvm.convert(_schedule_reduce)
23+
24+
def _compute_reduce(f):
25+
"""auxiliary function"""
26+
def _compute(attrs, inputs, out_info):
27+
axis = attrs.get_int_tuple("axis")
28+
keepdims = attrs.get_bool("keepdims")
29+
if axis:
30+
return f(inputs[0], axis=axis, keepdims=keepdims)
31+
return f(inputs[0], keepdims=keepdims)
32+
return _compute
33+
34+
# sum
35+
reg.register_compute("sum", _compute_reduce(topi.sum))
36+
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
37+
reg.register_schedule("sum", _fschedule_reduce)
38+
39+
# max
40+
reg.register_compute("max", _compute_reduce(topi.max))
41+
reg.register_pattern("max", OpPattern.COMM_REDUCE)
42+
reg.register_schedule("max", _fschedule_reduce)
43+
44+
# min
45+
reg.register_compute("min", _compute_reduce(topi.min))
46+
reg.register_pattern("min", OpPattern.COMM_REDUCE)
47+
reg.register_schedule("min", _fschedule_reduce)

nnvm/python/nnvm/top/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ def _schedule_injective(_, outs, target):
1919
s[x].fuse(s[x].op.axis)
2020
return s
2121

22+
2223
def _compute_binary_scalar(f):
2324
"""auxiliary function"""
24-
@tvm.tag_scope("ewise")
25+
@tvm.tag_scope(topi.tag.ELEMWISE)
2526
def _compute(attrs, x, _):
2627
x = x[0]
2728
scalar = attrs.get_float("scalar")

nnvm/python/nnvm/top/transform.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
import tvm
66
import topi
7-
from .tensor import _fschedule_broadcast
7+
from .tensor import _fschedule_broadcast, _fschedule_injective
88
from ..compiler import registry as reg
99
from ..compiler import OpPattern
1010

11-
# Need add reshape, transpose
11+
# Need add reshape
1212
@reg.register_compute("expand_dims")
1313
def compute_expand_dims(attrs, inputs, out_info):
1414
"""Compute definition of expand_dims"""
@@ -19,6 +19,16 @@ def compute_expand_dims(attrs, inputs, out_info):
1919
reg.register_schedule("expand_dims", _fschedule_broadcast)
2020

2121

22+
@reg.register_compute("transpose")
23+
def compute_transpose(attrs, inputs, out_info):
24+
"""Compute definition of expand_dims"""
25+
axes = attrs.get_int_tuple("axes")
26+
axes = tuple(axes) if axes else None
27+
return topi.transpose(inputs[0], axes)
28+
reg.register_pattern("transpose", OpPattern.INJECTIVE)
29+
reg.register_schedule("transpose", _fschedule_injective)
30+
31+
2232
def _flatten_index(indices, shape):
2333
"""flatten the index to 1D"""
2434
idx = 0
@@ -38,4 +48,4 @@ def compute_reshape(attrs, inputs, out_info):
3848
x = inputs[0]
3949
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
4050
reg.register_pattern("reshape", OpPattern.INJECTIVE)
41-
reg.register_schedule("reshape", _fschedule_broadcast)
51+
reg.register_schedule("reshape", _fschedule_injective)

nnvm/tests/python/compiler/test_op_fusion.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ def test_conv_ewise_injective():
5656
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
5757

5858

59+
def test_injective_reduce_injective():
60+
x = sym.Variable("x")
61+
x = sym.flatten(x) + 1
62+
y = sym.sum(x, axis=1)
63+
dtype = "float32"
64+
dshape = (32, 1, 18, 18)
65+
shape_dict = {"x": dshape}
66+
67+
for target, ctx in test_ctx_list():
68+
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
69+
m = nnvm.runtime.create(graph, lib, ctx)
70+
assert graph.index.num_nodes == 2
71+
data = np.random.uniform(size=dshape).astype(dtype)
72+
m.run(x=data)
73+
c_np = np.sum(data.reshape(32, 18 * 18) + 1, axis=1)
74+
# get output
75+
out = m.get_output(0, tvm.nd.empty(c_np.shape, dtype))
76+
np.testing.assert_allclose(out.asnumpy(), c_np, rtol=1e-5)
77+
78+
5979
if __name__ == "__main__":
80+
test_injective_reduce_injective()
6081
test_ewise_injective()
6182
test_conv_ewise_injective()

nnvm/tests/python/compiler/test_top_level1.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,21 @@
88

99
def test_relu():
1010
x = sym.Variable("x")
11-
y = sym.relu(x)
11+
y = sym.leaky_relu(x, alpha=0.3) - 0.2
12+
y = sym.relu(y)
1213
dtype = "float32"
1314
dshape = (1, 3, 32, 32)
1415
oshape = dshape
1516
for target, ctx in test_ctx_list():
1617
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
1718
m = nnvm.runtime.create(graph, lib, ctx)
1819
# get member functions
19-
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
20-
# set input
21-
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
22-
set_input("x", data)
23-
# execute
24-
run()
25-
# get output
26-
out = tvm.nd.empty(oshape, dtype)
27-
get_output(0, out)
28-
y_np = np.maximum(data.asnumpy(), 0.0)
29-
np.testing.assert_allclose(out.asnumpy(), y_np, atol=1e-5, rtol=1e-5)
20+
data = np.random.uniform(size=dshape).astype(dtype)
21+
m.run(x=data)
22+
data = (data < 0) * data * 0.3 + (data>0) * data - 0.2
23+
data = (data > 0) * data
24+
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
25+
np.testing.assert_allclose(out.asnumpy(), data, atol=1e-5, rtol=1e-5)
3026

3127

3228
def test_exp():
@@ -157,17 +153,18 @@ def test_dense():
157153
"dense_weight" : (3, 100),
158154
"dense_bias" : (3,),
159155
}
160-
graph, lib, _ = nnvm.compiler.build(y, "llvm", shape)
161-
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
162-
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
163-
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
164-
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
165-
res = tvm.nd.empty((10, 3))
166-
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
167-
m.get_output(0, res)
168-
res_np = np.dot(x_np, w_np.T) + b_np
169-
np.testing.assert_allclose(
170-
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
156+
for target, ctx in test_ctx_list():
157+
graph, lib, _ = nnvm.compiler.build(y, target, shape)
158+
m = nnvm.runtime.create(graph, lib, ctx)
159+
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
160+
w_np = np.random.uniform(size=shape["dense_weight"]).astype(dtype)
161+
b_np = np.random.uniform(size=shape["dense_bias"]).astype(dtype)
162+
res = tvm.nd.empty((10, 3))
163+
m.run(x=x_np, dense_weight=w_np, dense_bias=b_np)
164+
m.get_output(0, res)
165+
res_np = np.dot(x_np, w_np.T) + b_np
166+
np.testing.assert_allclose(
167+
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)
171168

172169

173170
def test_batchnorm():
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import tvm
3+
import topi
4+
import nnvm.symbol as sym
5+
import nnvm.compiler
6+
import nnvm.runtime
7+
from nnvm.testing.config import test_ctx_list
8+
9+
def verify_transpose(dshape, axes):
10+
x = sym.Variable("x")
11+
if axes:
12+
y = sym.transpose(x, axes=axes)
13+
else:
14+
y = sym.transpose(x)
15+
y = y + 1
16+
dtype = "float32"
17+
for target, ctx in test_ctx_list():
18+
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
19+
m = nnvm.runtime.create(graph, lib, ctx)
20+
# set input
21+
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
22+
m.run(x=data)
23+
out_np = np.transpose(data.asnumpy(), axes=axes) + 1
24+
out = m.get_output(0, tvm.nd.empty(out_np.shape))
25+
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
26+
27+
28+
def verify_reduce(dshape, fnp, fsym, **kwargs):
29+
x = sym.Variable("x")
30+
y = fsym(x + 1, **kwargs)
31+
dtype = "float32"
32+
for target, ctx in test_ctx_list():
33+
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
34+
m = nnvm.runtime.create(graph, lib, ctx)
35+
# set input
36+
data = np.random.uniform(size=dshape).astype(dtype)
37+
out_np = fnp(data + 1, **kwargs)
38+
m.run(x=data)
39+
out = m.get_output(0, tvm.nd.empty(out_np.shape))
40+
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)
41+
42+
43+
def test_tranpose():
44+
verify_transpose((2, 3, 4), (0, 2, 1))
45+
verify_transpose((2, 3, 4), None)
46+
47+
48+
def test_reduce():
49+
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
50+
verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
51+
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
52+
53+
54+
if __name__ == "__main__":
55+
test_reduce()
56+
test_tranpose()

0 commit comments

Comments
 (0)