Skip to content

Commit 11f1280

Browse files
committed
[TEST][FLAKY] fix random fail
1 parent f577aa6 commit 11f1280

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/runtime/contrib/random/random.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args,
123123
entry->random_engine.RandomFill(out);
124124
});
125125

126+
TVM_REGISTER_GLOBAL("tvm.contrib.random.seed").set_body([](TVMArgs args, TVMRetValue* ret) {
127+
RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
128+
uint64_t seed = args[0];
129+
entry->random_engine.Seed(seed);
130+
});
131+
126132
} // namespace contrib
127133
} // namespace tvm

tests/python/contrib/test_random.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def enabled_ctx_list():
3838
def test_randint():
3939
m = 1024
4040
n = 1024
41+
if not tvm.get_global_func("tvm.contrib.random.seed", True):
42+
print("skip because extern function is not available")
43+
return
44+
seed = tvm.get_global_func("tvm.contrib.random.seed")
45+
seed(0)
4146
A = random.randint(-127, 128, size=(m, n), dtype='int32')
4247
s = te.create_schedule(A.op)
4348

@@ -62,6 +67,11 @@ def verify(target="llvm"):
6267
def test_uniform():
6368
m = 1024
6469
n = 1024
70+
if not tvm.get_global_func("tvm.contrib.random.seed", True):
71+
print("skip because extern function is not available")
72+
return
73+
seed = tvm.get_global_func("tvm.contrib.random.seed")
74+
seed(0)
6575
A = random.uniform(0, 1, size=(m, n))
6676
s = te.create_schedule(A.op)
6777

@@ -86,6 +96,11 @@ def verify(target="llvm"):
8696
def test_normal():
8797
m = 1024
8898
n = 1024
99+
if not tvm.get_global_func("tvm.contrib.random.seed", True):
100+
print("skip because extern function is not available")
101+
return
102+
seed = tvm.get_global_func("tvm.contrib.random.seed")
103+
seed(0)
89104
A = random.normal(3, 4, size=(m, n))
90105
s = te.create_schedule(A.op)
91106

0 commit comments

Comments
 (0)