19
19
import topi
20
20
import topi .testing
21
21
from topi import util
22
+ from common import get_all_backend
22
23
23
24
24
25
def test_util ():
@@ -59,8 +60,7 @@ def check_device(device):
59
60
foo (a , b )
60
61
tvm .testing .assert_allclose (b .asnumpy (), b_np , rtol = 1e-5 , atol = 1e-5 )
61
62
62
- for device in ['cuda' , 'opencl' , 'metal' , 'rocm' , 'vulkan' , 'llvm' , 'nvptx' , 'sdaccel' ,
63
- 'aocl_sw_emu' ]:
63
+ for device in get_all_backend ():
64
64
check_device (device )
65
65
66
66
@@ -77,6 +77,46 @@ def check_device(device):
77
77
test_apply (topi .sqrt , "sqrt" , np .sqrt , 0 , 100 )
78
78
test_apply (topi .rsqrt , "rsqrt" , lambda x :np .ones_like (x )/ np .sqrt (x ), 0 , 100 , skip_name_check = True )
79
79
80
+
81
+ def test_cast ():
82
+ def verify (from_dtype , to_dtype , low = - 100 , high = 100 ):
83
+ shape = (5 , 4 )
84
+ A = tvm .placeholder (shape , dtype = from_dtype , name = "A" )
85
+ B = topi .cast (A , to_dtype )
86
+
87
+ if from_dtype == "bool" :
88
+ a_np = np .random .choice ([True , False ], size = shape )
89
+ else :
90
+ a_np = np .random .uniform (low , high , size = shape ).astype (from_dtype )
91
+ if to_dtype == "bool" :
92
+ a_np = a_np - a_np [2 , 3 ]
93
+ b_np = a_np .astype (to_dtype )
94
+
95
+ for device in get_all_backend ():
96
+ ctx = tvm .context (device , 0 )
97
+ if not ctx .exist :
98
+ print ("Skip because %s is not enabled" % device )
99
+ continue
100
+ print ("Running on target: %s" % device )
101
+ with tvm .target .create (device ):
102
+ s = topi .generic .schedule_injective (B )
103
+ foo = tvm .build (s , [A , B ], device )
104
+ a = tvm .nd .array (a_np , ctx )
105
+ b = tvm .nd .empty (shape = shape , dtype = to_dtype , ctx = ctx )
106
+ foo (a , b )
107
+ tvm .testing .assert_allclose (b .asnumpy (), b_np )
108
+
109
+ verify ("int32" , "float32" )
110
+ verify ("int32" , "float64" )
111
+ verify ("int32" , "bool" )
112
+ verify ("float32" , "int32" )
113
+ verify ("float32" , "float64" )
114
+ verify ("float32" , "bool" )
115
+ verify ("bool" , "float32" )
116
+ verify ("bool" , "int32" )
117
+
118
+
80
119
if __name__ == "__main__" :
81
120
test_util ()
82
121
test_ewise ()
122
+ test_cast ()
0 commit comments