Skip to content

Commit

Permalink
[Runtime] Make CSourceModule and StaticLibraryModule Binary Seria…
Browse files Browse the repository at this point in the history
…lizable (#15693)

make csource module and static libary module binary serializable
  • Loading branch information
sunggg authored Sep 8, 2023
1 parent 738c2e9 commit 666bd14
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 12 deletions.
30 changes: 28 additions & 2 deletions src/runtime/static_library.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,39 @@ class StaticLibraryNode final : public runtime::ModuleNode {
}
}

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(data_);
std::vector<std::string> func_names;
for (const auto func_name : func_names_) func_names.push_back(func_name);
stream->Write(func_names);
}

static Module LoadFromBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
auto n = make_object<StaticLibraryNode>();
// load data
std::string data;
ICHECK(stream->Read(&data)) << "Loading data failed";
n->data_ = std::move(data);

// load func names
std::vector<std::string> func_names;
ICHECK(stream->Read(&func_names)) << "Loading func names failed";
for (auto func_name : func_names) n->func_names_.push_back(String(func_name));

return Module(n);
}

void SaveToFile(const String& file_name, const String& format) final {
VLOG(0) << "Saving static library of " << data_.size() << " bytes implementing " << FuncNames()
<< " to '" << file_name << "'";
SaveBinaryToFile(file_name, data_);
}

// TODO(tvm-team): Make this module serializable
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const override { return ModulePropertyMask::kDSOExportable; }
int GetPropertyMask() const override {
return runtime::ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kDSOExportable;
}

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand Down Expand Up @@ -103,6 +127,8 @@ Module LoadStaticLibrary(const std::string& filename, Array<String> func_names)
}

TVM_REGISTER_GLOBAL("runtime.ModuleLoadStaticLibrary").set_body_typed(LoadStaticLibrary);
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_static_library")
.set_body_typed(StaticLibraryNode::LoadFromBinary);

} // namespace runtime
} // namespace tvm
3 changes: 1 addition & 2 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class ModuleSerializer {
// we will not produce import_tree_.
bool has_import_tree = true;

if (mod_->IsDSOExportable()) {
ICHECK(export_dso) << "`export_dso` should be enabled for DSOExportable modules";
if (export_dso) {
has_import_tree = !mod_->imports().empty();
}

Expand Down
41 changes: 40 additions & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,39 @@ class CSourceModuleNode : public runtime::ModuleNode {

String GetFormat() override { return fmt_; }

void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(code_);
stream->Write(fmt_);

std::vector<std::string> func_names;
for (const auto func_name : func_names_) func_names.push_back(func_name);
std::vector<std::string> const_vars;
for (auto const_var : const_vars_) const_vars.push_back(const_var);
stream->Write(func_names);
stream->Write(const_vars);
}

static runtime::Module LoadFromBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);

std::string code, fmt;
ICHECK(stream->Read(&code)) << "Loading code failed";
ICHECK(stream->Read(&fmt)) << "Loading format failed";

std::vector<std::string> tmp_func_names, tmp_const_vars;
CHECK(stream->Read(&tmp_func_names)) << "Loading func names failed";
CHECK(stream->Read(&tmp_const_vars)) << "Loading const vars failed";

Array<String> func_names;
for (auto func_name : tmp_func_names) func_names.push_back(String(func_name));

Array<String> const_vars;
for (auto const_var : tmp_const_vars) const_vars.push_back(String(const_var));

auto n = make_object<CSourceModuleNode>(code, fmt, func_names, const_vars);
return runtime::Module(n);
}

void SaveToFile(const String& file_name, const String& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
Expand All @@ -130,7 +163,10 @@ class CSourceModuleNode : public runtime::ModuleNode {
}
}

int GetPropertyMask() const override { return runtime::ModulePropertyMask::kDSOExportable; }
int GetPropertyMask() const override {
return runtime::ModulePropertyMask::kBinarySerializable |
runtime::ModulePropertyMask::kDSOExportable;
}

bool ImplementsFunction(const String& name, bool query_imports) final {
return std::find(func_names_.begin(), func_names_.end(), name) != func_names_.end();
Expand All @@ -151,6 +187,9 @@ runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
return runtime::Module(n);
}

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_c")
.set_body_typed(CSourceModuleNode::LoadFromBinary);

/*!
* \brief A concrete class to get access to base methods of CodegenSourceBase.
*
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_roundtrip_runtime_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@


def test_csource_module():
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], None)
# source module that is not binary serializable.
# Thus, it would raise an error.
assert not mod.is_binary_serializable
with pytest.raises(TVMError):
tvm.ir.load_json(tvm.ir.save_json(mod))
mod = tvm.runtime._ffi_api.CSourceModuleCreate("", "cc", [], [])
assert mod.type_key == "c"
assert mod.is_binary_serializable
new_mod = tvm.ir.load_json(tvm.ir.save_json(mod))
assert new_mod.type_key == "c"
assert new_mod.is_binary_serializable


def test_aot_module():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_runtime_module_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def create_aot_module():
def test_property():
checker(
create_csource_module(),
expected={"is_binary_serializable": False, "is_runnable": False, "is_dso_exportable": True},
expected={"is_binary_serializable": True, "is_runnable": False, "is_dso_exportable": True},
)

checker(
Expand Down

0 comments on commit 666bd14

Please sign in to comment.