Skip to content

Commit cd339f4

Browse files
committed
Place device now compatible and tested (apache#33)
1 parent aff2993 commit cd339f4

File tree

3 files changed

+55
-18
lines changed

3 files changed

+55
-18
lines changed

nnvm/include/dmlc/any.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ inline const std::type_info& any::type() const {
259259
template<typename T>
260260
inline void any::check_type() const {
261261
CHECK(type_ != nullptr)
262-
<< "The any container is empty";
262+
<< "The any container is empty"
263+
<< " requested=" << typeid(T).name();
263264
CHECK(type_->ptype_info == &typeid(T))
264265
<< "The stored type mismatch"
265266
<< " stored=" << type_->ptype_info->name()

nnvm/include/dmlc/parameter.h

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class FieldEntry;
5757
// forward declare ParamManagerSingleton
5858
template<typename PType>
5959
struct ParamManagerSingleton;
60+
61+
/*! \brief option in parameter initialization */
62+
enum ParamInitOption {
63+
/*! \brief allow unknown parameters */
64+
kAllowUnknown,
65+
/*! \brief need to match exact parameters */
66+
kAllMatch
67+
};
6068
} // namespace parameter
6169
/*!
6270
* \brief Information about a parameter field in string representations.
@@ -108,13 +116,17 @@ struct Parameter {
108116
* and throw error if something wrong happens.
109117
*
110118
* \param kwargs map of keyword arguments, or vector of pairs
119+
* \parma option The option on initialization.
111120
* \tparam Container container type
112121
* \throw ParamError when something go wrong.
113122
*/
114123
template<typename Container>
115-
inline void Init(const Container &kwargs) {
124+
inline void Init(const Container &kwargs,
125+
parameter::ParamInitOption option = parameter::kAllowUnknown) {
116126
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
117-
kwargs.begin(), kwargs.end(), NULL);
127+
kwargs.begin(), kwargs.end(),
128+
NULL,
129+
option == parameter::kAllowUnknown);
118130
}
119131
/*!
120132
* \brief initialize the parameter by keyword arguments.
@@ -130,7 +142,8 @@ struct Parameter {
130142
InitAllowUnknown(const Container &kwargs) {
131143
std::vector<std::pair<std::string, std::string> > unknown;
132144
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
133-
kwargs.begin(), kwargs.end(), &unknown);
145+
kwargs.begin(), kwargs.end(),
146+
&unknown, true);
134147
return unknown;
135148
}
136149
/*!
@@ -355,7 +368,8 @@ class ParamManager {
355368
inline void RunInit(void *head,
356369
RandomAccessIterator begin,
357370
RandomAccessIterator end,
358-
std::vector<std::pair<std::string, std::string> > *unknown_args) const {
371+
std::vector<std::pair<std::string, std::string> > *unknown_args,
372+
bool allow_unknown) const {
359373
std::set<FieldAccessEntry*> selected_args;
360374
for (RandomAccessIterator it = begin; it != end; ++it) {
361375
FieldAccessEntry *e = Find(it->first);
@@ -367,11 +381,13 @@ class ParamManager {
367381
if (unknown_args != NULL) {
368382
unknown_args->push_back(*it);
369383
} else {
370-
std::ostringstream os;
371-
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
372-
os << "----------------\n";
373-
PrintDocString(os);
374-
throw dmlc::ParamError(os.str());
384+
if (!allow_unknown) {
385+
std::ostringstream os;
386+
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
387+
os << "----------------\n";
388+
PrintDocString(os);
389+
throw dmlc::ParamError(os.str());
390+
}
375391
}
376392
}
377393
}

nnvm/src/pass/place_device.cc

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) {
2525
const Op* copy_op = Op::Get(src.GetAttr<std::string>("device_copy_op"));
2626
auto& device_assign_map = src.GetAttr<DeviceAssignMap>("device_assign_map");
2727
const IndexedGraph& idx = src.indexed_graph();
28-
2928
DeviceVector device;
3029
// copy on write semanatics
3130
if (src.attrs.count("device") != 0) {
@@ -79,17 +78,27 @@ Graph PlaceDevice(Graph src) {
7978
src.attrs["device"] = std::make_shared<any>(std::move(device));
8079
return src;
8180
}
82-
8381
std::map<std::tuple<uint32_t, uint32_t, int>, NodePtr> copy_map;
8482
std::vector<NodePtr> new_node_map(idx.num_nodes(), nullptr);
8583
std::unordered_map<const Node*, int> new_device_map;
84+
static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
8685

8786
// insert copy node
8887
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
8988
int dev_id = device[nid];
9089
const auto& inode = idx[nid];
9190
// check if mutation is needed
9291
bool need_mutate = false;
92+
if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) {
93+
for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) {
94+
auto e = inode.inputs[index];
95+
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
96+
LOG(FATAL) << " mutable state cannot go across device"
97+
<< " op=" << inode.source->op()->name
98+
<< " input_state_index=" << index;
99+
}
100+
}
101+
}
93102
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
94103
if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) {
95104
need_mutate = true; break;
@@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) {
102111
}
103112
}
104113
}
114+
if (inode.source->is_variable()) {
115+
CHECK(!need_mutate) << "consistency check";
116+
}
105117
if (need_mutate) {
106118
NodePtr new_node = Node::Create();
107119
new_node->attrs = inode.source->attrs;
@@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) {
120132
os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy";
121133
copy_node->attrs.op = copy_op;
122134
copy_node->attrs.name = os.str();
123-
copy_node->inputs.push_back(inode.source->inputs[i]);
135+
if (new_node_map[e.node_id] != nullptr) {
136+
copy_node->inputs.emplace_back(
137+
NodeEntry{new_node_map[e.node_id], e.index, 0});
138+
} else {
139+
copy_node->inputs.push_back(inode.source->inputs[i]);
140+
}
141+
if (copy_node->attrs.op->attr_parser != nullptr) {
142+
copy_node->attrs.op->attr_parser(&(copy_node->attrs));
143+
}
124144
copy_map[copy_key] = copy_node;
125145
new_device_map[copy_node.get()] = dev_id;
126146
new_node->inputs.emplace_back(
@@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) {
130150
if (new_node_map[e.node_id] != nullptr) {
131151
new_node->inputs.emplace_back(
132152
NodeEntry{new_node_map[e.node_id], e.index, 0});
133-
} else {
153+
} else {
134154
new_node->inputs.push_back(inode.source->inputs[i]);
135155
}
136156
}
@@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) {
150170
new_device_map[inode.source] = dev_id;
151171
}
152172
}
153-
154173
// make the new graph
155174
Graph ret;
156175
for (const NodeEntry& e : src.outputs) {
@@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) {
163182
}
164183
DeviceVector new_device_vec(ret.indexed_graph().num_nodes());
165184
for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) {
166-
if (new_device_map.count(ret.indexed_graph()[nid].source) == 0) {
167-
LOG(INFO) << "canot find " << ret.indexed_graph()[nid].source->attrs.name;
185+
auto source = ret.indexed_graph()[nid].source;
186+
if (new_device_map.count(source) == 0) {
187+
LOG(FATAL) << "canot find " << source;
168188
}
169-
new_device_vec[nid] = new_device_map.at(ret.indexed_graph()[nid].source);
189+
new_device_vec[nid] = new_device_map.at(source);
170190
}
171191
ret.attrs["device"] = std::make_shared<any>(std::move(new_device_vec));
172192
return ret;

0 commit comments

Comments
 (0)