-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Runtime] Allow for parameter sharing in GraphRuntime #3384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { | |
} | ||
} | ||
|
||
void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { | ||
uint64_t header, reserved; | ||
CHECK(strm->Read(&header)) | ||
<< "Invalid parameters file format"; | ||
CHECK(header == kTVMNDArrayListMagic) | ||
<< "Invalid parameters file format"; | ||
CHECK(strm->Read(&reserved)) | ||
<< "Invalid parameters file format"; | ||
std::vector<std::string> names; | ||
CHECK(strm->Read(&names)) << "Invalid parameters file format"; | ||
uint64_t sz; | ||
strm->Read(&sz); | ||
size_t size = static_cast<size_t>(sz); | ||
CHECK(size == names.size()) << "Invalid parameters file format"; | ||
for (size_t i = 0; i < size; ++i) { | ||
int in_idx = GetInputIndex(names[i]); | ||
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i]; | ||
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); | ||
CHECK_LT(eid, data_entry_.size()); | ||
CHECK_EQ(data_entry_[eid].use_count(), 1); | ||
data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for reuse variable purpose , here should be data_entry_[eid] = other.GetInput(in_idx). |
||
CHECK_GT(data_entry_[eid].use_count(), 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for consistency, 205,206 208 may need a "message" just like other check did. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. data_entry_[eid].use_count() should get increase 1 after one reference, this check just check if the count > 1 seems like not logically complete. how about some logic like following? |
||
} | ||
this->SetupOpExecs(); | ||
} | ||
|
||
void GraphRuntime::SetupStorage() { | ||
// Grab saved optimization plan from graph. | ||
std::vector<TVMType> vtype; | ||
|
@@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction( | |
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { | ||
this->LoadParams(args[0].operator std::string()); | ||
}); | ||
} else if (name == "share_params") { | ||
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { | ||
const auto& module = args[0].operator Module(); | ||
CHECK_EQ(module.operator->()->type_key(), "GraphRuntime"); | ||
const auto& param_blob = args[1].operator std::string(); | ||
dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob)); | ||
this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm); | ||
}); | ||
} else { | ||
return PackedFunc(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here may need a 'CHECK' just like line 189 and 193 did.