Skip to content

Commit edceb6b

Browse files
committed
delete relay adaptive pooling changes
1 parent 9270ca6 commit edceb6b

File tree

5 files changed

+19
-5
lines changed

5 files changed

+19
-5
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,22 @@ struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
332332
}
333333
};
334334

335+
/*! \brief Attributes for adaptive pool operator */
336+
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
337+
Array<IndexExpr> output_size;
338+
std::string layout;
339+
340+
TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
341+
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({}))
342+
.describe("Output height and width.");
343+
TVM_ATTR_FIELD(layout).set_default("NCHW")
344+
.describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
345+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
346+
"dimensions respectively. Convolution is applied on the 'H' and"
347+
"'W' dimensions.");
348+
}
349+
};
350+
335351

336352
/*! \brief Attributes for dense operator */
337353
struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {

python/tvm/relay/op/nn/_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def schedule_avg_pool2d(attrs, outs, target):
247247
def schedule_global_max_pool2d(_, outs, target):
248248
"""Schedule definition of global_max_pool2d"""
249249
with target:
250-
return topi.generic.schedule_global_pool(outs)
250+
return topi.generic.schedule_adaptive_pool(outs)
251251

252252

253253
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -258,7 +258,7 @@ def schedule_global_max_pool2d(_, outs, target):
258258
def schedule_global_avg_pool2d(_, outs, target):
259259
"""Schedule definition of global_avg_pool2d"""
260260
with target:
261-
return topi.generic.schedule_global_pool(outs)
261+
return topi.generic.schedule_adaptive_pool(outs)
262262

263263

264264
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)

src/relay/op/nn/pooling.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ bool Pool2DRel(const Array<Type>& types,
7272

7373
CHECK(data != nullptr);
7474
const auto dshape = data->shape;
75-
CHECK_NE(dshape.size(), 0);
7675
CHECK_GE(dshape.size(), 2U)
7776
<< "Pool2D only support input >= 2-D: input must have height and width";
7877
const auto param = attrs.as<AttrType>();
@@ -284,7 +283,6 @@ bool GlobalPool2DRel(const Array<Type>& types,
284283
const auto* data = types[0].as<TensorTypeNode>();
285284
if (data == nullptr) { return false; }
286285
const auto dshape = data->shape;
287-
CHECK_NE(dshape.size(), 0);
288286
CHECK_GE(dshape.size(), 2U)
289287
<< "Pool2D only support input >= 2-D: input must have height and width";
290288
const auto param = attrs.as<GlobalPool2DAttrs>();

tests/python/relay/test_op_level10.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def test_adaptive_pool2d():
251251

252252

253253
if __name__ == "__main__":
254+
test_adaptive_pool2d()
254255
test_collapse_sum_like()
255256
test_broadcast_to_like()
256257
test_slice_like()

tests/python/relay/test_op_level2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def test_avg_pool2d_no_count_pad():
316316
op_res1 = intrp1.evaluate(func)(data)
317317
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
318318

319-
320319
def test_flatten_infer_type():
321320
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
322321
x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))

0 commit comments

Comments
 (0)