Skip to content

Commit 5615880

Browse files
yuruofeifeitqchen
authored andcommitted
[GRADIENT] Register more gradient operators (#300)
* Add conv2d max_pool backward op * Added tests * Fix testing * Address comments * Change dot to matmul * Address comments * Break down indicator function * Make greater, less numpy compatible
1 parent fdf54ec commit 5615880

File tree

18 files changed

+1016
-162
lines changed

18 files changed

+1016
-162
lines changed

nnvm/docs/top.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ This level enables fully connected multi-layer perceptron.
2828
:nosignatures:
2929

3030
nnvm.symbol.dense
31+
nnvm.symbol.matmul
3132
nnvm.symbol.relu
3233
nnvm.symbol.tanh
3334
nnvm.symbol.sigmoid
@@ -38,6 +39,7 @@ This level enables fully connected multi-layer perceptron.
3839
nnvm.symbol.elemwise_sub
3940
nnvm.symbol.elemwise_mul
4041
nnvm.symbol.elemwise_div
42+
nnvm.symbol.elemwise_sum
4143
nnvm.symbol.full
4244
nnvm.symbol.full_like
4345
nnvm.symbol.ones
@@ -54,6 +56,8 @@ This level enables fully connected multi-layer perceptron.
5456
nnvm.symbol.softmax
5557
nnvm.symbol.log_softmax
5658
nnvm.symbol.pad
59+
nnvm.symbol.block_grad
60+
nnvm.symbol.indicator
5761

5862

5963
**Level 2: Convolutions**
@@ -77,6 +81,8 @@ This level enables typical convnet models.
7781
:nosignatures:
7882

7983
nnvm.symbol.reshape
84+
nnvm.symbol.reshape_like
85+
nnvm.symbol.expand_like
8086
nnvm.symbol.copy
8187
nnvm.symbol.negative
8288
nnvm.symbol.leaky_relu
@@ -107,6 +113,7 @@ This level enables typical convnet models.
107113
Detailed Definitions
108114
--------------------
109115
.. autofunction:: nnvm.symbol.dense
116+
.. autofunction:: nnvm.symbol.matmul
110117
.. autofunction:: nnvm.symbol.relu
111118
.. autofunction:: nnvm.symbol.tanh
112119
.. autofunction:: nnvm.symbol.sigmoid
@@ -117,6 +124,7 @@ Detailed Definitions
117124
.. autofunction:: nnvm.symbol.elemwise_sub
118125
.. autofunction:: nnvm.symbol.elemwise_mul
119126
.. autofunction:: nnvm.symbol.elemwise_div
127+
.. autofunction:: nnvm.symbol.elemwise_sum
120128
.. autofunction:: nnvm.symbol.full
121129
.. autofunction:: nnvm.symbol.full_like
122130
.. autofunction:: nnvm.symbol.ones
@@ -133,6 +141,8 @@ Detailed Definitions
133141
.. autofunction:: nnvm.symbol.softmax
134142
.. autofunction:: nnvm.symbol.log_softmax
135143
.. autofunction:: nnvm.symbol.pad
144+
.. autofunction:: nnvm.symbol.block_grad
145+
.. autofunction:: nnvm.symbol.indicator
136146

137147
.. autofunction:: nnvm.symbol.conv2d
138148
.. autofunction:: nnvm.symbol.conv2d_transpose
@@ -142,6 +152,8 @@ Detailed Definitions
142152
.. autofunction:: nnvm.symbol.global_avg_pool2d
143153

144154
.. autofunction:: nnvm.symbol.reshape
155+
.. autofunction:: nnvm.symbol.reshape_like
156+
.. autofunction:: nnvm.symbol.expand_like
145157
.. autofunction:: nnvm.symbol.copy
146158
.. autofunction:: nnvm.symbol.negative
147159
.. autofunction:: nnvm.symbol.leaky_relu

nnvm/include/nnvm/top/tensor.h

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ enum TypeFlag {
6262
kUint64 = 10,
6363
};
6464

65+
enum IndicatorRuleFlag {
66+
kGT0 = 0,
67+
kLT0 = 1,
68+
kMax = 2,
69+
kMin = 3,
70+
};
71+
6572
#define DMLC_DECLARE_DTYPE_FIELD(name) \
6673
DMLC_DECLARE_FIELD(name) \
6774
.add_enum("float16", kFloat16) \
@@ -84,6 +91,28 @@ struct CastParam : public dmlc::Parameter<CastParam> {
8491
}
8592
};
8693

94+
struct IndicatorParam : public dmlc::Parameter<IndicatorParam> {
95+
TShape axis;
96+
bool exclude;
97+
DMLC_DECLARE_PARAMETER(IndicatorParam) {
98+
DMLC_DECLARE_FIELD(axis).set_default(TShape())
99+
.describe(R"code(The axis or axes along which to perform the indicator rule.
100+
101+
The default, `axis=()`, will compute over all elements into a
102+
scalar array with shape `(1,)`.
103+
104+
If `axis` is int, rule is applied on a particular axis.
105+
106+
If `axis` is a tuple of ints, rule is applied on all the axes
107+
specified in the tuple.
108+
109+
If `exclude` is true, rule will be applied on the axes that are
110+
NOT in axis instead.)code");
111+
DMLC_DECLARE_FIELD(exclude).set_default(false)
112+
.describe("Whether to apply rule on axis that are NOT in axis instead.");
113+
}
114+
};
115+
87116
struct ReshapeParam : public dmlc::Parameter<ReshapeParam> {
88117
Tuple<int64_t> shape;
89118

@@ -97,8 +126,7 @@ struct SqueezeParam : public dmlc::Parameter<SqueezeParam> {
97126

98127
DMLC_DECLARE_PARAMETER(SqueezeParam) {
99128
DMLC_DECLARE_FIELD(axis).set_default(TShape())
100-
.describe("The axis to squeeze in the input tensor."
101-
" If set to None, all size=1 axes will be squeezed");
129+
.describe("The axis to squeeze in the input tensor.");
102130
}
103131
};
104132

@@ -110,6 +138,15 @@ struct ScalarParam : public dmlc::Parameter<ScalarParam> {
110138
}
111139
};
112140

141+
struct FillValueParam : public dmlc::Parameter<FillValueParam> {
142+
double fill_value;
143+
144+
DMLC_DECLARE_PARAMETER(FillValueParam) {
145+
DMLC_DECLARE_FIELD(fill_value)
146+
.describe("Scalar value to be filled");
147+
}
148+
};
149+
113150
struct TransposeParam : public dmlc::Parameter<TransposeParam> {
114151
TShape axes;
115152

@@ -158,16 +195,49 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
158195
}
159196
};
160197

198+
struct InitOpWithScalarParam : public dmlc::Parameter<InitOpWithScalarParam> {
199+
TShape shape;
200+
int dtype;
201+
double fill_value;
202+
203+
DMLC_DECLARE_PARAMETER(InitOpWithScalarParam) {
204+
DMLC_DECLARE_FIELD(shape).set_default(TShape());
205+
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
206+
.describe("Target data type.");
207+
DMLC_DECLARE_FIELD(fill_value).describe("Scalar value to fill");
208+
}
209+
};
210+
161211
struct InitOpParam : public dmlc::Parameter<InitOpParam> {
162212
TShape shape;
163213
int dtype;
164-
double value;
165214

166215
DMLC_DECLARE_PARAMETER(InitOpParam) {
167216
DMLC_DECLARE_FIELD(shape).set_default(TShape());
168217
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
169218
.describe("Target data type.");
170-
DMLC_DECLARE_FIELD(value).describe("Value to fill");
219+
}
220+
};
221+
222+
struct ElementWiseReduceParam : public dmlc::Parameter<ElementWiseReduceParam> {
223+
int num_args;
224+
DMLC_DECLARE_PARAMETER(ElementWiseReduceParam) {
225+
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
226+
.describe("Number of inputs to be reduced.");
227+
}
228+
};
229+
230+
struct MatMulParam : public dmlc::Parameter<MatMulParam> {
231+
bool transpose_a;
232+
bool transpose_b;
233+
234+
DMLC_DECLARE_PARAMETER(MatMulParam) {
235+
DMLC_DECLARE_FIELD(transpose_a)
236+
.describe("If true then transpose the first input before dot.")
237+
.set_default(false);
238+
DMLC_DECLARE_FIELD(transpose_b)
239+
.describe("If true then transpose the second input before dot.")
240+
.set_default(false);
171241
}
172242
};
173243

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def build(graph, target=None, shape=None, dtype="float32", params=None, target_h
188188
The input types to the graph
189189
190190
params : dict of str to NDArray
191-
Input parameetrs to the graph that do not change
191+
Input parameters to the graph that do not change
192192
during inference time. Used for pre-compute
193193
folding optimization.
194194

nnvm/python/nnvm/compiler/graph_util.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import tvm
66
from . import graph_attr
77

8+
from ..graph import create
9+
from ..symbol import Group, ones_like
10+
811
def infer_shape(graph, **shape):
912
"""Infer the shape given the shape of inputs.
1013
@@ -89,3 +92,57 @@ def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
8992
err = _deep_compare(grapha, graphb, compare_variable_attrs)
9093
if err:
9194
raise ValueError("Graph compare error: " + err)
95+
96+
def get_gradient_graph(ys, xs, grad_ys=None):
97+
"""Create gradient graph of ys with respect to xs.
98+
99+
Parameters
100+
----------
101+
ys : Symbol or list of Symbol
102+
Symbols from which the gradient is calculated.
103+
xs : Symbol or list of Symbol
104+
Symbols the gradient respect to.
105+
For group symbol, gradients for all outputs will be calculated.
106+
grad_ys : Symbol or list of Symbol
107+
Head gradients for ys.
108+
109+
Returns
110+
-------
111+
ret : Graph
112+
Generated gradient graph.
113+
"""
114+
if isinstance(ys, list):
115+
ys = Group(ys)
116+
g = create(ys)
117+
g._set_symbol_list_attr('grad_ys', ys)
118+
g._set_symbol_list_attr('grad_xs', xs)
119+
ny = len(ys.list_output_names())
120+
if grad_ys is None:
121+
grad_ys = [ones_like(ys[i]) for i in range(ny)]
122+
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
123+
return g.apply('Gradient')
124+
125+
def gradients(ys, xs, grad_ys=None):
126+
"""Create gradient symbol of ys respect to xs.
127+
128+
Parameters
129+
----------
130+
ys : Symbol or list of Symbol
131+
Symbols from which the gradient is calculated.
132+
xs : Symbol or list of Symbol
133+
Symbols the gradient respect to.
134+
For group symbol, gradients for all outputs will be calculated.
135+
grad_ys : Symbol or list of Symbol
136+
Head gradients for ys.
137+
138+
Returns
139+
-------
140+
ret : list of Symbol
141+
Generated gradient symbol. For each xs,
142+
all gradients from ys are merged into a single symbol.
143+
"""
144+
grad_g = get_gradient_graph(ys, xs, grad_ys)
145+
nx = len(Group(xs).list_output_names()) \
146+
if isinstance(xs, list) else len(xs.list_output_names())
147+
ret = [grad_g.symbol[i] for i in range(nx)]
148+
return ret

nnvm/python/nnvm/graph.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ._base import GraphHandle, SymbolHandle
1414
from ._base import check_call
1515
from .symbol import Variable, Symbol, Group as _Group
16-
from .symbol import ones_like
1716

1817
class GraphIndex(object):
1918
"""Index for quickly accessing graph attributes.
@@ -271,38 +270,3 @@ def create(symbol):
271270
check_call(_LIB.NNGraphCreate(
272271
symbol.handle, ctypes.byref(ghandle)))
273272
return Graph(ghandle)
274-
275-
276-
def gradients(ys, xs, grad_ys=None):
277-
"""Create gradient symbol of ys respect to xs.
278-
279-
Parameters
280-
----------
281-
ys : Symbol or list of Symbol
282-
Symbols from which the gradient is calculated.
283-
xs : Symbol or list of Symbol
284-
Symbols the gradient respect to.
285-
For group symbol, gradients for all outputs will be calculated.
286-
grad_ys : Symbol or list of Symbol
287-
Head gradients for ys.
288-
289-
Returns
290-
-------
291-
ret : list of Symbol
292-
Generated gradient symbol. For each xs,
293-
all gradients from ys are merged into a single symbol.
294-
"""
295-
if isinstance(ys, list):
296-
ys = _Group(ys)
297-
g = create(ys)
298-
g._set_symbol_list_attr('grad_ys', ys)
299-
g._set_symbol_list_attr('grad_xs', xs)
300-
ny = len(ys.list_output_names())
301-
if grad_ys is None:
302-
grad_ys = [ones_like(ys[i]) for i in range(ny)]
303-
g._set_symbol_list_attr('grad_ys_out_grad', grad_ys)
304-
sym = g.apply('Gradient').symbol
305-
nx = len(_Group(xs).list_output_names()) \
306-
if isinstance(xs, list) else len(xs.list_output_names())
307-
ret = [sym[i] for i in range(nx)]
308-
return ret

nnvm/src/pass/gradient.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,23 @@ namespace pass {
1414
namespace {
1515

1616
// default aggregate gradient function
17-
// require operator __zero__ and __ewise_sum__ to be presented.
17+
// require operator zeros and elemwise_sum to be presented.
1818
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
1919
if (v.size() == 1) {
2020
return std::move(v[0]);
2121
} else if (v.size() == 0) {
2222
NodePtr zero_node = Node::Create();
23-
zero_node->attrs.op = Op::Get("_zeros");
23+
zero_node->attrs.op = Op::Get("zeros");
24+
zero_node->attrs.name = "zero_grad";
25+
zero_node->attrs.op->attr_parser(&(zero_node->attrs));
2426
return NodeEntry{zero_node, 0, 0};
2527
} else {
2628
NodePtr sum_node = Node::Create();
2729
sum_node->attrs.op = Op::Get("elemwise_sum");
2830
sum_node->inputs = std::move(v);
31+
sum_node->attrs.name = "grad_sum";
32+
sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size());
33+
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
2934
return NodeEntry{sum_node, 0, 0};
3035
}
3136
}

0 commit comments

Comments
 (0)