Skip to content

Commit 0f20125

Browse files
committed
fix miopen pad
1 parent 9687307 commit 0f20125

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

python/tvm/relay/op/strategy/rocm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
3636
layout = attrs.data_layout
3737
stride_h, stride_w = attrs.get_int_tuple("strides")
3838
kernel_layout = attrs.kernel_layout
39+
padding = attrs.get_int_tuple("padding")
3940
if dilation_h < 1 or dilation_w < 1:
4041
raise ValueError("dilation should be positive value")
4142

@@ -77,7 +78,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
7778
else:
7879
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
7980
# add miopen implementation
80-
if "miopen" in target.libs and layout == "NCHW":
81+
if "miopen" in target.libs and layout == "NCHW" and padding[0] == padding[2] and \
82+
padding[1] == padding[3]:
8183
strategy.add_implementation(
8284
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
8385
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),

topi/python/topi/rocm/conv2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
6666
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
6767
pad_h, pad_w = pt + pb, pl + pr
6868
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
69-
69+
assert (pt == pb) and (pl == pr)
7070
OH = (H + 2 * pad_h - KH) // stride_h + 1
7171
OW = (W + 2 * pad_w - KW) // stride_w + 1
7272
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
@@ -76,8 +76,8 @@ def conv2d_nchw_miopen(cfg, data, kernel, strides, padding, dilation,
7676
kernel,
7777
stride_h,
7878
stride_w,
79-
pad_h,
80-
pad_w,
79+
pt,
80+
pl,
8181
dilation_h,
8282
dilation_w,
8383
conv_mode=0,

0 commit comments

Comments
 (0)