Skip to content

Commit d2528b7

Browse files
committed
introduce DensePackAttrs to avoid breaking dense op
1 parent d676de7 commit d2528b7

File tree

2 files changed

+21
-4
lines changed
  • include/tvm/relay/attrs
  • src/relay/op/nn

2 files changed

+21
-4
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,11 +992,26 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
992992
IndexExpr units;
993993
tvm::String auto_scheduler_rewritten_layout; // The layout after auto-scheduler's layout rewrite
994994
DataType out_dtype;
995-
tvm::String weight_layout;
996995

997996
TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
998997
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
999998

999+
// use 0 bits to indicate none.
1000+
TVM_ATTR_FIELD(out_dtype)
1001+
.set_default(NullValue<DataType>())
1002+
.describe("Output data type, set to explicit type under mixed precision setting");
1003+
}
1004+
};
1005+
1006+
/*! \brief Attributes for dense_pack operator */
1007+
struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
1008+
IndexExpr units;
1009+
DataType out_dtype;
1010+
tvm::String weight_layout;
1011+
1012+
TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") {
1013+
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
1014+
10001015
// use 0 bits to indicate none.
10011016
TVM_ATTR_FIELD(out_dtype)
10021017
.set_default(NullValue<DataType>())

src/relay/op/nn/nn.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,12 @@ RELAY_REGISTER_OP("nn.dense")
233233
// ------------------- relay.nn.dense
234234

235235
// ------------------- relay.nn.contrib_dense_pack
236+
TVM_REGISTER_NODE_TYPE(DensePackAttrs);
237+
236238
// Positional relay function to create dense_pack operator used by frontend FFI.
237239
Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
238240
DataType out_dtype) {
239-
auto attrs = make_object<DenseAttrs>();
241+
auto attrs = make_object<DensePackAttrs>();
240242
attrs->units = units;
241243
attrs->out_dtype = out_dtype;
242244
attrs->weight_layout = std::move(weight_layout);
@@ -253,7 +255,7 @@ bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
253255
const auto* weight = types[1].as<TensorTypeNode>();
254256
if (data == nullptr || weight == nullptr) return false;
255257

256-
const DenseAttrs* param = attrs.as<DenseAttrs>();
258+
const DensePackAttrs* param = attrs.as<DensePackAttrs>();
257259
ICHECK(param != nullptr);
258260

259261
ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
@@ -275,7 +277,7 @@ InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
275277
const Array<Layout>& new_in_layouts,
276278
const Array<Layout>& old_in_layouts,
277279
const Array<tvm::relay::Type>& old_in_types) {
278-
auto params = attrs.as<DenseAttrs>();
280+
auto params = attrs.as<DensePackAttrs>();
279281
ICHECK(params);
280282
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
281283
}

0 commit comments

Comments
 (0)