Skip to content

Commit fb27659

Browse files
committed
Add NameSupply and GlobalVarSupply
1 parent fc419df commit fb27659

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+707
-347
lines changed

include/tvm/driver/driver_api.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,15 @@ TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
9999
* \param args The arguments to the function.
100100
* \param name The name of the lowered function.
101101
* \param binds Buffer assignments.
102+
* \param global_var_supply The GlobalVarSupply to be used in the module.
102103
* \param simple_mode Disables the loop partition pass. Defaults to false.
103104
* \return The result module.
104105
*/
105106

106107
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
107108
const std::string& name,
108109
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
110+
GlobalVarSupply global_var_supply,
109111
bool simple_mode = false);
110112

111113
/*!
@@ -115,12 +117,14 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
115117
* \param args The arguments to the function (Array of Tensor, Buffer and Vars)
116118
* \param name The name of the lowered function.
117119
* \param binds Buffer assignments.
120+
* \param global_var_supply The GlobalVarSupply to be used in the module.
118121
* \param simple_mode Disables the loop partition pass. Defaults to false.
119122
* \return The result module.
120123
*/
121124
TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
122125
const std::string& name,
123126
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
127+
GlobalVarSupply global_var_supply,
124128
bool simple_mode = false);
125129

126130
/*!
@@ -130,10 +134,13 @@ TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
130134
* \param args The arguments to the function.
131135
* \param name The name of the lowered function.
132136
* \param binds Buffer assignments.
137+
* \param global_var_supply The GlobalVarSupply to be used in the module and when creating
138+
* GlobalVars.
133139
* \return The result module.
134140
*/
135141
IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
136-
const std::unordered_map<te::Tensor, tir::Buffer>& binds);
142+
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
143+
GlobalVarSupply global_var_supply);
137144
/*!
138145
* \brief Build a device and host module for a specific target from an IRModule.
139146
* \param funcs The functions to be built.

include/tvm/ir/global_var_supply.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_
21+
#define TVM_IR_GLOBAL_VAR_SUPPLY_H_
22+
23+
#include <string>
24+
#include <unordered_map>
25+
26+
#include "tvm/ir/expr.h"
27+
#include "tvm/ir/name_supply.h"
28+
29+
namespace tvm {
30+
31+
class GlobalVarSupplyNode : public Object {
32+
public:
33+
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}
34+
35+
explicit GlobalVarSupplyNode(NameSupply name_supply);
36+
37+
GlobalVar FreshGlobal(String name, bool add_prefix = true);
38+
39+
GlobalVar UniqueGlobalFor(const String& name, bool add_prefix = true);
40+
41+
void VisitAttrs(AttrVisitor* v) {
42+
v->Visit("name_supply", &name_supply_);
43+
}
44+
45+
NameSupply name_supply_;
46+
47+
static constexpr const char* _type_key = "GlobalVarSupply";
48+
static constexpr const bool _type_has_method_sequal_reduce = false;
49+
static constexpr const bool _type_has_method_shash_reduce = false;
50+
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object);
51+
52+
private:
53+
std::unordered_map<std::string, GlobalVar> name_to_var_map_;
54+
55+
friend class GlobalVarSupply;
56+
};
57+
58+
class GlobalVarSupply : public ObjectRef {
59+
public:
60+
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply =
61+
NameSupply::NameSupplyWithPrefix(""),
62+
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});
63+
64+
TVM_DLL static GlobalVarSupply GlobalVarSupplyFromNameSupply(const NameSupply& name_supply);
65+
66+
TVM_DLL static GlobalVarSupply EmptySupply();
67+
68+
explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
69+
/*! \return mutable pointers to the node. */
70+
GlobalVarSupplyNode* operator->() const {
71+
auto* ptr = get_mutable();
72+
ICHECK(ptr != nullptr);
73+
return static_cast<GlobalVarSupplyNode*>(ptr);
74+
}
75+
76+
TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarSupplyNode);
77+
};
78+
79+
} // namespace tvm
80+
81+
#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_

include/tvm/ir/module.h

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/ir/adt.h>
2828
#include <tvm/ir/expr.h>
2929
#include <tvm/ir/function.h>
30+
#include <tvm/ir/global_var_supply.h>
3031
#include <tvm/ir/type.h>
3132
#include <tvm/parser/source_map.h>
3233
#include <tvm/runtime/container/array.h>
@@ -64,6 +65,8 @@ class IRModuleNode : public Object {
6465
/* \brief Additional attributes storing meta-data about the module. */
6566
DictAttrs attrs;
6667

68+
GlobalVarSupply global_var_supply;
69+
6770
/*!
6871
* \brief Get a module attribute.
6972
*
@@ -125,6 +128,7 @@ class IRModuleNode : public Object {
125128
v->Visit("global_type_var_map_", &global_type_var_map_);
126129
v->Visit("source_map", &source_map);
127130
v->Visit("attrs", &attrs);
131+
v->Visit("global_var_supply", &global_var_supply);
128132
}
129133

130134
TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
@@ -323,14 +327,6 @@ class IRModuleNode : public Object {
323327
/*! \brief Helper function for registering a typedef's constructors */
324328
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);
325329

326-
/*!
327-
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
328-
*
329-
* \param name The original name.
330-
* \return Updated name which is unique.
331-
*/
332-
String GetUniqueName(const String& name);
333-
334330
/*! \brief A map from string names to global variables that
335331
* ensures global uniqueness.
336332
*/
@@ -362,12 +358,15 @@ class IRModule : public ObjectRef {
362358
/*!
363359
* \brief constructor
364360
* \param functions Functions in the module.
361+
* \param global_var_supply The GlobalVarSupply to be used in the module.
365362
* \param type_definitions Type definitions in the module.
366363
* \param import_set Set of imported files in the module.
367364
* \param map The module source map.
368365
* \param attrs The module attributes.
369366
*/
370367
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
368+
GlobalVarSupply global_var_supply =
369+
GlobalVarSupply::EmptySupply(),
371370
Map<GlobalTypeVar, TypeData> type_definitions = {},
372371
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
373372
DictAttrs attrs = {});
@@ -403,6 +402,7 @@ class IRModule : public ObjectRef {
403402
*
404403
* \param expr The expression to set as the main function to the module.
405404
* \param global_funcs The global function map. Default empty.
405+
* \param global_var_supply The GlobalVarSupply to be used in the module.
406406
* \param type_definitions The global type definition map. Default empty.
407407
* \param import_set Set of external modules already imported. Default empty.
408408
*
@@ -413,6 +413,7 @@ class IRModule : public ObjectRef {
413413
*/
414414
static std::pair<IRModule, GlobalVar> FromExprInContext(
415415
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
416+
GlobalVarSupply global_var_supply = GlobalVarSupply::EmptySupply(),
416417
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
417418
std::unordered_set<String> import_set = {});
418419

@@ -422,6 +423,8 @@ class IRModule : public ObjectRef {
422423
*/
423424
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
424425
const Map<GlobalVar, BaseFunc>& global_funcs = {},
426+
GlobalVarSupply global_var_supply =
427+
GlobalVarSupply::EmptySupply(),
425428
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
426429

427430
/*!

include/tvm/ir/name_supply.h

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_IR_NAME_SUPPLY_H_
21+
#define TVM_IR_NAME_SUPPLY_H_
22+
23+
#include <string>
24+
#include <unordered_map>
25+
#include "tvm/ir/expr.h"
26+
27+
namespace tvm {
28+
29+
class NameSupplyNode : public Object {
30+
public:
31+
NameSupplyNode() : NameSupplyNode("") {}
32+
33+
explicit NameSupplyNode(const String& prefix);
34+
35+
String FreshName(const String& name, bool add_prefix = true);
36+
37+
String ReserveName(const String& name, bool add_prefix = true);
38+
39+
bool ContainsName(const String& name, bool add_prefix = true);
40+
41+
void Clear();
42+
43+
void VisitAttrs(AttrVisitor* v) {
44+
v->Visit("prefix", &prefix_);
45+
}
46+
47+
// Prefix for all GlobalVar names. It can be empty.
48+
std::string prefix_;
49+
50+
static constexpr const char* _type_key = "NameSupply";
51+
static constexpr const bool _type_has_method_sequal_reduce = false;
52+
static constexpr const bool _type_has_method_shash_reduce = false;
53+
TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object);
54+
55+
private:
56+
String prefix_module_name(const String& name);
57+
58+
std::string GetUniqueName(std::string name);
59+
60+
// Key is function_name. Value is a counter.
61+
std::unordered_map<std::string, int> name_map;
62+
63+
friend class NameSupply;
64+
};
65+
66+
class NameSupply : public ObjectRef {
67+
public:
68+
TVM_DLL NameSupply();
69+
70+
TVM_DLL explicit NameSupply(const String& prefix,
71+
std::unordered_map<std::string, int> name_map = {});
72+
73+
TVM_DLL static NameSupply NameSupplyWithPrefix(const String& prefix = "");
74+
75+
TVM_DLL static NameSupply EmptySupply();
76+
77+
explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
78+
/*! \return mutable pointers to the node. */
79+
NameSupplyNode* operator->() const {
80+
auto* ptr = get_mutable();
81+
ICHECK(ptr != nullptr);
82+
return static_cast<NameSupplyNode*>(ptr);
83+
}
84+
85+
TVM_DEFINE_OBJECT_REF_COW_METHOD(NameSupplyNode);
86+
};
87+
88+
} // namespace tvm
89+
90+
#endif // TVM_IR_NAME_SUPPLY_H_

include/tvm/relay/interpreter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,15 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
181181
*
182182
* \param expr An expression to evaluate.
183183
* \param type_definitions Global type definitions which \p expr may references.
184+
* \param global_var_supply The GlobalVarSupply to be used during evaluation.
184185
* \param import_set Already imported external modules.
185186
* \param device The device on which all primitives will be executed.
186187
* \param target The compiler target flag for compiling primitives.
187188
* \param attrs Attributes for the expression to be evaluated with
188189
* @return The object representing the result.
189190
*/
190191
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
192+
GlobalVarSupply global_var_supply,
191193
std::unordered_set<String> import_set, Device device, Target target,
192194
Map<String, ObjectRef> attrs = {});
193195

python/tvm/ir/module.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"""IRModule that holds the functions and type definitions."""
1818
from tvm._ffi.base import string_types
1919
import tvm._ffi
20+
from tvm.ir.supply import GlobalVarSupply
2021

2122
from .base import Node
2223
from . import expr as _expr
@@ -36,7 +37,7 @@ class IRModule(Node):
3637
Map of global var to BaseFunc
3738
"""
3839

39-
def __init__(self, functions=None, type_definitions=None):
40+
def __init__(self, functions=None, type_definitions=None, globar_var_supply=None):
4041
if functions is None:
4142
functions = {}
4243
elif isinstance(functions, dict):
@@ -59,7 +60,11 @@ def __init__(self, functions=None, type_definitions=None):
5960
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
6061
mapped_type_defs[k] = v
6162
type_definitions = mapped_type_defs
62-
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
63+
if globar_var_supply is None:
64+
globar_var_supply = GlobalVarSupply()
65+
self.__init_handle_by_constructor__(
66+
_ffi_api.IRModule, functions, type_definitions, globar_var_supply
67+
)
6368

6469
def __setitem__(self, var, val):
6570
"""Add a mapping to the module.
@@ -217,7 +222,7 @@ def get_type(self, name):
217222
return tuple([ty_var] + list(ty_data.constructors))
218223

219224
@staticmethod
220-
def from_expr(expr, functions=None, type_defs=None):
225+
def from_expr(expr, functions=None, type_defs=None, global_var_supply=None):
221226
"""Construct a module from a standalone expression.
222227
223228
Parameters
@@ -238,9 +243,12 @@ def from_expr(expr, functions=None, type_defs=None):
238243
where expr is set as the entry point
239244
(wrapped in a function if necessary)
240245
"""
246+
global_var_supply = (
247+
global_var_supply if global_var_supply is not None else GlobalVarSupply()
248+
)
241249
funcs = functions if functions is not None else {}
242250
defs = type_defs if type_defs is not None else {}
243-
return _ffi_api.Module_FromExpr(expr, funcs, defs)
251+
return _ffi_api.Module_FromExpr(expr, funcs, global_var_supply, defs)
244252

245253
def _import(self, file_to_import):
246254
return _ffi_api.Module_Import(self, file_to_import)

0 commit comments

Comments
 (0)