@@ -144,7 +144,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
144144 if groups == 1 :
145145 if layout == "NCHW" :
146146 assert kernel_layout == "OIHW"
147- if data .dtype in ("int8" , "uint8" ) and kernel .dtype in ("int8" , "uint8" ):
147+ if (
148+ target .kind .name == "cuda"
149+ and data .dtype in ("int8" , "uint8" )
150+ and kernel .dtype in ("int8" , "uint8" )
151+ ):
148152 assert data .dtype == kernel .dtype
149153 strategy .add_implementation (
150154 wrap_compute_conv2d (topi .cuda .conv2d_nchw_int8 ),
@@ -293,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
293297 "Unsupported shape for conv2d HWNC.\
294298 Need to satisfy tensor core schedule."
295299 )
296- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
300+ elif target . kind . name == "cuda" and layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
297301 assert kernel_layout == "OIHW4o4i"
298302 strategy .add_implementation (
299303 wrap_compute_conv2d (topi .cuda .conv2d_NCHWc_int8 , True ),
@@ -353,7 +357,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
353357 ic_chunk = in_channels // 4
354358
355359 if (
356- data .dtype in ["int8" , "uint8" ]
360+ target .kind .name == "cuda"
361+ and data .dtype in ["int8" , "uint8" ]
357362 and kernel .dtype in ["int8" , "uint8" ]
358363 and channels % groups == 0
359364 and out_channels % groups == 0
0 commit comments