Skip to content

Commit

Permalink
Upgrade nnvm to use automatic correspondence guessing
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Nov 25, 2016
1 parent 2eb1b9c commit 5bc14f6
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 54 deletions.
23 changes: 3 additions & 20 deletions src/op_nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ NNVM_REGISTER_OP(_backward)
.describe("backward operator of NN module")
.set_num_outputs([] (const NodeAttrs& attrs) {
const NNBackwardParam& param = dmlc::get<NNBackwardParam>(attrs.parsed);
return param.forward_readonly_inputs;
return param.forward_readonly_inputs - param.num_no_grad_inputs;
})
.set_num_inputs([] (const NodeAttrs& attrs) {
const NNBackwardParam& param = dmlc::get<NNBackwardParam>(attrs.parsed);
Expand All @@ -77,17 +77,7 @@ NNVM_REGISTER_OP(_backward)
if (param.need_outputs) n += 1;
return n;
})
.set_attr<nnvm::FBackwardOutToInIndex>("FBackwardOutToInIndex", [](const NodeAttrs& attrs) {
const NNBackwardParam& param = dmlc::get<NNBackwardParam>(attrs.parsed);
std::vector<uint32_t> vec;
for (uint32_t i = 0; i < param.forward_readonly_inputs; ++i) {
vec.push_back(i);
}
return vec;
})
.set_attr<nnvm::FBackwardInGradIndex>("FBackwardInGradIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});
.set_attr<nnvm::TIsBackward>("TIsBackward", true);


// common attributes for nn module.
Expand Down Expand Up @@ -365,13 +355,6 @@ NNVM_REGISTER_OP(flatten_layer)
NNVM_REGISTER_OP(_flatten_backward)
.set_num_inputs(1)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FBackwardOutToInIndex>(
"FBackwardOutToInIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
})
.set_attr<FBackwardInGradIndex>(
"FBackwardInGradIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});
.set_attr<nnvm::TIsBackward>("TIsBackward", true);

} // namespace tinyflow
24 changes: 2 additions & 22 deletions src/op_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,20 +365,7 @@ NNVM_REGISTER_OP(matmul)
NNVM_REGISTER_OP(_matmul_backward)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<FBackwardOutToInIndex>(
"FBackwardOutToInIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0, 1};
})
.set_attr<FBackwardInGradIndex>(
"FBackwardInGradIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
})
.set_attr<FInplaceOption>(
"FInplaceOption", [](const NodeAttrs& attrs) {
// lhs->gradLhs
return std::vector<std::pair<int, int> >{{1, 0}};
});

.set_attr<nnvm::TIsBackward>("TIsBackward", true);

struct ReduceParam : public dmlc::Parameter<ReduceParam> {
Tuple<int> reduction_indices;
Expand Down Expand Up @@ -440,14 +427,7 @@ NNVM_REGISTER_OP(reduce_mean)


NNVM_REGISTER_OP_GROUP(ReduceBackwardIndeAttr)
.set_attr<FBackwardOutToInIndex>(
"FBackwardOutToInIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
})
.set_attr<FBackwardInGradIndex>(
"FBackwardInGradIndex", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});
.set_attr<nnvm::TIsBackward>("TIsBackward", true);


NNVM_REGISTER_OP(_reduce_sum_backward)
Expand Down
11 changes: 0 additions & 11 deletions tests/python/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,6 @@ def test_softmax():
np.testing.assert_almost_equal(
ay, ax / np.sum(ax, axis=1, keepdims=True))

def test_bias_add():
x = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
y = tf.nn.bias_add(x, b)
ax = np.random.uniform(size=(2, 3))
ab = np.random.uniform(size=(3, ))
sess = tf.Session()
ay = sess.run(y, feed_dict={x:ax, b:ab})
np.testing.assert_almost_equal(
ay, ax + ab)

def test_matmul():
x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
Expand Down

0 comments on commit 5bc14f6

Please sign in to comment.