@@ -38,6 +38,11 @@ def enabled_ctx_list():
38
38
def test_randint ():
39
39
m = 1024
40
40
n = 1024
41
+ if not tvm .get_global_func ("tvm.contrib.random.seed" , True ):
42
+ print ("skip because extern function is not available" )
43
+ return
44
+ seed = tvm .get_global_func ("tvm.contrib.random.seed" )
45
+ seed (0 )
41
46
A = random .randint (- 127 , 128 , size = (m , n ), dtype = 'int32' )
42
47
s = te .create_schedule (A .op )
43
48
@@ -62,6 +67,11 @@ def verify(target="llvm"):
62
67
def test_uniform ():
63
68
m = 1024
64
69
n = 1024
70
+ if not tvm .get_global_func ("tvm.contrib.random.seed" , True ):
71
+ print ("skip because extern function is not available" )
72
+ return
73
+ seed = tvm .get_global_func ("tvm.contrib.random.seed" )
74
+ seed (0 )
65
75
A = random .uniform (0 , 1 , size = (m , n ))
66
76
s = te .create_schedule (A .op )
67
77
@@ -86,6 +96,11 @@ def verify(target="llvm"):
86
96
def test_normal ():
87
97
m = 1024
88
98
n = 1024
99
+ if not tvm .get_global_func ("tvm.contrib.random.seed" , True ):
100
+ print ("skip because extern function is not available" )
101
+ return
102
+ seed = tvm .get_global_func ("tvm.contrib.random.seed" )
103
+ seed (0 )
89
104
A = random .normal (3 , 4 , size = (m , n ))
90
105
s = te .create_schedule (A .op )
91
106
0 commit comments