Skip to content

Commit 84ae2ef

Browse files
author
Florin-Gabriel Blanaru
committed
Construct GlobalVarSupply from IRModule
1 parent fa25ace commit 84ae2ef

23 files changed

+84
-141
lines changed

include/tvm/driver/driver_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#ifndef TVM_DRIVER_DRIVER_API_H_
3030
#define TVM_DRIVER_DRIVER_API_H_
3131

32+
#include <tvm/ir/global_var_supply.h>
3233
#include <tvm/ir/module.h>
3334
#include <tvm/ir/transform.h>
3435
#include <tvm/runtime/packed_func.h>

include/tvm/ir/global_var_supply.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <unordered_map>
2525

2626
#include "tvm/ir/expr.h"
27+
#include "tvm/ir/module.h"
2728
#include "tvm/ir/name_supply.h"
2829

2930
namespace tvm {
@@ -57,13 +58,12 @@ class GlobalVarSupplyNode : public Object {
5758

5859
class GlobalVarSupply : public ObjectRef {
5960
public:
60-
TVM_DLL explicit GlobalVarSupply(
61-
const NameSupply& name_supply = NameSupply::NameSupplyWithPrefix(""),
62-
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});
61+
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(),
62+
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});
6363

64-
TVM_DLL static GlobalVarSupply GlobalVarSupplyFromNameSupply(const NameSupply& name_supply);
64+
TVM_DLL explicit GlobalVarSupply(const Array<IRModule>& modules);
6565

66-
TVM_DLL static GlobalVarSupply EmptySupply();
66+
TVM_DLL explicit GlobalVarSupply(const IRModule module);
6767

6868
explicit GlobalVarSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
6969
/*! \return mutable pointers to the node. */

include/tvm/ir/module.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include <tvm/ir/adt.h>
2828
#include <tvm/ir/expr.h>
2929
#include <tvm/ir/function.h>
30-
#include <tvm/ir/global_var_supply.h>
3130
#include <tvm/ir/type.h>
3231
#include <tvm/parser/source_map.h>
3332
#include <tvm/runtime/container/array.h>

include/tvm/ir/name_supply.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,11 @@ class NameSupplyNode : public Object {
6464

6565
class NameSupply : public ObjectRef {
6666
public:
67-
TVM_DLL NameSupply();
67+
TVM_DLL explicit NameSupply();
6868

6969
TVM_DLL explicit NameSupply(const String& prefix,
7070
std::unordered_map<std::string, int> name_map = {});
7171

72-
TVM_DLL static NameSupply NameSupplyWithPrefix(const String& prefix = "");
73-
74-
TVM_DLL static NameSupply EmptySupply();
75-
7672
explicit NameSupply(ObjectPtr<Object> n) : ObjectRef(n) {}
7773
/*! \return mutable pointers to the node. */
7874
NameSupplyNode* operator->() const {

python/tvm/ir/supply.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
"""Suppliers that are used to guarantee uniqueness of names and GlobalVars."""
1818
import tvm
19-
from tvm import Object
19+
from tvm import Object, IRModule
2020
from . import _ffi_api
2121

2222

@@ -38,6 +38,12 @@ def fresh_name(self, name, add_prefix=True):
3838
def reserve_name(self, name, add_prefix=True):
3939
return _ffi_api.NameSupply_ReserveName(self, name, add_prefix)
4040

41+
def contains_name(self, name, add_prefix=True):
42+
return _ffi_api.NameSupply_ContainsName(self, name, add_prefix)
43+
44+
def clear(self):
45+
return _ffi_api.NameSupply_Clear(self)
46+
4147

4248
@tvm._ffi.register_object("GlobalVarSupply")
4349
class GlobalVarSupply(Object):
@@ -48,15 +54,24 @@ class GlobalVarSupply(Object):
4854
4955
Parameters
5056
----------
51-
name_supply: The NameSupply to be used by this GlobalVarSupply.
57+
value: Union[List[IRModule], IRModule, NameSupply]
58+
The IRModules used to build this GlobalVarSupply or a NameSupply.
5259
"""
5360

54-
def __init__(self, name_supply=None):
55-
name_supply = name_supply if name_supply is not None else NameSupply("")
56-
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply, name_supply)
61+
def __init__(self, value=None):
62+
if value is None:
63+
name_supply = NameSupply("")
64+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply)
65+
elif isinstance(value, (list, tvm.container.Array)):
66+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
67+
elif isinstance(value, IRModule):
68+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value)
5769

5870
def fresh_global(self, name, add_prefix=True):
5971
return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix)
6072

6173
def unique_global_for(self, name, add_prefix=True):
6274
return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix)
75+
76+
def reserve_global(self, global_var, allow_conflict=False):
77+
return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict)

src/auto_scheduler/feature.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/auto_scheduler/measure.h>
2828
#include <tvm/auto_scheduler/measure_record.h>
2929
#include <tvm/driver/driver_api.h>
30+
#include <tvm/ir/global_var_supply.h>
3031
#include <tvm/runtime/registry.h>
3132
#include <tvm/support/parallel_for.h>
3233
#include <tvm/te/operation.h>
@@ -1371,8 +1372,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
13711372
auto pass_ctx = tvm::transform::PassContext::Current();
13721373

13731374
auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
1374-
std::unordered_map<te::Tensor, te::Buffer>(),
1375-
GlobalVarSupply::EmptySupply());
1375+
std::unordered_map<te::Tensor, te::Buffer>(), GlobalVarSupply());
13761376

13771377
bool disable_vectorize =
13781378
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();

src/contrib/hybrid/codegen_hybrid.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#ifndef TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
2525
#define TVM_CONTRIB_HYBRID_CODEGEN_HYBRID_H_
2626

27+
#include <tvm/ir/name_supply.h>
2728
#include <tvm/target/codegen.h>
2829
#include <tvm/te/operation.h>
2930
#include <tvm/te/schedule.h>
@@ -146,7 +147,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
146147
/*! \brief Print the current indent spaces. */
147148
inline void PrintIndent();
148149
/*! \brief NameSupply for allocated ids. */
149-
NameSupply ids_allocated = NameSupply::EmptySupply();
150+
NameSupply ids_allocated = NameSupply();
150151
/*!
151152
* \brief Keys are either (tensors, value_index) or (variables, 0).
152153
* Values are the corresponding IDs.*/

src/driver/driver_api.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
303303
c_binds.insert({kv.first, kv.second});
304304
}
305305
}
306-
IRModule mod =
307-
ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply::EmptySupply());
306+
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply());
308307
return mod;
309308
});
310309

@@ -367,8 +366,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
367366
c_binds.insert({kv.first, kv.second});
368367
}
369368
}
370-
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply::EmptySupply(),
371-
simple_mode);
369+
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode);
372370
});
373371

374372
/**

src/ir/global_var_supply.cc

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,32 @@
2626
#include "tvm/ir/expr.h"
2727

2828
namespace tvm {
29-
3029
GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply,
3130
std::unordered_map<std::string, GlobalVar> name_to_var_map) {
3231
auto n = make_object<GlobalVarSupplyNode>(name_supply);
3332
n->name_to_var_map_ = std::move(name_to_var_map);
3433
data_ = std::move(n);
3534
}
3635

37-
GlobalVarSupply GlobalVarSupply::GlobalVarSupplyFromNameSupply(const NameSupply& name_supply) {
38-
auto global_var_supply = GlobalVarSupply(name_supply);
39-
return global_var_supply;
36+
std::string GetModuleName(const IRModule& module) {
37+
return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
4038
}
4139

42-
GlobalVarSupply GlobalVarSupply::EmptySupply() {
43-
return GlobalVarSupplyFromNameSupply(NameSupply::NameSupplyWithPrefix(""));
40+
GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply() {
41+
if (!modules.empty()) {
42+
IRModule first_mod = modules.front();
43+
this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
44+
}
45+
for (auto& mod : modules) {
46+
for (auto kv : mod->functions) {
47+
this->operator->()->ReserveGlobalVar(kv.first);
48+
}
49+
}
4450
}
4551

52+
GlobalVarSupply::GlobalVarSupply(const IRModule module)
53+
: GlobalVarSupply(Array<IRModule>{module}) {}
54+
4655
void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) {
4756
name_supply_->ReserveName(var->name_hint, false);
4857
if (!allow_conflict) {
@@ -79,8 +88,15 @@ GlobalVar GlobalVarSupplyNode::FreshGlobal(String name, bool add_prefix) {
7988

8089
TVM_REGISTER_NODE_TYPE(GlobalVarSupplyNode);
8190

82-
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply").set_body_typed([](NameSupply name_supply) {
83-
return GlobalVarSupply(name_supply);
91+
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_NameSupply")
92+
.set_body_typed([](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); });
93+
94+
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModule").set_body_typed([](IRModule mod) {
95+
return GlobalVarSupply(std::move(mod));
96+
});
97+
98+
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_IRModules").set_body_typed([](const Array<IRModule>& mods) {
99+
return GlobalVarSupply(mods);
84100
});
85101

86102
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
@@ -89,4 +105,7 @@ TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_FreshGlobal")
89105
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_UniqueGlobalFor")
90106
.set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::UniqueGlobalFor);
91107

108+
TVM_REGISTER_GLOBAL("ir.GlobalVarSupply_ReserveGlobalVar")
109+
.set_body_method<GlobalVarSupply>(&GlobalVarSupplyNode::ReserveGlobalVar);
110+
92111
} // namespace tvm

src/ir/module.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#include <sstream>
4141
#include <unordered_set>
4242

43-
#include "../relay/backend/supply_provider.h"
43+
#include "tvm/ir/global_var_supply.h"
4444

4545
namespace tvm {
4646

@@ -386,7 +386,7 @@ std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
386386
}
387387

388388
GlobalVar main_gv;
389-
auto global_var_supply = tvm::BuildGlobalVarSupply(mod);
389+
auto global_var_supply = GlobalVarSupply(mod);
390390
if (gv_name.empty()) {
391391
// Bind function to 'main' (though rename if would clash with existing 'main').
392392
main_gv = global_var_supply->FreshGlobal("main", false);

0 commit comments

Comments
 (0)