Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #115 from tqchen/master
Browse files Browse the repository at this point in the history
Allow partial positional arguments of input symbol
  • Loading branch information
antinucleon committed Sep 21, 2015
2 parents a5941d8 + ccd0abd commit 339085a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
29 changes: 18 additions & 11 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ Symbol Symbol::operator[] (size_t index) const {
}
}

// create a default variable name
inline std::string DefaultVarName(const std::string &op_name,
const std::string &arg_name) {
if (op_name.length() == 0) {
return arg_name;
} else {
return op_name + '_' + arg_name;
}
}

void Symbol::Compose(const std::vector<Symbol>& args,
const std::string& name) {
CHECK_EQ(NumOutputs(), 1) << "Only composition of value function is supported currently";
Expand All @@ -261,13 +271,17 @@ void Symbol::Compose(const std::vector<Symbol>& args,
if (this->is_atomic()) {
// atomic symbol do not have place holder for all the arguments
std::vector<std::string> req_args = heads_[0].source->op->ListArguments();
CHECK_EQ(args.size(), req_args.size())
CHECK_LE(args.size(), req_args.size())
<< "Incorrect number of arguments, requires " << req_args.size()
<< ", provided " << args.size();
heads_[0].source->inputs.resize(args.size());
heads_[0].source->inputs.resize(req_args.size());
for (size_t i = 0; i < args.size(); ++i) {
heads_[0].source->inputs[i] = args[i].heads_[0];
}
for (size_t i = args.size(); i < req_args.size(); ++i) {
heads_[0].source->inputs[i] = DataEntry(
std::make_shared<Node>(nullptr, DefaultVarName(name, req_args[i])), 0);
}
} else {
// find all the place holders
size_t arg_counter = 0;
Expand Down Expand Up @@ -325,15 +339,8 @@ void Symbol::Compose(const std::unordered_map<std::string, Symbol>& kwargs,
heads_[0].source->inputs[i] = iter->second.heads_[0];
++nmatched;
} else {
// create a variable node
// TODO(bing): think of naming convention
if (name.length() == 0) {
heads_[0].source->inputs[i] = DataEntry(
std::make_shared<Node>(nullptr, req_args[i]), 0);
} else {
heads_[0].source->inputs[i] = DataEntry(
std::make_shared<Node>(nullptr, name + '_' + req_args[i]), 0);
}
heads_[0].source->inputs[i] = DataEntry(
std::make_shared<Node>(nullptr, DefaultVarName(name, req_args[i])), 0);
}
}
// if things goes wrong recover the old state
Expand Down
12 changes: 6 additions & 6 deletions tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# symbol net
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
softmax = mx.symbol.Softmax(data = fc3, name = 'sm')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.Softmax(fc3, name = 'sm')

num_round = 4
prefix = './mlp'
Expand Down

0 comments on commit 339085a

Please sign in to comment.