Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, module):
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]

def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -234,6 +235,19 @@ def load_params(self, params_bytes):
"""
self._load_params(bytearray(params_bytes))

def share_params(self, other, params_bytes):
"""Share parameters from pre-existing GraphRuntime instance.

Parameters
----------
other: GraphRuntime
The parent GraphRuntime from which this instance should share
it's parameters.
params_bytes : bytearray
The serialized parameter dict (used only for the parameter names).
"""
self._share_params(other.module, bytearray(params_bytes))

def __getitem__(self, key):
"""Get internal module function

Expand Down
34 changes: 34 additions & 0 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
}
}

void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here may need a 'CHECK' just like line 189 and 193 did.

size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
CHECK_EQ(data_entry_[eid].use_count(), 1);
data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for reuse variable purpose , here should be data_entry_[eid] = other.GetInput(in_idx).
one more question is 'is other.GetInputIndex(names[i]) always equal this->GetInputIndex(names[i]) ?', here i guess the logic should be other.GetInput(other.GetInputIndex(names[i])), if they are same we can reuse in_idx.

CHECK_GT(data_entry_[eid].use_count(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency, 205,206 208 may need a "message" just like other check did.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_entry_[eid].use_count() should get increase 1 after one reference, this check just check if the count > 1 seems like not logically complete. how about some logic like following?
int prev_idx = other.GetInput(in_idx).use_count;
...
CHECK_EQ(data_entry_[eid].use_count(), prev_idx + 1);

}
this->SetupOpExecs();
}

void GraphRuntime::SetupStorage() {
// Grab saved optimization plan from graph.
std::vector<TVMType> vtype;
Expand Down Expand Up @@ -372,6 +398,14 @@ PackedFunc GraphRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else if (name == "share_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
const auto& module = args[0].operator Module();
CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
const auto& param_blob = args[1].operator std::string();
dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
});
} else {
return PackedFunc();
}
Expand Down
17 changes: 13 additions & 4 deletions src/runtime/graph/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,19 @@ class GraphRuntime : public ModuleNode {
* \param param_blob A binary blob of parameter.
*/
void LoadParams(const std::string& param_blob);
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/

/*!
* \brief Share parameters from pre-existing GraphRuntime instance.
* \param other A GraphRuntime instance, previously with |LoadParams| called with the
* identical input |param_blob|.
* \param strm The input stream.
*/
void ShareParams(const GraphRuntime& other, dmlc::Stream* strm);

/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
uint32_t GetNumOfNodes() const {
return static_cast<uint32_t>(nodes_.size());
}
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_runtime_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,46 @@ def check_remote():
out = mod.get_output(0, out)
np.testing.assert_equal(out.asnumpy(), a + 1)

def check_sharing():
from tvm import relay
x = relay.var('x', shape=(1, 10))
y = relay.var('y', shape=(1, 10))
z = relay.add(x, y)
func = relay.Function([x, y], z)

x_in = np.ones((1, 10)).astype("float32")
params = {'x': x_in}
graph, lib, params = relay.build(func, target="llvm", params=params)

if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
mod_shared.load_params(relay.save_param_dict(params))
num_mods = 10
mods = [graph_runtime.create(graph, lib, tvm.cpu(0))
for _ in range(num_mods)]

for mod in mods:
mod.share_params(mod_shared, relay.save_param_dict(params))

a = np.random.uniform(size=(1, 10)).astype("float32")
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)

# Explicitly delete the shared module and verify correctness.
del mod_shared
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)
del mod

check_verify()
check_remote()
check_sharing()

if __name__ == "__main__":
test_graph_simple()