@@ -233,10 +233,12 @@ RELAY_REGISTER_OP("nn.dense")
233
233
// ------------------- relay.nn.dense
234
234
235
235
// ------------------- relay.nn.contrib_dense_pack
236
+ TVM_REGISTER_NODE_TYPE (DensePackAttrs);
237
+
236
238
// Positional relay function to create dense_pack operator used by frontend FFI.
237
239
Expr MakeDensePack (Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
238
240
DataType out_dtype) {
239
- auto attrs = make_object<DenseAttrs >();
241
+ auto attrs = make_object<DensePackAttrs >();
240
242
attrs->units = units;
241
243
attrs->out_dtype = out_dtype;
242
244
attrs->weight_layout = std::move (weight_layout);
@@ -253,7 +255,7 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
253
255
const auto * weight = types[1 ].as <TensorTypeNode>();
254
256
if (data == nullptr || weight == nullptr ) return false ;
255
257
256
- const DenseAttrs * param = attrs.as <DenseAttrs >();
258
+ const DensePackAttrs * param = attrs.as <DensePackAttrs >();
257
259
ICHECK (param != nullptr );
258
260
259
261
ICHECK_EQ (data->shape .size (), 2 ) << " Only 2D data is supported" ;
@@ -275,7 +277,7 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
275
277
const Array<Layout>& new_in_layouts,
276
278
const Array<Layout>& old_in_layouts,
277
279
const Array<tvm::relay::Type>& old_in_types) {
278
- auto params = attrs.as <DenseAttrs >();
280
+ auto params = attrs.as <DensePackAttrs >();
279
281
ICHECK (params);
280
282
return InferCorrectLayoutOutput ({" NC" , params->weight_layout }, {" NC" }, attrs);
281
283
}
0 commit comments