Skip to content

Commit 9496a19

Browse files
masahiylc
authored andcommitted
disable cuda int8 schedule for non-cuda gpu target (apache#9014)
1 parent 17114c9 commit 9496a19

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
144144
if groups == 1:
145145
if layout == "NCHW":
146146
assert kernel_layout == "OIHW"
147-
if data.dtype in ("int8", "uint8") and kernel.dtype in ("int8", "uint8"):
147+
if (
148+
target.kind.name == "cuda"
149+
and data.dtype in ("int8", "uint8")
150+
and kernel.dtype in ("int8", "uint8")
151+
):
148152
assert data.dtype == kernel.dtype
149153
strategy.add_implementation(
150154
wrap_compute_conv2d(topi.cuda.conv2d_nchw_int8),
@@ -293,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
293297
"Unsupported shape for conv2d HWNC.\
294298
Need to satisfy tensor core schedule."
295299
)
296-
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
300+
elif target.kind.name == "cuda" and layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
297301
assert kernel_layout == "OIHW4o4i"
298302
strategy.add_implementation(
299303
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
@@ -353,7 +357,8 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
353357
ic_chunk = in_channels // 4
354358

355359
if (
356-
data.dtype in ["int8", "uint8"]
360+
target.kind.name == "cuda"
361+
and data.dtype in ["int8", "uint8"]
357362
and kernel.dtype in ["int8", "uint8"]
358363
and channels % groups == 0
359364
and out_channels % groups == 0

tests/python/relay/test_op_level2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,6 @@ def test_run(
325325
kernel_size,
326326
):
327327
target = tvm.target.Target(target)
328-
if target.kind.name == "vulkan" and dtype == "int8":
329-
# The schedule selection incorrectly picks an
330-
# implementation that requires NCHWc packed input.
331-
pytest.xfail("Known failing test for vulkan")
332328

333329
x = relay.var("x", shape=dshape, dtype=dtype)
334330
w = relay.var("w", shape=kshape, dtype=dtype)

0 commit comments

Comments
 (0)