@@ -144,7 +144,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
144
144
if groups == 1 :
145
145
if layout == "NCHW" :
146
146
assert kernel_layout == "OIHW"
147
- if data .dtype in ("int8" , "uint8" ) and kernel .dtype in ("int8" , "uint8" ):
147
+ if (
148
+ target .kind .name == "cuda"
149
+ and data .dtype in ("int8" , "uint8" )
150
+ and kernel .dtype in ("int8" , "uint8" )
151
+ ):
148
152
assert data .dtype == kernel .dtype
149
153
strategy .add_implementation (
150
154
wrap_compute_conv2d (topi .cuda .conv2d_nchw_int8 ),
@@ -293,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
293
297
"Unsupported shape for conv2d HWNC.\
294
298
Need to satisfy tensor core schedule."
295
299
)
296
- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
300
+ elif target . kind . name == "cuda" and layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
297
301
assert kernel_layout == "OIHW4o4i"
298
302
strategy .add_implementation (
299
303
wrap_compute_conv2d (topi .cuda .conv2d_NCHWc_int8 , True ),
@@ -353,7 +357,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
353
357
ic_chunk = in_channels // 4
354
358
355
359
if (
356
- data .dtype in ["int8" , "uint8" ]
360
+ target .kind .name == "cuda"
361
+ and data .dtype in ["int8" , "uint8" ]
357
362
and kernel .dtype in ["int8" , "uint8" ]
358
363
and channels % groups == 0
359
364
and out_channels % groups == 0
0 commit comments