diff --git a/include/dmlc/any.h b/include/dmlc/any.h index 5707a363db99..d32e44e48c4d 100644 --- a/include/dmlc/any.h +++ b/include/dmlc/any.h @@ -259,7 +259,8 @@ inline const std::type_info& any::type() const { template inline void any::check_type() const { CHECK(type_ != nullptr) - << "The any container is empty"; + << "The any container is empty" + << " requested=" << typeid(T).name(); CHECK(type_->ptype_info == &typeid(T)) << "The stored type mismatch" << " stored=" << type_->ptype_info->name() diff --git a/include/dmlc/parameter.h b/include/dmlc/parameter.h index 2fbab2a44e32..0222df60664d 100644 --- a/include/dmlc/parameter.h +++ b/include/dmlc/parameter.h @@ -57,6 +57,14 @@ class FieldEntry; // forward declare ParamManagerSingleton template struct ParamManagerSingleton; + +/*! \brief option in parameter initialization */ +enum ParamInitOption { + /*! \brief allow unknown parameters */ + kAllowUnknown, + /*! \brief need to match exact parameters */ + kAllMatch +}; } // namespace parameter /*! * \brief Information about a parameter field in string representations. @@ -108,13 +116,17 @@ struct Parameter { * and throw error if something wrong happens. * * \param kwargs map of keyword arguments, or vector of pairs + * \parma option The option on initialization. * \tparam Container container type * \throw ParamError when something go wrong. */ template - inline void Init(const Container &kwargs) { + inline void Init(const Container &kwargs, + parameter::ParamInitOption option = parameter::kAllowUnknown) { PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), NULL); + kwargs.begin(), kwargs.end(), + NULL, + option == parameter::kAllowUnknown); } /*! * \brief initialize the parameter by keyword arguments. @@ -130,7 +142,8 @@ struct Parameter { InitAllowUnknown(const Container &kwargs) { std::vector > unknown; PType::__MANAGER__()->RunInit(static_cast(this), - kwargs.begin(), kwargs.end(), &unknown); + kwargs.begin(), kwargs.end(), + &unknown, true); return unknown; } /*! @@ -355,7 +368,8 @@ class ParamManager { inline void RunInit(void *head, RandomAccessIterator begin, RandomAccessIterator end, - std::vector > *unknown_args) const { + std::vector > *unknown_args, + bool allow_unknown) const { std::set selected_args; for (RandomAccessIterator it = begin; it != end; ++it) { FieldAccessEntry *e = Find(it->first); @@ -367,11 +381,13 @@ class ParamManager { if (unknown_args != NULL) { unknown_args->push_back(*it); } else { - std::ostringstream os; - os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; - os << "----------------\n"; - PrintDocString(os); - throw dmlc::ParamError(os.str()); + if (!allow_unknown) { + std::ostringstream os; + os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n"; + os << "----------------\n"; + PrintDocString(os); + throw dmlc::ParamError(os.str()); + } } } } diff --git a/src/pass/place_device.cc b/src/pass/place_device.cc index 00d216512558..3fce4fa06d63 100644 --- a/src/pass/place_device.cc +++ b/src/pass/place_device.cc @@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) { const Op* copy_op = Op::Get(src.GetAttr("device_copy_op")); auto& device_assign_map = src.GetAttr("device_assign_map"); const IndexedGraph& idx = src.indexed_graph(); - DeviceVector device; // copy on write semanatics if (src.attrs.count("device") != 0) { @@ -79,10 +78,10 @@ Graph PlaceDevice(Graph src) { src.attrs["device"] = std::make_shared(std::move(device)); return src; } - std::map, NodePtr> copy_map; std::vector new_node_map(idx.num_nodes(), nullptr); std::unordered_map new_device_map; + static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); // insert copy node for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { @@ -90,6 +89,16 @@ Graph PlaceDevice(Graph src) { const auto& inode = idx[nid]; // check if mutation is needed bool need_mutate = false; + if (!inode.source->is_variable() && fmutate_inputs.count(inode.source->op())) { + for (uint32_t index : fmutate_inputs[inode.source->op()](inode.source->attrs)) { + auto e = inode.inputs[index]; + if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { + LOG(FATAL) << " mutable state cannot go across device" + << " op=" << inode.source->op()->name + << " input_state_index=" << index; + } + } + } for (const IndexedGraph::NodeEntry& e : inode.inputs) { if (new_node_map[e.node_id] != nullptr || dev_id != device[e.node_id]) { need_mutate = true; break; @@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) { } } } + if (inode.source->is_variable()) { + CHECK(!need_mutate) << "consistency check"; + } if (need_mutate) { NodePtr new_node = Node::Create(); new_node->attrs = inode.source->attrs; @@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) { os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); - copy_node->inputs.push_back(inode.source->inputs[i]); + if (new_node_map[e.node_id] != nullptr) { + copy_node->inputs.emplace_back( + NodeEntry{new_node_map[e.node_id], e.index, 0}); + } else { + copy_node->inputs.push_back(inode.source->inputs[i]); + } + if (copy_node->attrs.op->attr_parser != nullptr) { + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + } copy_map[copy_key] = copy_node; new_device_map[copy_node.get()] = dev_id; new_node->inputs.emplace_back( @@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) { if (new_node_map[e.node_id] != nullptr) { new_node->inputs.emplace_back( NodeEntry{new_node_map[e.node_id], e.index, 0}); - } else { + } else { new_node->inputs.push_back(inode.source->inputs[i]); } } @@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) { new_device_map[inode.source] = dev_id; } } - // make the new graph Graph ret; for (const NodeEntry& e : src.outputs) { @@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) { } DeviceVector new_device_vec(ret.indexed_graph().num_nodes()); for (uint32_t nid = 0; nid < ret.indexed_graph().num_nodes(); ++nid) { - if (new_device_map.count(ret.indexed_graph()[nid].source) == 0) { - LOG(INFO) << "canot find " << ret.indexed_graph()[nid].source->attrs.name; + auto source = ret.indexed_graph()[nid].source; + if (new_device_map.count(source) == 0) { + LOG(FATAL) << "canot find " << source; } - new_device_vec[nid] = new_device_map.at(ret.indexed_graph()[nid].source); + new_device_vec[nid] = new_device_map.at(source); } ret.attrs["device"] = std::make_shared(std::move(new_device_vec)); return ret;