@@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
206
206
return Call (op, {data, weight}, Attrs (attrs), {});
207
207
}
208
208
209
+ InferCorrectLayoutOutput DenseInferCorrectLayout (const Attrs& attrs,
210
+ const Array<Layout>& new_in_layouts,
211
+ const Array<Layout>& old_in_layouts,
212
+ const Array<tvm::relay::Type>& old_in_types) {
213
+ return InferCorrectLayoutOutput ({" NC" , " NK" }, {" NC" }, attrs);
214
+ }
215
+
209
216
TVM_REGISTER_GLOBAL (" relay.op.nn._make.dense" ).set_body_typed(MakeDense);
210
217
211
218
RELAY_REGISTER_OP (" nn.dense" )
@@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense")
221
228
.add_argument(" data" , " nD Tensor" , " Input data." )
222
229
.add_argument(" weight" , " 2D Tensor" , " Weight matrix." )
223
230
.set_support_level(1 )
231
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , DenseInferCorrectLayout)
224
232
.add_type_rel(" Dense" , MatmulRel<DenseAttrs>);
225
233
// ------------------- relay.nn.dense
226
234
227
235
// ------------------- relay.nn.contrib_dense_pack
236
+ TVM_REGISTER_NODE_TYPE (DensePackAttrs);
237
+
228
238
// Positional relay function to create dense_pack operator used by frontend FFI.
229
- Expr MakeDensePack (Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
230
- auto attrs = make_object<DenseAttrs>();
239
+ Expr MakeDensePack (Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
240
+ DataType out_dtype) {
241
+ auto attrs = make_object<DensePackAttrs>();
231
242
attrs->units = units;
232
243
attrs->out_dtype = out_dtype;
244
+ attrs->weight_layout = std::move (weight_layout);
233
245
static const Op& op = Op::Get (" nn.contrib_dense_pack" );
234
246
return Call (op, {data, weight}, Attrs (attrs), {});
235
247
}
236
248
237
249
TVM_REGISTER_GLOBAL (" relay.op.nn._make.contrib_dense_pack" ).set_body_typed(MakeDensePack);
238
250
251
+ bool DensePackRel (const Array<Type>& types, int num_inputs, const Attrs& attrs,
252
+ const TypeReporter& reporter) {
253
+ ICHECK_EQ (types.size (), 3 );
254
+ const auto * data = types[0 ].as <TensorTypeNode>();
255
+ const auto * weight = types[1 ].as <TensorTypeNode>();
256
+ if (data == nullptr || weight == nullptr ) return false ;
257
+
258
+ const DensePackAttrs* param = attrs.as <DensePackAttrs>();
259
+ ICHECK (param != nullptr );
260
+
261
+ ICHECK_EQ (data->shape .size (), 2 ) << " Only 2D data is supported" ;
262
+ ICHECK_EQ (weight->shape .size (), 3 ) << " Weight is not packed" ;
263
+
264
+ Array<tvm::PrimExpr> oshape = data->shape ;
265
+ oshape.Set (1 , weight->shape [0 ] * weight->shape [2 ]);
266
+
267
+ DataType out_dtype = param->out_dtype ;
268
+ if (out_dtype.bits () == 0 ) {
269
+ out_dtype = data->dtype ;
270
+ }
271
+ // assign output type
272
+ reporter->Assign (types[2 ], TensorType (oshape, out_dtype));
273
+ return true ;
274
+ }
275
+
276
+ InferCorrectLayoutOutput DensePackInferCorrectLayout (const Attrs& attrs,
277
+ const Array<Layout>& new_in_layouts,
278
+ const Array<Layout>& old_in_layouts,
279
+ const Array<tvm::relay::Type>& old_in_types) {
280
+ auto params = attrs.as <DensePackAttrs>();
281
+ ICHECK (params);
282
+ return InferCorrectLayoutOutput ({" NC" , params->weight_layout }, {" NC" }, attrs);
283
+ }
284
+
239
285
RELAY_REGISTER_OP (" nn.contrib_dense_pack" )
240
286
.describe(R"code( Applies a linear transformation: :math:`Y = XW^T`.
241
287
242
- - **data**: `(x1, x2, ..., xn , input_dim)`
288
+ - **data**: `(batch , input_dim)`
243
289
- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
244
- - **out**: `(x1, x2, ..., xn , units)`.
290
+ - **out**: `(batch , units)`.
245
291
246
292
)code" TVM_ADD_FILELINE)
247
293
.set_attrs_type<DenseAttrs>()
248
294
.set_num_inputs(2 )
249
- .add_argument(" data" , " nD Tensor" , " Input data." )
295
+ .add_argument(" data" , " 2D Tensor" , " Input data." )
250
296
.add_argument(" weight" , " 3D Tensor" , " Packed weight matrix." )
251
297
.set_support_level(10 )
252
- .add_type_rel(" DensePack" , DensePackRel<DenseAttrs>);
298
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , DensePackInferCorrectLayout)
299
+ .add_type_rel(" DensePack" , DensePackRel);
253
300
// ------------------- relay.nn.contrib_dense_pack
254
301
255
302
// relay.leaky_relu
@@ -307,7 +354,6 @@ bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
307
354
return true ;
308
355
}
309
356
310
- template <typename T>
311
357
InferCorrectLayoutOutput PReluInferCorrectLayout (const Attrs& attrs,
312
358
const Array<Layout>& new_in_layouts,
313
359
const Array<Layout>& old_in_layouts,
@@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
343
389
.add_argument(" alpha" , " Tensor" , " Input channelwise alpha." )
344
390
.set_support_level(3 )
345
391
.add_type_rel(" PRelu" , PReluRel)
346
- .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , PReluInferCorrectLayout<PReluAttrs> )
392
+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , PReluInferCorrectLayout)
347
393
.set_attr<FTVMCompute>(" FTVMCompute" , [](const Attrs& attrs, const Array<te::Tensor>& inputs,
348
394
const Type& out_type) {
349
395
const auto * param = attrs.as <PReluAttrs>();
0 commit comments