Skip to content

Commit c3182af

Browse files
author
Florin-Gabriel Blanaru
committed
Add tests for supply
1 parent 428a598 commit c3182af

File tree

4 files changed

+185
-1
lines changed

4 files changed

+185
-1
lines changed

python/tvm/ir/supply.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def __init__(self, value=None):
6262
if value is None:
6363
name_supply = NameSupply("")
6464
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply)
65+
elif isinstance(value, NameSupply):
66+
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value)
6567
elif isinstance(value, (list, tvm.container.Array)):
6668
self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value)
6769
elif isinstance(value, IRModule):

src/ir/name_supply.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ String NameSupplyNode::ReserveName(const String& name, bool add_prefix) {
4040
if (add_prefix) {
4141
final_name = prefix_module_name(name);
4242
}
43-
name_map[final_name] = 1;
43+
name_map[final_name] = 0;
4444
return final_name;
4545
}
4646

tests/cpp/name_supply_test.cc

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <gtest/gtest.h>
21+
#include <tvm/ir/global_var_supply.h>
22+
#include <tvm/ir/module.h>
23+
#include <tvm/ir/name_supply.h>
24+
#include <tvm/relay/expr.h>
25+
#include <tvm/relay/function.h>
26+
27+
using namespace tvm;
28+
29+
NameSupply preambleNameSupply() {
30+
NameSupply name_supply = NameSupply("prefix");
31+
name_supply->FreshName("test");
32+
return name_supply;
33+
}
34+
35+
TEST(NameSupply, FreshName) {
36+
NameSupply name_supply = preambleNameSupply();
37+
String fresh = name_supply->FreshName("test");
38+
39+
EXPECT_EQ(fresh.compare("prefix_test1"), 0);
40+
}
41+
42+
TEST(NameSupply, ContainsName) {
43+
NameSupply name_supply = preambleNameSupply();
44+
45+
EXPECT_TRUE(name_supply->ContainsName("test"));
46+
EXPECT_FALSE(name_supply->ContainsName("test1"));
47+
}
48+
49+
TEST(NameSupply, ReserveName) {
50+
NameSupply name_supply = preambleNameSupply();
51+
name_supply->ReserveName("otherTest", false);
52+
53+
EXPECT_TRUE(name_supply->ContainsName("otherTest", false));
54+
EXPECT_FALSE(name_supply->ContainsName("otherTest"));
55+
}
56+
57+
GlobalVarSupply preambleVarSupply() {
58+
GlobalVarSupply global_var_supply = GlobalVarSupply();
59+
global_var_supply->FreshGlobal("test");
60+
return global_var_supply;
61+
}
62+
63+
TEST(GlobalVarSupply, FreshGlobal) {
64+
GlobalVarSupply global_var_supply = preambleVarSupply();
65+
GlobalVar first_var = global_var_supply->FreshGlobal("test");
66+
GlobalVar second_var = global_var_supply->FreshGlobal("test");
67+
68+
EXPECT_FALSE(tvm::StructuralEqual()(first_var, second_var));
69+
EXPECT_EQ(first_var->name_hint.compare("test1"), 0);
70+
EXPECT_EQ(second_var->name_hint.compare("test2"), 0);
71+
}
72+
73+
TEST(GlobalVarSupply, UniqueGlobalFor) {
74+
GlobalVarSupply global_var_supply = preambleVarSupply();
75+
GlobalVar first_var = global_var_supply->UniqueGlobalFor("someName");
76+
GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
77+
78+
EXPECT_TRUE(tvm::StructuralEqual()(first_var, second_var));
79+
EXPECT_EQ(first_var->name_hint.compare("someName"), 0);
80+
EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
81+
}
82+
83+
TEST(GlobalVarSupply, ReserveGlobal) {
84+
GlobalVarSupply global_var_supply = preambleVarSupply();
85+
GlobalVar var = GlobalVar("someName");
86+
global_var_supply->ReserveGlobalVar(var);
87+
GlobalVar second_var = global_var_supply->UniqueGlobalFor("someName");
88+
GlobalVar third_var = global_var_supply->FreshGlobal("someName");
89+
90+
EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
91+
EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
92+
EXPECT_EQ(second_var->name_hint.compare("someName"), 0);
93+
EXPECT_EQ(third_var->name_hint.compare("someName1"), 0);
94+
}
95+
96+
TEST(GlobalVarSupply, BuildIRModule) {
97+
auto x = relay::Var("x", relay::Type());
98+
auto f = relay::Function(tvm::Array<relay::Var>{x}, x, relay::Type(), {});
99+
GlobalVar var = GlobalVar("test");
100+
IRModule module = IRModule({{var, f}});
101+
102+
GlobalVarSupply global_var_supply = GlobalVarSupply(module);
103+
GlobalVar second_var = global_var_supply->UniqueGlobalFor("test", false);
104+
GlobalVar third_var = global_var_supply->FreshGlobal("test", false);
105+
106+
EXPECT_TRUE(tvm::StructuralEqual()(var, second_var));
107+
EXPECT_FALSE(tvm::StructuralEqual()(var, third_var));
108+
EXPECT_EQ(second_var->name_hint.compare("test"), 0);
109+
EXPECT_EQ(third_var->name_hint.compare("test1"), 0);
110+
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import tvm
18+
import tvm.testing
19+
20+
from tvm import relay
21+
from tvm.ir import GlobalVar, structural_equal
22+
from tvm.ir.supply import NameSupply
23+
from tvm.ir.supply import GlobalVarSupply
24+
25+
26+
def test_name_supply():
27+
name_supply = NameSupply("prefix")
28+
name_supply.reserve_name("test")
29+
30+
assert name_supply.contains_name("test")
31+
assert name_supply.fresh_name("test") == "prefix_test1"
32+
assert name_supply.contains_name("test1")
33+
assert not name_supply.contains_name("test1", False)
34+
assert not name_supply.contains_name("test2")
35+
36+
37+
def test_global_var_supply_from_none():
38+
var_supply = GlobalVarSupply()
39+
global_var = GlobalVar("test")
40+
var_supply.reserve_global(global_var)
41+
42+
assert structural_equal(var_supply.unique_global_for("test"), global_var)
43+
assert not structural_equal(var_supply.fresh_global("test"), global_var)
44+
45+
46+
def test_global_var_supply_from_name_supply():
47+
name_supply = NameSupply("prefix")
48+
var_supply = GlobalVarSupply(name_supply)
49+
global_var = GlobalVar("test")
50+
var_supply.reserve_global(global_var)
51+
52+
assert structural_equal(var_supply.unique_global_for("test", False), global_var)
53+
assert not structural_equal(var_supply.unique_global_for("test"), global_var)
54+
55+
56+
def test_global_var_supply_from_ir_mod():
57+
x = relay.var("x")
58+
y = relay.var("y")
59+
mod = tvm.IRModule()
60+
global_var = GlobalVar("test")
61+
mod[global_var] = relay.Function([x, y], relay.add(x, y))
62+
var_supply = GlobalVarSupply(mod)
63+
64+
second_global_var = var_supply.fresh_global("test", False)
65+
66+
assert structural_equal(var_supply.unique_global_for("test", False), global_var)
67+
assert not structural_equal(var_supply.unique_global_for("test"), global_var)
68+
assert not structural_equal(second_global_var, global_var)
69+
70+
71+
if __name__ == "__main__":
72+
tvm.testing.main()

0 commit comments

Comments
 (0)