Skip to content

Commit 8eb39d2

Browse files
Laurawlywweic
authored andcommitted
[Relay/TOPI][Frontend] Add tile and repeat operators in Relay and TOPI (apache#2720)
* tile and repeat operator added in rely * fix pylint * fix make warnings * comments addressed * fix lint error * comment addressed
1 parent b139a1f commit 8eb39d2

File tree

11 files changed

+506
-0
lines changed

11 files changed

+506
-0
lines changed

docs/api/python/topi.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ List of operators
7373
topi.logical_not
7474
topi.arange
7575
topi.stack
76+
topi.repeat
77+
topi.tile
7678
topi.layout_transform
7779
topi.image.resize
7880

@@ -132,6 +134,8 @@ topi
132134
.. autofunction:: topi.less
133135
.. autofunction:: topi.arange
134136
.. autofunction:: topi.stack
137+
.. autofunction:: topi.repeat
138+
.. autofunction:: topi.tile
135139
.. autofunction:: topi.layout_transform
136140

137141
topi.nn

docs/langref/relay_op.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ This level enables additional math and transform operators.
9797
tvm.relay.split
9898
tvm.relay.arange
9999
tvm.relay.stack
100+
tvm.relay.repeat
101+
tvm.relay.tile
100102

101103

102104
**Level 4: Broadcast and Reductions**
@@ -225,6 +227,8 @@ Level 3 Definitions
225227
.. autofunction:: tvm.relay.split
226228
.. autofunction:: tvm.relay.arange
227229
.. autofunction:: tvm.relay.stack
230+
.. autofunction:: tvm.relay.repeat
231+
.. autofunction:: tvm.relay.tile
228232

229233

230234
Level 4 Definitions

include/tvm/relay/attrs/transform.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,28 @@ struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
124124
}
125125
}; // struct StackAttrs
126126

127+
/*! \brief Attributes used in repeat operators */
128+
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
129+
Integer repeats;
130+
Integer axis;
131+
TVM_DECLARE_ATTRS(RepeatAttrs, "relay.attrs.RepeatAttrs") {
132+
TVM_ATTR_FIELD(repeats)
133+
.describe("The number of repetitions for each element.");
134+
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
135+
.describe(" The axis along which to repeat values.");
136+
}
137+
}; // struct RepeatAttrs
138+
139+
/*! \brief Attributes used in tile operators */
140+
struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
141+
Array<Integer> reps;
142+
TVM_DECLARE_ATTRS(TileAttrs, "relay.attrs.TileAttrs") {
143+
TVM_ATTR_FIELD(reps)
144+
.describe("The number of times for repeating the tensor a."
145+
"Each dim sizeof reps must be a positive integer.");
146+
}
147+
}; // struct TileAttrs
148+
127149
/*! \brief Attributes used in squeeze operators */
128150
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
129151
// use axis to make the name numpy compatible.

python/tvm/relay/frontend/mxnet.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def _mx_dropout(inputs, attrs):
166166
return _op.nn.dropout(inputs[0], rate=rate)
167167

168168

169+
def _mx_BlockGrad(inputs, attrs): #pylint: disable=unused-argument
170+
return inputs
171+
172+
169173
def _mx_batch_norm(inputs, attrs):
170174
if attrs.get_bool("output_mean_var", False):
171175
raise RuntimeError("batch_norm do not support output_mean_var")
@@ -357,6 +361,21 @@ def _mx_arange(inputs, attrs):
357361
return _op.arange(**new_attrs)
358362

359363

364+
def _mx_repeat(inputs, attrs):
365+
assert len(inputs) == 1
366+
new_attrs = {}
367+
new_attrs["repeats"] = attrs.get_int("repeats")
368+
new_attrs["axis"] = attrs.get_int("axis", 0)
369+
return _op.repeat(inputs[0], **new_attrs)
370+
371+
372+
def _mx_tile(inputs, attrs):
373+
assert len(inputs) == 1
374+
new_attrs = {}
375+
new_attrs["reps"] = attrs.get_int_tuple("reps")
376+
return _op.tile(inputs[0], **new_attrs)
377+
378+
360379
def _mx_roi_align(inputs, attrs):
361380
new_attrs = {}
362381
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
@@ -490,6 +509,9 @@ def _mx_proposal(inputs, attrs):
490509
"batch_dot" : _mx_batch_dot,
491510
"LeakyReLU" : _mx_leaky_relu,
492511
"_arange" : _mx_arange,
512+
"repeat" : _mx_repeat,
513+
"tile" : _mx_tile,
514+
"BlockGrad" : _mx_BlockGrad,
493515
"SoftmaxOutput" : _mx_softmax_output,
494516
"SoftmaxActivation" : _mx_softmax_activation,
495517
# vision

python/tvm/relay/op/_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
_reg.register_schedule("full", schedule_injective)
2020
_reg.register_schedule("full_like", schedule_injective)
2121
_reg.register_schedule("arange", schedule_injective)
22+
_reg.register_schedule("repeat", schedule_broadcast)
23+
_reg.register_schedule("tile", schedule_broadcast)
2224
_reg.register_schedule("cast", schedule_injective)
2325
_reg.register_schedule("strided_slice", schedule_injective)
2426
_reg.register_schedule("slice_like", schedule_injective)

python/tvm/relay/op/transform.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,75 @@ def stack(data, axis):
316316
return _make.stack(data, axis)
317317

318318

319+
def repeat(data, repeats, axis):
320+
"""Repeats elements of an array.
321+
By default, repeat flattens the input array into 1-D and then repeats the elements.
322+
323+
repeats : int
324+
The number of repetitions for each element.
325+
326+
axis: int
327+
The axis along which to repeat values. The negative numbers are interpreted
328+
counting from the backward. By default, use the flattened input array, and
329+
return a flat output array.
330+
331+
Returns
332+
-------
333+
ret : relay.Expr
334+
The computed result.
335+
336+
Examples
337+
--------
338+
.. code-block:: python
339+
340+
x = [[1, 2], [3, 4]]
341+
relay.repeat(x, repeats=2) = [1., 1., 2., 2., 3., 3., 4., 4.]
342+
343+
relay.repeat(x, repeats=2, axis=1) = [[1., 1., 2., 2.],
344+
[3., 3., 4., 4.]]
345+
"""
346+
return _make.repeat(data, repeats, axis)
347+
348+
349+
def tile(data, reps):
350+
"""Repeats the whole array multiple times.
351+
352+
Parameters
353+
----------
354+
data : relay.Expr
355+
The input data to the operator.
356+
357+
reps : tuple of int
358+
The number of times repeating the tensor data.
359+
360+
.. note::
361+
Each dim size of reps must be a positive integer. If reps has length d,
362+
the result will have dimension of max(d, data.ndim); If data.ndim < d,
363+
data is promoted to be d-dimensional by prepending new axes.
364+
If data.ndim >= d, reps is promoted to a.ndim by pre-pending 1's to it.
365+
366+
Returns
367+
-------
368+
ret : relay.Expr
369+
The computed result.
370+
371+
Examples
372+
--------
373+
.. code-block:: python
374+
375+
x = [[1, 2], [3, 4]]
376+
relay.tile(x, reps=(2,3)) = [[1., 2., 1., 2., 1., 2.],
377+
[3., 4., 3., 4., 3., 4.],
378+
[1., 2., 1., 2., 1., 2.],
379+
[3., 4., 3., 4., 3., 4.]]
380+
381+
relay.tile(x, reps=(2,)) = [[1., 2., 1., 2.],
382+
[3., 4., 3., 4.]]
383+
"""
384+
385+
return _make.tile(data, reps)
386+
387+
319388
def where(condition, x, y):
320389
"""Selecting elements from either x or y depending on the value of the
321390
condition.

src/relay/op/tensor/transform.cc

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,175 @@ RELAY_REGISTER_OP("arange")
10351035
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
10361036
.set_attr<TOpPattern>("TOpPattern", kInjective);
10371037

1038+
// repeat operator
1039+
TVM_REGISTER_NODE_TYPE(RepeatAttrs);
1040+
1041+
bool RepeatRel(const Array<Type>& types,
1042+
int num_inputs,
1043+
const Attrs& attrs,
1044+
const TypeReporter& reporter) {
1045+
// `types` contains: [data, result]
1046+
CHECK_EQ(types.size(), 2);
1047+
const auto* data = types[0].as<TensorTypeNode>();
1048+
if (data == nullptr) {
1049+
CHECK(types[0].as<IncompleteTypeNode>())
1050+
<< "repeat: expect input type to be TensorType but get "
1051+
<< types[0];
1052+
return false;
1053+
}
1054+
const auto* param = attrs.as<RepeatAttrs>();
1055+
const int ndim = static_cast<int>(data->shape.size());
1056+
const int repeats = param->repeats;
1057+
const int axis = param->axis;
1058+
CHECK(repeats >= 1)
1059+
<< "repeat only accepts `repeats >= 1`"
1060+
<< ", but got repeats = " << repeats;
1061+
CHECK(-ndim - 1 <= axis && axis <= ndim)
1062+
<< "repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1063+
<< ", but got axis = " << axis
1064+
<< ", and data.ndim = " << ndim;
1065+
const int pivot = axis < 0 ? ndim + axis : axis;
1066+
std::vector<IndexExpr> oshape;
1067+
oshape.reserve(ndim + repeats);
1068+
for (int i = 0; i < pivot; ++i) {
1069+
oshape.emplace_back(data->shape[i]);
1070+
}
1071+
oshape.emplace_back(data->shape[pivot] * repeats);
1072+
for (int i = pivot + 1; i < ndim; ++i) {
1073+
oshape.emplace_back(data->shape[i]);
1074+
}
1075+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
1076+
return true;
1077+
}
1078+
1079+
Array<Tensor> RepeatCompute(const Attrs& attrs,
1080+
const Array<Tensor>& inputs,
1081+
const Type& out_type,
1082+
const Target& target) {
1083+
const RepeatAttrs *param = attrs.as<RepeatAttrs>();
1084+
CHECK(param != nullptr);
1085+
return { topi::repeat(inputs[0], param->repeats, param->axis) };
1086+
}
1087+
1088+
Expr MakeRepeat(Expr data,
1089+
int repeats,
1090+
int axis) {
1091+
auto attrs = make_node<RepeatAttrs>();
1092+
attrs->repeats = repeats;
1093+
attrs->axis = axis;
1094+
static const Op& op = Op::Get("repeat");
1095+
return CallNode::make(op, {data}, Attrs(attrs), {});
1096+
}
1097+
1098+
TVM_REGISTER_API("relay.op._make.repeat")
1099+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
1100+
runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
1101+
});
1102+
1103+
RELAY_REGISTER_OP("repeat")
1104+
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
1105+
1106+
- **data**: The input data to the operator.
1107+
1108+
)code" TVM_ADD_FILELINE)
1109+
.set_num_inputs(1)
1110+
.set_attrs_type_key("relay.attrs.Repeat")
1111+
.add_argument("data", "Tensor", "The input tensor.")
1112+
.set_support_level(1)
1113+
.add_type_rel("Repeat", RepeatRel)
1114+
.set_attr<FTVMCompute>("FTVMCompute", RepeatCompute)
1115+
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
1116+
1117+
// tile operator
1118+
TVM_REGISTER_NODE_TYPE(TileAttrs);
1119+
1120+
bool TileRel(const Array<Type>& types,
1121+
int num_inputs,
1122+
const Attrs& attrs,
1123+
const TypeReporter& reporter) {
1124+
// `types` contains: [data, result]
1125+
CHECK_EQ(types.size(), 2);
1126+
const auto* data = types[0].as<TensorTypeNode>();
1127+
if (data == nullptr) {
1128+
CHECK(types[0].as<IncompleteTypeNode>())
1129+
<< "tile: expect input type to be TensorType but get "
1130+
<< types[0];
1131+
return false;
1132+
}
1133+
const auto* param = attrs.as<TileAttrs>();
1134+
const size_t ndim = data->shape.size();
1135+
const Array<Integer>& reps = param->reps;
1136+
// check dimension match
1137+
CHECK(!reps.defined())
1138+
<< "repetition array is not defined. data.ndim = " << ndim;
1139+
const size_t rndim = reps.size();
1140+
size_t tndim = (ndim > rndim) ? ndim : rndim;
1141+
// re-construct data shape or reps shape
1142+
std::vector<IndexExpr> data_shape;
1143+
std::vector<IndexExpr> reps_shape;
1144+
data_shape.reserve(tndim);
1145+
reps_shape.reserve(tndim);
1146+
if (ndim == rndim) {
1147+
for (size_t i = 0; i < tndim; ++i) {
1148+
data_shape.emplace_back(data->shape[i]);
1149+
reps_shape.emplace_back(reps[i]);
1150+
}
1151+
} else if (ndim > rndim) {
1152+
for (size_t i = 0; i < ndim; ++i)
1153+
data_shape.emplace_back(data->shape[i]);
1154+
for (size_t i = 0; i < (ndim - rndim); ++i)
1155+
reps_shape.emplace_back(1);
1156+
for (size_t i = 0; i < rndim; ++i)
1157+
reps_shape.emplace_back(reps[i]);
1158+
} else {
1159+
for (size_t i = 0; i < rndim; ++i)
1160+
reps_shape.emplace_back(reps[i]);
1161+
}
1162+
std::vector<IndexExpr> oshape;
1163+
oshape.reserve(tndim);
1164+
for (size_t i = 0; i < tndim; ++i) {
1165+
oshape.emplace_back(data_shape[i] * reps_shape[i]);
1166+
}
1167+
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
1168+
return true;
1169+
}
1170+
1171+
Array<Tensor> TileCompute(const Attrs& attrs,
1172+
const Array<Tensor>& inputs,
1173+
const Type& out_type,
1174+
const Target& target) {
1175+
const TileAttrs *param = attrs.as<TileAttrs>();
1176+
CHECK(param != nullptr);
1177+
return { topi::tile(inputs[0], param->reps) };
1178+
}
1179+
1180+
Expr MakeTile(Expr data,
1181+
Array<Integer> reps) {
1182+
auto attrs = make_node<TileAttrs>();
1183+
attrs->reps = reps;
1184+
static const Op& op = Op::Get("tile");
1185+
return CallNode::make(op, {data}, Attrs(attrs), {});
1186+
}
1187+
1188+
TVM_REGISTER_API("relay.op._make.tile")
1189+
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
1190+
runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
1191+
});
1192+
1193+
RELAY_REGISTER_OP("tile")
1194+
.describe(R"code(Repeat the whole array multiple times.
1195+
1196+
- **data**: The input data to the operator.
1197+
1198+
)code" TVM_ADD_FILELINE)
1199+
.set_num_inputs(1)
1200+
.set_attrs_type_key("relay.attrs.Tile")
1201+
.add_argument("data", "Tensor", "The input tensor.")
1202+
.set_support_level(1)
1203+
.add_type_rel("Tile", TileRel)
1204+
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
1205+
.set_attr<TOpPattern>("TOpPattern", kBroadcast);
1206+
10381207
// where operator
10391208
bool WhereRel(const Array<Type>& types,
10401209
int num_inputs,

0 commit comments

Comments
 (0)