@@ -1035,6 +1035,175 @@ RELAY_REGISTER_OP("arange")
1035
1035
.set_attr<FTVMCompute>(" FTVMCompute" , ArangeCompute)
1036
1036
.set_attr<TOpPattern>(" TOpPattern" , kInjective );
1037
1037
1038
+ // repeat operator
1039
+ TVM_REGISTER_NODE_TYPE (RepeatAttrs);
1040
+
1041
+ bool RepeatRel (const Array<Type>& types,
1042
+ int num_inputs,
1043
+ const Attrs& attrs,
1044
+ const TypeReporter& reporter) {
1045
+ // `types` contains: [data, result]
1046
+ CHECK_EQ (types.size (), 2 );
1047
+ const auto * data = types[0 ].as <TensorTypeNode>();
1048
+ if (data == nullptr ) {
1049
+ CHECK (types[0 ].as <IncompleteTypeNode>())
1050
+ << " repeat: expect input type to be TensorType but get "
1051
+ << types[0 ];
1052
+ return false ;
1053
+ }
1054
+ const auto * param = attrs.as <RepeatAttrs>();
1055
+ const int ndim = static_cast <int >(data->shape .size ());
1056
+ const int repeats = param->repeats ;
1057
+ const int axis = param->axis ;
1058
+ CHECK (repeats >= 1 )
1059
+ << " repeat only accepts `repeats >= 1`"
1060
+ << " , but got repeats = " << repeats;
1061
+ CHECK (-ndim - 1 <= axis && axis <= ndim)
1062
+ << " repeat only accepts `axis` in [-data.ndim - 1, data.ndim]"
1063
+ << " , but got axis = " << axis
1064
+ << " , and data.ndim = " << ndim;
1065
+ const int pivot = axis < 0 ? ndim + axis : axis;
1066
+ std::vector<IndexExpr> oshape;
1067
+ oshape.reserve (ndim + repeats);
1068
+ for (int i = 0 ; i < pivot; ++i) {
1069
+ oshape.emplace_back (data->shape [i]);
1070
+ }
1071
+ oshape.emplace_back (data->shape [pivot] * repeats);
1072
+ for (int i = pivot + 1 ; i < ndim; ++i) {
1073
+ oshape.emplace_back (data->shape [i]);
1074
+ }
1075
+ reporter->Assign (types[1 ], TensorTypeNode::make (oshape, data->dtype ));
1076
+ return true ;
1077
+ }
1078
+
1079
+ Array<Tensor> RepeatCompute (const Attrs& attrs,
1080
+ const Array<Tensor>& inputs,
1081
+ const Type& out_type,
1082
+ const Target& target) {
1083
+ const RepeatAttrs *param = attrs.as <RepeatAttrs>();
1084
+ CHECK (param != nullptr );
1085
+ return { topi::repeat (inputs[0 ], param->repeats , param->axis ) };
1086
+ }
1087
+
1088
+ Expr MakeRepeat (Expr data,
1089
+ int repeats,
1090
+ int axis) {
1091
+ auto attrs = make_node<RepeatAttrs>();
1092
+ attrs->repeats = repeats;
1093
+ attrs->axis = axis;
1094
+ static const Op& op = Op::Get (" repeat" );
1095
+ return CallNode::make (op, {data}, Attrs (attrs), {});
1096
+ }
1097
+
1098
+ TVM_REGISTER_API (" relay.op._make.repeat" )
1099
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
1100
+ runtime::detail::unpack_call<Expr, 3 >(MakeRepeat, args, rv);
1101
+ });
1102
+
1103
+ RELAY_REGISTER_OP (" repeat" )
1104
+ .describe(R"code( Repeat elements of an array `repeats` times along axis `axis`
1105
+
1106
+ - **data**: The input data to the operator.
1107
+
1108
+ )code" TVM_ADD_FILELINE)
1109
+ .set_num_inputs(1 )
1110
+ .set_attrs_type_key(" relay.attrs.Repeat" )
1111
+ .add_argument(" data" , " Tensor" , " The input tensor." )
1112
+ .set_support_level(1 )
1113
+ .add_type_rel(" Repeat" , RepeatRel)
1114
+ .set_attr<FTVMCompute>(" FTVMCompute" , RepeatCompute)
1115
+ .set_attr<TOpPattern>(" TOpPattern" , kBroadcast );
1116
+
1117
+ // tile operator
1118
+ TVM_REGISTER_NODE_TYPE (TileAttrs);
1119
+
1120
+ bool TileRel (const Array<Type>& types,
1121
+ int num_inputs,
1122
+ const Attrs& attrs,
1123
+ const TypeReporter& reporter) {
1124
+ // `types` contains: [data, result]
1125
+ CHECK_EQ (types.size (), 2 );
1126
+ const auto * data = types[0 ].as <TensorTypeNode>();
1127
+ if (data == nullptr ) {
1128
+ CHECK (types[0 ].as <IncompleteTypeNode>())
1129
+ << " tile: expect input type to be TensorType but get "
1130
+ << types[0 ];
1131
+ return false ;
1132
+ }
1133
+ const auto * param = attrs.as <TileAttrs>();
1134
+ const size_t ndim = data->shape .size ();
1135
+ const Array<Integer>& reps = param->reps ;
1136
+ // check dimension match
1137
+ CHECK (!reps.defined ())
1138
+ << " repetition array is not defined. data.ndim = " << ndim;
1139
+ const size_t rndim = reps.size ();
1140
+ size_t tndim = (ndim > rndim) ? ndim : rndim;
1141
+ // re-construct data shape or reps shape
1142
+ std::vector<IndexExpr> data_shape;
1143
+ std::vector<IndexExpr> reps_shape;
1144
+ data_shape.reserve (tndim);
1145
+ reps_shape.reserve (tndim);
1146
+ if (ndim == rndim) {
1147
+ for (size_t i = 0 ; i < tndim; ++i) {
1148
+ data_shape.emplace_back (data->shape [i]);
1149
+ reps_shape.emplace_back (reps[i]);
1150
+ }
1151
+ } else if (ndim > rndim) {
1152
+ for (size_t i = 0 ; i < ndim; ++i)
1153
+ data_shape.emplace_back (data->shape [i]);
1154
+ for (size_t i = 0 ; i < (ndim - rndim); ++i)
1155
+ reps_shape.emplace_back (1 );
1156
+ for (size_t i = 0 ; i < rndim; ++i)
1157
+ reps_shape.emplace_back (reps[i]);
1158
+ } else {
1159
+ for (size_t i = 0 ; i < rndim; ++i)
1160
+ reps_shape.emplace_back (reps[i]);
1161
+ }
1162
+ std::vector<IndexExpr> oshape;
1163
+ oshape.reserve (tndim);
1164
+ for (size_t i = 0 ; i < tndim; ++i) {
1165
+ oshape.emplace_back (data_shape[i] * reps_shape[i]);
1166
+ }
1167
+ reporter->Assign (types[1 ], TensorTypeNode::make (oshape, data->dtype ));
1168
+ return true ;
1169
+ }
1170
+
1171
+ Array<Tensor> TileCompute (const Attrs& attrs,
1172
+ const Array<Tensor>& inputs,
1173
+ const Type& out_type,
1174
+ const Target& target) {
1175
+ const TileAttrs *param = attrs.as <TileAttrs>();
1176
+ CHECK (param != nullptr );
1177
+ return { topi::tile (inputs[0 ], param->reps ) };
1178
+ }
1179
+
1180
+ Expr MakeTile (Expr data,
1181
+ Array<Integer> reps) {
1182
+ auto attrs = make_node<TileAttrs>();
1183
+ attrs->reps = reps;
1184
+ static const Op& op = Op::Get (" tile" );
1185
+ return CallNode::make (op, {data}, Attrs (attrs), {});
1186
+ }
1187
+
1188
+ TVM_REGISTER_API (" relay.op._make.tile" )
1189
+ .set_body([](const TVMArgs& args, TVMRetValue* rv) {
1190
+ runtime::detail::unpack_call<Expr, 2 >(MakeTile, args, rv);
1191
+ });
1192
+
1193
+ RELAY_REGISTER_OP (" tile" )
1194
+ .describe(R"code( Repeat the whole array multiple times.
1195
+
1196
+ - **data**: The input data to the operator.
1197
+
1198
+ )code" TVM_ADD_FILELINE)
1199
+ .set_num_inputs(1 )
1200
+ .set_attrs_type_key(" relay.attrs.Tile" )
1201
+ .add_argument(" data" , " Tensor" , " The input tensor." )
1202
+ .set_support_level(1 )
1203
+ .add_type_rel(" Tile" , TileRel)
1204
+ .set_attr<FTVMCompute>(" FTVMCompute" , TileCompute)
1205
+ .set_attr<TOpPattern>(" TOpPattern" , kBroadcast );
1206
+
1038
1207
// where operator
1039
1208
bool WhereRel (const Array<Type>& types,
1040
1209
int num_inputs,
0 commit comments