Skip to content

Commit b53472c

Browse files
authored
Migrate C Interface API Generation to C++ (#9106)
Using the new name transformations added in #9088, the C interface API is now generated in C++ rather than in Python. This is intended to be a no-op for the actual users of this change and thus I've undone some of my overzealous sanitizing to match that expectation. Follow up PRs will clean up any remaining name transformation inconsistencies. Fixes #8792
1 parent c980db3 commit b53472c

File tree

12 files changed

+397
-117
lines changed

12 files changed

+397
-117
lines changed

python/tvm/micro/interface_api.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

python/tvm/micro/model_library_format.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import tarfile
2626
import typing
2727

28+
import tvm
2829
from tvm.ir.type import TupleType
2930
from .._ffi import get_global_func
30-
from .interface_api import generate_c_interface_header
3131
from ..contrib import utils
3232
from ..driver import build_module
3333
from ..runtime import ndarray as _nd
3434
from ..relay.backend import executor_factory
35+
from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
3536
from ..relay import param_dict
3637
from ..tir import expr
3738

@@ -43,6 +44,20 @@ class UnsupportedInModelLibraryFormatError(Exception):
4344
"""Raised when export_model_library_format does not support the given Module tree."""
4445

4546

47+
def generate_c_interface_header(module_name, inputs, outputs, include_path):
48+
"""Generate C Interface header to be included in MLF"""
49+
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
50+
metadata_header = os.path.join(include_path, f"{mangled_name}.h")
51+
52+
interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
53+
interface_c_module = interface_c_create(module_name, inputs, outputs)
54+
55+
with open(metadata_header, "w") as header_file:
56+
header_file.write(interface_c_module.get_source())
57+
58+
return metadata_header
59+
60+
4661
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
4762
"""Populate the codegen sub-directory as part of a Model Library Format export.
4863

python/tvm/relay/backend/name_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def to_c_variable_style(original_name: str):
4848
return _backend.ToCVariableStyle(original_name)
4949

5050

51+
def to_c_constant_style(original_name: str):
52+
"""Transform a name to the C constant style assuming it is
53+
appropriately constructed using the prefixing functions
54+
55+
Parameters
56+
----------
57+
original_name : str
58+
Original name to transform
59+
"""
60+
return _backend.ToCConstantStyle(original_name)
61+
62+
5163
def _preprocess_names(names: Union[List[str], str]):
5264
"""Preprocesses name strings into format for C++ functions
5365

src/relay/backend/name_transforms.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ std::string ToCVariableStyle(const std::string& original_name) {
6262
return variable_name;
6363
}
6464

65+
std::string ToCConstantStyle(const std::string& original_name) {
66+
ICHECK_EQ(original_name.find("TVM"), 0) << "Constant not TVM prefixed";
67+
std::string constant_name = ToCVariableStyle(original_name);
68+
69+
std::transform(constant_name.begin(), constant_name.end(), constant_name.begin(), ::toupper);
70+
return constant_name;
71+
}
72+
6573
std::string CombineNames(const Array<String>& names) {
6674
std::stringstream combine_stream;
6775
ICHECK(!names.empty()) << "Name segments empty";
@@ -79,22 +87,16 @@ std::string CombineNames(const Array<String>& names) {
7987
std::string SanitizeName(const std::string& name) {
8088
ICHECK(!name.empty()) << "Name is empty";
8189

82-
auto multipleSeparators = [](char before, char after) {
83-
return before == '_' && before == after;
84-
};
8590
auto isNotAlnum = [](char c) { return !std::isalnum(c); };
8691
std::string sanitized_input = name;
8792
std::replace_if(sanitized_input.begin(), sanitized_input.end(), isNotAlnum, '_');
8893

89-
sanitized_input.erase(
90-
std::unique(sanitized_input.begin(), sanitized_input.end(), multipleSeparators),
91-
sanitized_input.end());
92-
9394
return sanitized_input;
9495
}
9596

9697
TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
9798
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
99+
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
98100
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
99101
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
100102
TVM_REGISTER_GLOBAL("relay.backend.SanitizeName").set_body_typed(SanitizeName);

src/relay/backend/name_transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
* ToCVariableStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
3636
* // tvmgen_model_devices
3737
*
38+
* ToCConstantStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
39+
* // TVMGEN_MODEL_DEVICES
40+
*
3841
*/
3942

4043
#include <tvm/runtime/container/array.h>
@@ -68,6 +71,14 @@ std::string ToCFunctionStyle(const std::string& original_name);
6871
*/
6972
std::string ToCVariableStyle(const std::string& original_name);
7073

74+
/*!
75+
* \brief Transform a name to the C constant style assuming it is
76+
* appropriately constructed using the prefixing functions
77+
* \param name Original name
78+
* \return Transformed function in the C constant style
79+
*/
80+
std::string ToCConstantStyle(const std::string& original_name);
81+
7182
/*!
7283
* \brief Combine names together for use as a generated name
7384
* \param names Vector of strings to combine

src/target/source/interface_c.cc

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file interface_c.cc
22+
* \brief Generates a C interface header for a given modules inputs and outputs
23+
*/
24+
25+
#include <tvm/runtime/container/array.h>
26+
#include <tvm/runtime/container/string.h>
27+
#include <tvm/runtime/module.h>
28+
#include <tvm/runtime/packed_func.h>
29+
#include <tvm/runtime/registry.h>
30+
31+
#include <string>
32+
33+
#include "../../relay/backend/name_transforms.h"
34+
35+
namespace tvm {
36+
namespace codegen {
37+
38+
using runtime::PackedFunc;
39+
using namespace tvm::relay::backend;
40+
41+
class InterfaceCNode : public runtime::ModuleNode {
42+
public:
43+
InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> outputs)
44+
: module_name_(module_name), inputs_(inputs), outputs_(outputs) {}
45+
const char* type_key() const { return "h"; }
46+
47+
std::string GetSource(const std::string& format) final {
48+
std::stringstream code;
49+
50+
EmitUpperHeaderGuard(code);
51+
EmitBrief(code, "Input tensor pointers");
52+
EmitStruct(code, "inputs", inputs_);
53+
EmitBrief(code, "Output tensor pointers");
54+
EmitStruct(code, "outputs", outputs_);
55+
EmitRunFunction(code);
56+
EmitLowerHeaderGuard(code);
57+
58+
return code.str();
59+
}
60+
61+
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
62+
return PackedFunc(nullptr);
63+
}
64+
65+
private:
66+
void EmitUpperHeaderGuard(std::stringstream& code_stream) {
67+
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"}));
68+
code_stream << "#ifndef " << header_guard_name << "_\n"
69+
<< "#define " << header_guard_name << "_\n"
70+
<< "#include <stdint.h>\n\n"
71+
<< "#ifdef __cplusplus\n"
72+
<< "extern \"C\" {\n"
73+
<< "#endif\n\n";
74+
}
75+
76+
void EmitLowerHeaderGuard(std::stringstream& code_stream) {
77+
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"}));
78+
code_stream << "\n#ifdef __cplusplus\n"
79+
<< "}\n"
80+
<< "#endif\n\n"
81+
<< "#endif // " << header_guard_name << "_\n";
82+
}
83+
84+
void EmitBrief(std::stringstream& code_stream, const std::string& description) {
85+
code_stream << "/*!\n"
86+
<< " * \\brief " << description << " for TVM module \"" << module_name_ << "\" \n"
87+
<< " */\n";
88+
}
89+
90+
void EmitStruct(std::stringstream& code_stream, const std::string& suffix,
91+
Array<String> properties) {
92+
std::string struct_name = ToCVariableStyle(PrefixGeneratedName({module_name_, suffix}));
93+
code_stream << "struct " << struct_name << " {\n";
94+
95+
std::vector<std::string> sanitized_properties;
96+
for (const String& property : properties) {
97+
std::string sanitized_property = SanitizeName(property);
98+
ICHECK(std::find(sanitized_properties.begin(), sanitized_properties.end(),
99+
sanitized_property) == sanitized_properties.end())
100+
<< "Sanitized input tensor name clash" << sanitized_property;
101+
code_stream << " void* " << sanitized_property << ";\n";
102+
sanitized_properties.push_back(sanitized_property);
103+
}
104+
code_stream << "};\n\n";
105+
}
106+
107+
void EmitRunFunction(std::stringstream& code_stream) {
108+
std::string run_function = ToCVariableStyle(PrefixGeneratedName({module_name_, "run"}));
109+
std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"}));
110+
std::string outputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"}));
111+
112+
code_stream << "/*!\n"
113+
<< " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n"
114+
<< " * \\param inputs Input tensors for the module \n"
115+
<< " * \\param outputs Output tensors for the module \n"
116+
<< " */\n"
117+
<< "int32_t " << run_function << "(\n"
118+
<< " struct " << inputs_struct << "* inputs,\n"
119+
<< " struct " << outputs_struct << "* outputs\n"
120+
<< ");\n";
121+
}
122+
123+
std::string module_name_;
124+
Array<String> inputs_;
125+
Array<String> outputs_;
126+
};
127+
128+
runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
129+
Array<String> outputs) {
130+
auto n = make_object<InterfaceCNode>(module_name, inputs, outputs);
131+
return runtime::Module(n);
132+
}
133+
134+
TVM_REGISTER_GLOBAL("runtime.InterfaceCCreate").set_body_typed(InterfaceCCreate);
135+
136+
} // namespace codegen
137+
} // namespace tvm

tests/cpp/name_transforms_test.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ TEST(NameTransforms, ToCVariableStyle) {
4242
EXPECT_THROW(ToCVariableStyle(""), InternalError);
4343
}
4444

45+
TEST(NameTransforms, ToCConstantStyle) {
46+
ASSERT_EQ(ToCConstantStyle("TVM_Woof"), "TVM_WOOF");
47+
ASSERT_EQ(ToCConstantStyle("TVM_woof"), "TVM_WOOF");
48+
ASSERT_EQ(ToCConstantStyle("TVM_woof_Woof"), "TVM_WOOF_WOOF");
49+
EXPECT_THROW(ToCConstantStyle("Cake_Bakery"), InternalError); // Incorrect prefix
50+
EXPECT_THROW(ToCConstantStyle(""), InternalError);
51+
}
52+
4553
TEST(NameTransforms, PrefixName) {
4654
ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof");
4755
ASSERT_EQ(PrefixName({"woof"}), "TVM_woof");
@@ -71,10 +79,10 @@ TEST(NameTransforms, CombineNames) {
7179
}
7280

7381
TEST(NameTransforms, SanitizeName) {
74-
ASSERT_EQ(SanitizeName("+_+ "), "_");
82+
ASSERT_EQ(SanitizeName("+_+ "), "____");
7583
ASSERT_EQ(SanitizeName("input+"), "input_");
7684
ASSERT_EQ(SanitizeName("input-"), "input_");
77-
ASSERT_EQ(SanitizeName("input++"), "input_");
85+
ASSERT_EQ(SanitizeName("input++"), "input__");
7886
ASSERT_EQ(SanitizeName("woof:1"), "woof_1");
7987
EXPECT_THROW(SanitizeName(""), InternalError);
8088
}

0 commit comments

Comments
 (0)