Skip to content

Commit b709a63

Browse files
notoraptortrevor-m
authored andcommitted
[relay][topi] Add operation relay.nn.dilate() which calls topi.nn.dilate() (apache#5331)
* Add operation relay.nn.dilate() which calls topi.nn.dilate(). * Fix typo * Set op pattern to injective
1 parent 9db61fe commit b709a63

File tree

6 files changed

+134
-0
lines changed

6 files changed

+134
-0
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,16 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
442442
}
443443
};
444444

445+
/*! \brief Attributes used in dilate operator */
446+
struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
447+
Array<IndexExpr> strides;
448+
449+
TVM_DECLARE_ATTRS(DilateAttrs, "relay.attrs.DilateAttrs") {
450+
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
451+
.describe("Dilation stride on each dimension, 1 means no dilation.");
452+
}
453+
};
454+
445455
/*! \brief Attributes used in 1D transposed convolution operator */
446456
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
447457
IndexExpr channels;

python/tvm/relay/op/nn/_nn.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,15 @@ def compute_cross_entropy(attrs, inputs, out_dtype):
502502
reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
503503

504504

505+
# dilate
506+
@reg.register_compute("nn.dilate")
507+
def compute_dilate(attrs, inputs, out_dtype):
508+
return [topi.nn.dilate(inputs[0], attrs.strides)]
509+
510+
reg.register_broadcast_schedule("nn.dilate")
511+
reg.register_pattern("nn.dilate", OpPattern.INJECTIVE)
512+
513+
505514
# cross_entropy_with_logits
506515
@reg.register_compute("nn.cross_entropy_with_logits")
507516
def compute_cross_entropy_with_logits(attrs, inputs, out_dtype):
@@ -697,6 +706,21 @@ def pad_shape_func(attrs, inputs, _):
697706
pad_width.append(get_const_tuple(pair))
698707
return [_pad_shape_func(inputs[0], convert(pad_width))]
699708

709+
@script
710+
def _dilate_shape_func(data_shape, strides):
711+
out = output_tensor((data_shape.shape[0],), "int64")
712+
for i in const_range(out.shape[0]):
713+
out[i] = (data_shape[i] - 1) * strides[i] + 1
714+
715+
return out
716+
717+
@reg.register_shape_func("nn.dilate", False)
718+
def dilate_shape_func(attrs, inputs, _):
719+
"""
720+
Shape function for dilate op.
721+
"""
722+
return [_dilate_shape_func(inputs[0], convert(attrs.strides))]
723+
700724
reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
701725
reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
702726
reg.register_shape_func("nn.relu", False, elemwise_shape_func)

python/tvm/relay/op/nn/nn.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,6 +1347,25 @@ def pad(data,
13471347
return _make.pad(data, pad_width, pad_value, pad_mode)
13481348

13491349

1350+
def dilate(data, strides):
1351+
"""Dilate data with zeros.
1352+
1353+
Parameters
1354+
----------
1355+
data : tvm.relay.Expr
1356+
n-D, can be any layout.
1357+
1358+
strides : <tuple of <int>
1359+
Dilation stride on each dimension, 1 means no dilation.
1360+
1361+
Returns
1362+
-------
1363+
Output : tvm.relay.Expr
1364+
The computed result
1365+
"""
1366+
return _make.dilate(data, strides)
1367+
1368+
13501369
def mirror_pad(data,
13511370
pad_width,
13521371
mode="SYMMETRIC"):

python/tvm/relay/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ class Conv2DTransposeAttrs(Attrs):
350350
"""Attributes used in Transposed Conv2D operators"""
351351

352352

353+
@tvm._ffi.register_object("relay.attrs.DilateAttrs")
354+
class DilateAttrs(Attrs):
355+
"""Attributes used in dilate operators"""
356+
357+
353358
@tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
354359
class SubPixelAttrs(Attrs):
355360
"""Attributes used in depth to space and space to depth operators"""

src/relay/op/nn/nn.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,54 @@ Do log on the data - do not accept logits.
10351035
.add_type_rel("CrossEntropy", CrossEntropyRel);
10361036

10371037

1038+
// relay.nn.dilate
1039+
TVM_REGISTER_NODE_TYPE(DilateAttrs);
1040+
1041+
bool DilateRel(const Array<Type>& types,
1042+
int num_inputs,
1043+
const Attrs& attrs,
1044+
const TypeReporter& reporter) {
1045+
CHECK_EQ(types.size(), 2);
1046+
const auto* x = types[0].as<TensorTypeNode>();
1047+
const DilateAttrs* param = attrs.as<DilateAttrs>();
1048+
if (x == nullptr) return false;
1049+
CHECK_EQ(x->shape.size(), param->strides.size());
1050+
1051+
std::vector<IndexExpr> oshape;
1052+
for (size_t i = 0; i < param->strides.size(); ++i) {
1053+
if (!x->shape[i].as<tir::AnyNode>()) {
1054+
oshape.push_back((x->shape[i] - 1) * param->strides[i] + 1);
1055+
} else {
1056+
oshape.push_back(x->shape[i]);
1057+
}
1058+
}
1059+
1060+
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), x->dtype));
1061+
return true;
1062+
}
1063+
1064+
// Positional relay function to create dilate operator used by frontend FFI.
1065+
Expr MakeDilate(Expr data, Array<IndexExpr> strides) {
1066+
auto attrs = make_object<DilateAttrs>();
1067+
attrs->strides = std::move(strides);
1068+
static const Op& op = Op::Get("nn.dilate");
1069+
return Call(op, {data}, Attrs(attrs), {});
1070+
}
1071+
1072+
1073+
TVM_REGISTER_GLOBAL("relay.op.nn._make.dilate")
1074+
.set_body_typed(MakeDilate);
1075+
1076+
1077+
RELAY_REGISTER_OP("nn.dilate")
1078+
.describe(R"code(
1079+
Dilate data with zeros.
1080+
)code" TVM_ADD_FILELINE)
1081+
.set_num_inputs(1)
1082+
.add_argument("x", "1D Tensor", "Data to dilate.")
1083+
.set_support_level(10)
1084+
.add_type_rel("Dilate", DilateRel);
1085+
10381086
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
10391087
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
10401088
static const Op& op = Op::Get("nn.cross_entropy_with_logits");

tests/python/relay/test_any.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,34 @@ def test_any_pad():
508508
verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3))
509509
verify_any_pad(any_dims(4), ((1, 0), (1, 3), (0, 2), (9, 0)), (13, 11, 3, 1))
510510

511+
def verify_any_dilate(data_shape, strides, static_data_shape):
512+
assert len(data_shape) == len(strides)
513+
mod = tvm.IRModule()
514+
dtype = "float32"
515+
data = relay.var('data', shape=data_shape, dtype=dtype)
516+
y = relay.nn.dilate(data, strides)
517+
mod["main"] = relay.Function([data], y)
518+
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
519+
ref_shape = tuple((static_data_shape[i] - 1) * strides[i] + 1
520+
for i in range(len(static_data_shape)))
521+
ref_out = np.zeros(shape=ref_shape, dtype=dtype)
522+
ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np
523+
524+
for kind in ["debug", "vm"]:
525+
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
526+
result = ex.evaluate()(data_np)
527+
tvm.testing.assert_allclose(result.asnumpy(), ref_out)
528+
529+
def test_any_dilate():
530+
verify_any_dilate(any_dims(1), (1,), (1,))
531+
verify_any_dilate(any_dims(1), (1,), (5,))
532+
verify_any_dilate(any_dims(1), (5,), (5,))
533+
verify_any_dilate(any_dims(3), (1, 1, 1), (1, 2, 3))
534+
verify_any_dilate(any_dims(3), (1, 1, 2), (1, 2, 3))
535+
verify_any_dilate(any_dims(3), (1, 1, 5), (1, 2, 3))
536+
verify_any_dilate(any_dims(3), (3, 7, 5), (1, 2, 3))
537+
verify_any_dilate(any_dims(4), (3, 7, 1, 5), (1, 2, 3, 4))
538+
511539
def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape):
512540
mod = tvm.IRModule()
513541
dtype = "float32"

0 commit comments

Comments
 (0)