Skip to content

Commit b121278

Browse files
author
Matthew Brookhart
authored
int32 pooling with int64 shapes (#6687)
* Failing tests for Int32 avg_pooling with Int64 shapes * fix pooling implementations
1 parent c7ff885 commit b121278

File tree

4 files changed

+133
-93
lines changed

4 files changed

+133
-93
lines changed

include/tvm/topi/nn/pooling.h

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
7575
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
7676
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
7777

78-
auto height = x->shape[height_axis];
79-
auto width = x->shape[width_axis];
78+
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
79+
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
8080

8181
auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
8282
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
@@ -107,6 +107,9 @@ inline Tensor pool_impl(const Tensor& x, const Array<PrimExpr>& kernel_size,
107107
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
108108

109109
Array<PrimExpr> out_shape = x->shape;
110+
for (size_t i = 0; i < out_shape.size(); ++i) {
111+
out_shape.Set(i, cast(DataType::DataType::Int(32), out_shape[i]));
112+
}
110113
out_shape.Set(height_axis, out_height);
111114
out_shape.Set(width_axis, out_width);
112115

@@ -189,8 +192,8 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
189192
auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
190193
auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
191194

192-
auto height = x->shape[height_axis];
193-
auto width = x->shape[width_axis];
195+
auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
196+
auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
194197

195198
auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
196199
auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
@@ -220,7 +223,12 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
220223
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height));
221224
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width));
222225

223-
Array<PrimExpr> out_shape = x->shape;
226+
Array<PrimExpr> data_shape = x->shape;
227+
for (size_t i = 0; i < data_shape.size(); ++i) {
228+
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
229+
}
230+
231+
Array<PrimExpr> out_shape = data_shape;
224232
out_shape.Set(height_axis, out_height);
225233
out_shape.Set(width_axis, out_width);
226234

@@ -232,7 +240,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
232240
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
233241

234242
if (pool_type == kMaxPool) {
235-
Array<PrimExpr> ravel_shape{x->shape.begin(), x->shape.end()};
243+
Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
236244
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
237245
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
238246

@@ -257,7 +265,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
257265
auto mp_inds = mp_argmax[0];
258266

259267
return tvm::te::compute(
260-
x->shape,
268+
data_shape,
261269
[&](const Array<Var>& inds) {
262270
Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
263271
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
@@ -288,7 +296,7 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
288296
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height));
289297
auto windoww = tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width));
290298
return tvm::te::compute(
291-
x->shape,
299+
data_shape,
292300
[&](const Array<Var>& inds) {
293301
PrimExpr pad_h_idx = inds[height_axis] + pad_top;
294302
PrimExpr pad_w_idx = inds[width_axis] + pad_left;
@@ -483,10 +491,14 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
483491
const auto n_dim = output_size.size();
484492
CHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
485493

486-
Array<PrimExpr> out_shape = x->shape;
494+
Array<PrimExpr> data_shape = x->shape;
495+
for (size_t i = 0; i < data_shape.size(); ++i) {
496+
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
497+
}
498+
Array<PrimExpr> out_shape = data_shape;
487499
Array<PrimExpr> in_size, out_size;
488500
for (size_t i = 0; i < n_dim; ++i) {
489-
in_size.push_back(x->shape[axes[i]]);
501+
in_size.push_back(data_shape[axes[i]]);
490502
out_size.push_back(cast(DataType::Int(32), output_size[i]));
491503
out_shape.Set(axes[i], out_size[i]);
492504
}
@@ -661,7 +673,11 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
661673
std::vector<PrimExpr> pad_tail(k_size);
662674
Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
663675
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
664-
Array<PrimExpr> out_shape = x->shape;
676+
Array<PrimExpr> data_shape = x->shape;
677+
for (size_t i = 0; i < data_shape.size(); ++i) {
678+
data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
679+
}
680+
Array<PrimExpr> out_shape = data_shape;
665681

666682
bool do_pad = false;
667683
for (int i = 0; i < k_size; i++) {
@@ -687,7 +703,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
687703

688704
arith::Analyzer analyzer;
689705
auto out_dim = analyzer.Simplify(
690-
indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
706+
indexdiv(data_shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1);
691707

692708
out_shape.Set(ii, out_dim);
693709
}
@@ -746,7 +762,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
746762
for (int i = 0; i < k_size; i++) {
747763
int ii = axis[i];
748764
start[i] = output[ii] * stride[i] - pad_head[i];
749-
end[i] = min(start[i] + kernel[i], x->shape[ii]);
765+
end[i] = min(start[i] + kernel[i], data_shape[ii]);
750766
start[i] = max(start[i], make_const(DataType::Int(32), 0));
751767
kernel_size *= (end[i] - start[i]);
752768
}

tests/python/relay/test_op_grad_level2.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -66,39 +66,43 @@ def test_max_pool2d_grad():
6666
)
6767

6868

69-
def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, count_include_pad):
70-
x = relay.var("x", relay.TensorType(x_shape, "float32"))
71-
y = tvm.relay.nn.avg_pool2d(
72-
x,
73-
pool_size=pool_size,
74-
strides=strides,
75-
padding=padding,
76-
ceil_mode=ceil_mode,
77-
count_include_pad=count_include_pad,
78-
)
79-
80-
fwd_func = relay.Function([x], y)
81-
fwd_func = run_infer_type(fwd_func)
82-
bwd_func = run_infer_type(gradient(fwd_func))
69+
def verify_avg_pool2d_grad(
70+
x_shape, pool_size, strides, padding, ceil_mode, count_include_pad, dtype="float32"
71+
):
72+
73+
for shape_dtype in ["int32", "int64"]:
74+
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in x_shape], dtype=dtype)
75+
y = tvm.relay.nn.avg_pool2d(
76+
x,
77+
pool_size=pool_size,
78+
strides=strides,
79+
padding=padding,
80+
ceil_mode=ceil_mode,
81+
count_include_pad=count_include_pad,
82+
)
8383

84-
data = np.random.rand(*x_shape).astype("float32")
85-
ph, pw = padding
86-
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
87-
out_grad = np.ones(shape=y_shape)
88-
ref_grad = tvm.topi.testing.pool_grad_nchw(
89-
data,
90-
out_grad,
91-
pool_size=pool_size,
92-
strides=strides,
93-
padding=[ph, pw, ph, pw],
94-
pool_type="avg",
95-
ceil_mode=ceil_mode,
96-
)
84+
fwd_func = relay.Function([x], y)
85+
fwd_func = run_infer_type(fwd_func)
86+
bwd_func = run_infer_type(gradient(fwd_func))
87+
88+
data = np.random.rand(*x_shape).astype(dtype)
89+
ph, pw = padding
90+
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
91+
out_grad = np.ones(shape=y_shape)
92+
ref_grad = tvm.topi.testing.pool_grad_nchw(
93+
data,
94+
out_grad,
95+
pool_size=pool_size,
96+
strides=strides,
97+
padding=[ph, pw, ph, pw],
98+
pool_type="avg",
99+
ceil_mode=ceil_mode,
100+
)
97101

98-
for target, ctx in tvm.testing.enabled_targets():
99-
intrp = relay.create_executor(ctx=ctx, target=target)
100-
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
101-
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
102+
for target, ctx in tvm.testing.enabled_targets():
103+
intrp = relay.create_executor(ctx=ctx, target=target)
104+
op_res, (op_grad,) = intrp.evaluate(bwd_func)(data)
105+
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
102106

103107

104108
@tvm.testing.uses_gpu
@@ -119,6 +123,15 @@ def test_avg_pool2d_grad():
119123
ceil_mode=False,
120124
count_include_pad=False,
121125
)
126+
verify_avg_pool2d_grad(
127+
(1, 4, 16, 16),
128+
pool_size=(1, 1),
129+
strides=(1, 1),
130+
padding=(1, 1),
131+
ceil_mode=False,
132+
count_include_pad=False,
133+
dtype="int32",
134+
)
122135

123136

124137
def verify_global_avg_pool2d_grad(x_shape):

tests/python/relay/test_op_level10.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -425,17 +425,18 @@ def verify_ndarray_size(shape):
425425

426426

427427
def verify_adaptive_pool(dshape, out_size, pool_type, layout, dtype, opfunc):
428-
x = relay.var("x", relay.TensorType(dshape, "float32"))
429-
y = opfunc(x, out_size, layout)
430-
func = relay.Function([x], y)
428+
for shape_dtype in ["int32", "int64"]:
429+
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
430+
y = opfunc(x, out_size, layout)
431+
func = relay.Function([x], y)
431432

432-
np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
433-
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
433+
np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
434+
np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
434435

435-
for target, ctx in tvm.testing.enabled_targets():
436-
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
437-
relay_out = intrp1.evaluate(func)(np_data)
438-
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)
436+
for target, ctx in tvm.testing.enabled_targets():
437+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
438+
relay_out = intrp1.evaluate(func)(np_data)
439+
tvm.testing.assert_allclose(relay_out.asnumpy(), np_out, rtol=1e-5, atol=1e-5)
439440

440441

441442
def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"):
@@ -452,13 +453,16 @@ def verify_adaptive_pool3d(dshape, out_size, pool_type, layout="NCHW", dtype="fl
452453
def test_adaptive_pool():
453454
verify_adaptive_pool2d((1, 9, 224, 224), (1, 1), "max")
454455
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg")
456+
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", dtype="int32")
455457
verify_adaptive_pool2d((1, 14, 56, 78), (34, 13), "max")
456458
verify_adaptive_pool2d((1, 5, 46, 97), (4, 96), "avg")
457459
verify_adaptive_pool2d((1, 224, 224, 3), (1, 1), "max", layout="NHWC")
458460
verify_adaptive_pool2d((1, 3, 224, 224), (2, 3), "avg", layout="NHWC")
459461
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "max", layout="NCDHW")
460462
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW")
461463
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC")
464+
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NCDHW", dtype="int32")
465+
verify_adaptive_pool3d((1, 16, 32, 32, 32), (1, 1, 1), "avg", layout="NDHWC", dtype="int32")
462466
verify_adaptive_pool3d((1, 16, 32, 32, 32), (2, 4, 4), "max", layout="NDHWC")
463467

464468

tests/python/relay/test_op_level2.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -959,15 +959,16 @@ def _test_pool2d_int(opfunc, reffunc, dtype):
959959
# test execution
960960
dtype = "int32"
961961
dshape = (1, 3, 28, 28)
962-
x = relay.var("x", shape=dshape, dtype=dtype)
963-
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
964-
func = relay.Function([x], y)
965-
data = np.random.randint(low=-128, high=128, size=dshape)
966-
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
967-
for target, ctx in tvm.testing.enabled_targets():
968-
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
969-
op_res1 = intrp1.evaluate(func)(data)
970-
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
962+
for shape_dtype in ["int32", "int64"]:
963+
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
964+
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
965+
func = relay.Function([x], y)
966+
data = np.random.randint(low=-128, high=128, size=dshape)
967+
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
968+
for target, ctx in tvm.testing.enabled_targets():
969+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
970+
op_res1 = intrp1.evaluate(func)(data)
971+
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
971972

972973

973974
def _test_global_pool2d(opfunc, reffunc):
@@ -1010,32 +1011,34 @@ def test_pool2d():
10101011

10111012
@tvm.testing.uses_gpu
10121013
def test_pool1d():
1013-
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0)):
1014+
def _test_pool1d(opfunc, pool_size=(2,), strides=(2,), padding=(0, 0), dtype="float32"):
10141015
n, c, w = te.var("n"), 10, 224
10151016
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
10161017
y = opfunc(x, pool_size=(1,))
10171018
assert "pool_size=" in y.astext()
10181019
yy = run_infer_type(y)
10191020
assert yy.checked_type == relay.TensorType((n, 10, 224), "float32")
10201021
# test execution
1021-
dtype = "float32"
10221022
dshape = (1, 3, 32)
1023-
x = relay.var("x", shape=dshape)
1024-
pool_type = "max" if "max" in str(opfunc) else "avg"
1025-
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
1026-
func = relay.Function([x], y)
1027-
data = np.random.uniform(size=dshape).astype(dtype)
1028-
ref_res = tvm.topi.testing.pool1d_ncw_python(
1029-
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
1030-
)
1031-
for target, ctx in tvm.testing.enabled_targets():
1032-
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
1033-
op_res1 = intrp1.evaluate(func)(data)
1034-
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
1023+
for shape_dtype in ["int32", "int64"]:
1024+
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
1025+
pool_type = "max" if "max" in str(opfunc) else "avg"
1026+
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
1027+
func = relay.Function([x], y)
1028+
data = np.random.uniform(size=dshape).astype(dtype)
1029+
ref_res = tvm.topi.testing.pool1d_ncw_python(
1030+
data, (2,), (2,), (0, 0), (1, 3, 16), pool_type, False
1031+
)
1032+
for target, ctx in tvm.testing.enabled_targets():
1033+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
1034+
op_res1 = intrp1.evaluate(func)(data)
1035+
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
10351036

10361037
_test_pool1d(relay.nn.max_pool1d)
1038+
_test_pool1d(relay.nn.max_pool1d, dtype="int32")
10371039
_test_pool1d(relay.nn.max_pool1d, pool_size=2, strides=2, padding=0)
10381040
_test_pool1d(relay.nn.avg_pool1d)
1041+
_test_pool1d(relay.nn.avg_pool1d, dtype="int32")
10391042
_test_pool1d(relay.nn.avg_pool1d, pool_size=2, strides=2, padding=0)
10401043

10411044

@@ -1047,6 +1050,7 @@ def _test_pool3d(
10471050
strides=(2, 2, 2),
10481051
padding=(0, 0, 0, 0, 0, 0),
10491052
out_shape=(1, 3, 16, 16, 16),
1053+
dtype="float32",
10501054
):
10511055
n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224
10521056
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
@@ -1057,30 +1061,33 @@ def _test_pool3d(
10571061
# test execution
10581062
dtype = "float32"
10591063
dshape = (1, 3, 32, 32, 32)
1060-
x = relay.var("x", shape=dshape)
1061-
pool_type = "max" if "max" in str(opfunc) else "avg"
1062-
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
1063-
func = relay.Function([x], y)
1064-
# check output shape
1065-
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
1066-
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
1067-
out_shape, f_out_shape
1068-
)
1069-
data = np.random.uniform(size=dshape).astype(dtype)
1070-
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
1071-
data, pool_size, strides, padding, out_shape, pool_type, False
1072-
)
1073-
for target, ctx in tvm.testing.enabled_targets():
1074-
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
1075-
op_res1 = intrp1.evaluate(func)(data)
1076-
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
1064+
for shape_dtype in ["int32", "int64"]:
1065+
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
1066+
pool_type = "max" if "max" in str(opfunc) else "avg"
1067+
y = opfunc(x, pool_size=pool_size, strides=strides, padding=padding)
1068+
func = relay.Function([x], y)
1069+
# check output shape
1070+
f_out_shape = tuple(map(lambda x: int(x), run_infer_type(func).ret_type.shape))
1071+
assert out_shape == f_out_shape, "Output shape mismatch. expected {}, actual {}".format(
1072+
out_shape, f_out_shape
1073+
)
1074+
data = np.random.uniform(size=dshape).astype(dtype)
1075+
ref_res = tvm.topi.testing.pool3d_ncdhw_python(
1076+
data, pool_size, strides, padding, out_shape, pool_type, False
1077+
)
1078+
for target, ctx in tvm.testing.enabled_targets():
1079+
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
1080+
op_res1 = intrp1.evaluate(func)(data)
1081+
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
10771082

10781083
_test_pool3d(relay.nn.max_pool3d)
1084+
_test_pool3d(relay.nn.max_pool3d, dtype="int32")
10791085
_test_pool3d(relay.nn.max_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
10801086
_test_pool3d(relay.nn.max_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
10811087
_test_pool3d(relay.nn.max_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))
10821088
_test_pool3d(relay.nn.max_pool3d, pool_size=2, padding=0, strides=2)
10831089
_test_pool3d(relay.nn.avg_pool3d)
1090+
_test_pool3d(relay.nn.avg_pool3d, dtype="int32")
10841091
_test_pool3d(relay.nn.avg_pool3d, padding=(2, 0, 0, 2, 0, 0), out_shape=(1, 3, 18, 16, 16))
10851092
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 3, 0, 0, 3, 0), out_shape=(1, 3, 16, 19, 16))
10861093
_test_pool3d(relay.nn.avg_pool3d, padding=(0, 0, 4, 0, 0, 4), out_shape=(1, 3, 16, 16, 20))

0 commit comments

Comments
 (0)