Skip to content

Commit be74cbb

Browse files
masahilhutton1
authored andcommitted
[Relay] Dense alter layout fixed for packed input (apache#8669)
* clean up typerel * add layout transform when input is 3D * add test * update doc to clarify that only 2D input data is supported * add weight_layout attribute in dense * remove explicit layout transform from dense_alter_op.py * Add DensePackInferCorrectLayout to insert layout transform * relax type rel * revert type rel relax and add check on dim * introduce DensePackAttrs to avoid breaking dense op * try fixing arm compute lib test * Update tests/python/contrib/test_arm_compute_lib/test_dense.py Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com> * formatting Co-authored-by: lhutton1 <35535092+lhutton1@users.noreply.github.com>
1 parent 7402c89 commit be74cbb

File tree

8 files changed

+139
-44
lines changed

8 files changed

+139
-44
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,25 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
10031003
}
10041004
};
10051005

1006+
/*! \brief Attributes for dense_pack operator */
1007+
struct DensePackAttrs : public tvm::AttrsNode<DensePackAttrs> {
1008+
IndexExpr units;
1009+
DataType out_dtype;
1010+
tvm::String weight_layout;
1011+
1012+
TVM_DECLARE_ATTRS(DensePackAttrs, "relay.attrs.DensePackAttrs") {
1013+
TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
1014+
1015+
// use 0 bits to indicate none.
1016+
TVM_ATTR_FIELD(out_dtype)
1017+
.set_default(NullValue<DataType>())
1018+
.describe("Output data type, set to explicit type under mixed precision setting");
1019+
TVM_ATTR_FIELD(weight_layout)
1020+
.set_default("NK")
1021+
.describe("Dimension ordering of weight. Packed layouts, such as NK8n, are possible.");
1022+
}
1023+
};
1024+
10061025
/*! \brief Attributes for batch matmul operator. */
10071026
struct BatchMatmulAttrs : public tvm::AttrsNode<BatchMatmulAttrs> {
10081027
DataType out_dtype;

python/tvm/relay/op/nn/_nn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,9 +1259,9 @@ def dense_shape_func(attrs, inputs, _):
12591259
@script
12601260
def _dense_pack_shape_func(data_shape, weight_shape):
12611261
out = output_tensor((data_shape.shape[0],), "int64")
1262-
for i in const_range(out.shape[0] - 1):
1263-
out[i] = data_shape[i]
1264-
out[out.shape[0] - 1] = weight_shape[0] * weight_shape[2]
1262+
assert data_shape.shape[0] == 2, "Input data must be 2D"
1263+
out[0] = data_shape[0]
1264+
out[1] = weight_shape[0] * weight_shape[2]
12651265

12661266
return out
12671267

python/tvm/relay/op/nn/nn.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,9 +1548,9 @@ def dense(data, weight, units=None, out_dtype=""):
15481548
return _make.dense(data, weight, units, out_dtype)
15491549

15501550

1551-
def contrib_dense_pack(data, weight, units=None, out_dtype=""):
1551+
def contrib_dense_pack(data, weight, weight_layout="NK", units=None, out_dtype=""):
15521552
"""Dense operator.
1553-
Applies a linear transformation
1553+
Applies a linear transformation with packed weight
15541554
15551555
.. math::
15561556
@@ -1560,25 +1560,27 @@ def contrib_dense_pack(data, weight, units=None, out_dtype=""):
15601560
----------
15611561
data : tvm.relay.Expr
15621562
The input data to the operator,
1563-
of shape `(d_1, d_2, ..., d_n, units_in)`.
1563+
of shape `(batch, units_in)`.
15641564
15651565
weight : tvm.relay.Expr
15661566
The transformed weight expressions, 3-D matrix,
15671567
of shape `(units // pack_weight_tile, units_in, pack_weight_tile)`.
15681568
1569+
weight_layout: str
1570+
The layout of weight, such as "NK" or "NK8n".
1571+
15691572
units : int, optional
15701573
Number of hidden units of the dense transformation.
15711574
15721575
out_dtype : str, optional
1573-
Specifies the output data type for mixed precision dense,
1574-
of shape `(d_1, d_2, ..., d_n, units)`.
1576+
Specifies the output data type for mixed precision dense.
15751577
15761578
Returns
15771579
-------
15781580
result : tvm.relay.Expr
15791581
The computed result.
15801582
"""
1581-
return _make.contrib_dense_pack(data, weight, units, out_dtype)
1583+
return _make.contrib_dense_pack(data, weight, weight_layout, units, out_dtype)
15821584

15831585

15841586
def fifo_buffer(data, buffer, axis):

python/tvm/topi/x86/dense_alter_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
6363
relay.op.get("nn.dense"), attrs, tinfos, out_type, target
6464
)
6565
workload = autotvm.task.get_workload(outs)
66+
6667
if workload:
6768
cfg = dispatch_ctx.query(target, workload)
6869
topi_impl = workload[0]
@@ -86,7 +87,6 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type):
8687
topi_impl,
8788
)
8889
dispatch_ctx.update(target, new_workload, cfg)
89-
weight_transform = relay.layout_transform(inputs[1], "NK", weight_layout)
90-
return relay.nn.contrib_dense_pack(inputs[0], weight_transform, None, out_dtype)
90+
return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype)
9191

9292
return None

src/relay/op/nn/nn.cc

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ Expr MakeDense(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
206206
return Call(op, {data, weight}, Attrs(attrs), {});
207207
}
208208

209+
InferCorrectLayoutOutput DenseInferCorrectLayout(const Attrs& attrs,
210+
const Array<Layout>& new_in_layouts,
211+
const Array<Layout>& old_in_layouts,
212+
const Array<tvm::relay::Type>& old_in_types) {
213+
return InferCorrectLayoutOutput({"NC", "NK"}, {"NC"}, attrs);
214+
}
215+
209216
TVM_REGISTER_GLOBAL("relay.op.nn._make.dense").set_body_typed(MakeDense);
210217

211218
RELAY_REGISTER_OP("nn.dense")
@@ -221,35 +228,75 @@ RELAY_REGISTER_OP("nn.dense")
221228
.add_argument("data", "nD Tensor", "Input data.")
222229
.add_argument("weight", "2D Tensor", "Weight matrix.")
223230
.set_support_level(1)
231+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DenseInferCorrectLayout)
224232
.add_type_rel("Dense", MatmulRel<DenseAttrs>);
225233
// ------------------- relay.nn.dense
226234

227235
// ------------------- relay.nn.contrib_dense_pack
236+
TVM_REGISTER_NODE_TYPE(DensePackAttrs);
237+
228238
// Positional relay function to create dense_pack operator used by frontend FFI.
229-
Expr MakeDensePack(Expr data, Expr weight, IndexExpr units, DataType out_dtype) {
230-
auto attrs = make_object<DenseAttrs>();
239+
Expr MakeDensePack(Expr data, Expr weight, tvm::String weight_layout, IndexExpr units,
240+
DataType out_dtype) {
241+
auto attrs = make_object<DensePackAttrs>();
231242
attrs->units = units;
232243
attrs->out_dtype = out_dtype;
244+
attrs->weight_layout = std::move(weight_layout);
233245
static const Op& op = Op::Get("nn.contrib_dense_pack");
234246
return Call(op, {data, weight}, Attrs(attrs), {});
235247
}
236248

237249
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_dense_pack").set_body_typed(MakeDensePack);
238250

251+
bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
252+
const TypeReporter& reporter) {
253+
ICHECK_EQ(types.size(), 3);
254+
const auto* data = types[0].as<TensorTypeNode>();
255+
const auto* weight = types[1].as<TensorTypeNode>();
256+
if (data == nullptr || weight == nullptr) return false;
257+
258+
const DensePackAttrs* param = attrs.as<DensePackAttrs>();
259+
ICHECK(param != nullptr);
260+
261+
ICHECK_EQ(data->shape.size(), 2) << "Only 2D data is supported";
262+
ICHECK_EQ(weight->shape.size(), 3) << "Weight is not packed";
263+
264+
Array<tvm::PrimExpr> oshape = data->shape;
265+
oshape.Set(1, weight->shape[0] * weight->shape[2]);
266+
267+
DataType out_dtype = param->out_dtype;
268+
if (out_dtype.bits() == 0) {
269+
out_dtype = data->dtype;
270+
}
271+
// assign output type
272+
reporter->Assign(types[2], TensorType(oshape, out_dtype));
273+
return true;
274+
}
275+
276+
InferCorrectLayoutOutput DensePackInferCorrectLayout(const Attrs& attrs,
277+
const Array<Layout>& new_in_layouts,
278+
const Array<Layout>& old_in_layouts,
279+
const Array<tvm::relay::Type>& old_in_types) {
280+
auto params = attrs.as<DensePackAttrs>();
281+
ICHECK(params);
282+
return InferCorrectLayoutOutput({"NC", params->weight_layout}, {"NC"}, attrs);
283+
}
284+
239285
RELAY_REGISTER_OP("nn.contrib_dense_pack")
240286
.describe(R"code(Applies a linear transformation: :math:`Y = XW^T`.
241287
242-
- **data**: `(x1, x2, ..., xn, input_dim)`
288+
- **data**: `(batch, input_dim)`
243289
- **weight**: `(units // pack_weight_tile, input_dim, pack_weight_tile)`
244-
- **out**: `(x1, x2, ..., xn, units)`.
290+
- **out**: `(batch, units)`.
245291
246292
)code" TVM_ADD_FILELINE)
247293
.set_attrs_type<DenseAttrs>()
248294
.set_num_inputs(2)
249-
.add_argument("data", "nD Tensor", "Input data.")
295+
.add_argument("data", "2D Tensor", "Input data.")
250296
.add_argument("weight", "3D Tensor", "Packed weight matrix.")
251297
.set_support_level(10)
252-
.add_type_rel("DensePack", DensePackRel<DenseAttrs>);
298+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DensePackInferCorrectLayout)
299+
.add_type_rel("DensePack", DensePackRel);
253300
// ------------------- relay.nn.contrib_dense_pack
254301

255302
// relay.leaky_relu
@@ -307,7 +354,6 @@ bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
307354
return true;
308355
}
309356

310-
template <typename T>
311357
InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs,
312358
const Array<Layout>& new_in_layouts,
313359
const Array<Layout>& old_in_layouts,
@@ -343,7 +389,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
343389
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
344390
.set_support_level(3)
345391
.add_type_rel("PRelu", PReluRel)
346-
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout<PReluAttrs>)
392+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PReluInferCorrectLayout)
347393
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const Array<te::Tensor>& inputs,
348394
const Type& out_type) {
349395
const auto* param = attrs.as<PReluAttrs>();

src/relay/op/nn/nn.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -116,29 +116,6 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
116116
return true;
117117
}
118118

119-
template <typename AttrType>
120-
bool DensePackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
121-
const TypeReporter& reporter) {
122-
ICHECK_EQ(types.size(), 3);
123-
const auto* data = types[0].as<TensorTypeNode>();
124-
const auto* weight = types[1].as<TensorTypeNode>();
125-
if (data == nullptr || weight == nullptr) return false;
126-
127-
const AttrType* param = attrs.as<AttrType>();
128-
ICHECK(param != nullptr);
129-
130-
Array<tvm::PrimExpr> oshape = data->shape;
131-
oshape.Set((oshape.size() - 1), weight->shape[0] * weight->shape[2]);
132-
133-
DataType out_dtype = param->out_dtype;
134-
if (out_dtype.bits() == 0) {
135-
out_dtype = data->dtype;
136-
}
137-
// assign output type
138-
reporter->Assign(types[2], TensorType(oshape, out_dtype));
139-
return true;
140-
}
141-
142119
template <typename AttrType>
143120
bool BatchMatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
144121
const TypeReporter& reporter) {

tests/python/contrib/test_arm_compute_lib/test_dense.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,16 @@ def _get_expected_codegen(shape, weight_shape, units, dtype, has_bias=False):
150150

151151
if has_bias:
152152
bias_dtype = "int32" if dtype == "uint8" else "float32"
153+
bias_shape = (
154+
[1, weight_shape[0]]
155+
if dtype == "float32" and weight_shape[0] != 1
156+
else [weight_shape[0]]
157+
)
153158
inputs.append(
154159
{
155160
"op": "const",
156161
"name": "",
157-
"attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]},
162+
"attrs": {"shape": [[bias_shape]], "dtype": [[bias_dtype]]},
158163
}
159164
)
160165

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,9 @@ def expected():
13151315
weight = relay.var("weight", shape=(48, 64))
13161316
target_layout = "NK16n"
13171317
weight_transform = relay.layout_transform(weight, "NK", target_layout)
1318-
y = relay.nn.contrib_dense_pack(x, weight_transform, units=None, out_dtype="float32")
1318+
y = relay.nn.contrib_dense_pack(
1319+
x, weight_transform, target_layout, units=None, out_dtype="float32"
1320+
)
13191321
y = relay.Function(analysis.free_vars(y), y)
13201322
return y
13211323

@@ -1353,6 +1355,49 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
13531355
assert before.body.attrs.layout == "NCHW"
13541356

13551357

1358+
def test_alter_op_dense_packed_data():
1359+
def before():
1360+
x = relay.var("x", shape=(1, 32, 8, 8))
1361+
weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
1362+
conv = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
1363+
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0])
1364+
squeeze = relay.squeeze(pool, axis=[2, 3])
1365+
dense = relay.nn.dense(squeeze, relay.var("dense_weight", shape=(16, 32)))
1366+
return relay.Function(analysis.free_vars(dense), dense)
1367+
1368+
def expected():
1369+
x = relay.var("x", shape=(1, 32, 8, 8))
1370+
conv_weight = relay.var("conv2d_weight", shape=(32, 32, 3, 3))
1371+
dense_weight = relay.var("dense_weight", shape=(16, 32))
1372+
conv = relay.nn.contrib_conv2d_nchwc(
1373+
relay.layout_transform(x, "NCHW", "NCHW8c"),
1374+
relay.layout_transform(conv_weight, "OIHW", "OIHW8i8o"),
1375+
channels=32,
1376+
kernel_size=(3, 3),
1377+
padding=(1, 1),
1378+
data_layout="NCHW8c",
1379+
kernel_layout="OIHW8i8o",
1380+
out_layout="NCHW8c",
1381+
)
1382+
pool = relay.nn.avg_pool2d(conv, pool_size=[8, 8], padding=[0, 0, 0, 0], layout="NCHW8c")
1383+
squeeze = relay.squeeze(pool, axis=[2, 3])
1384+
dense = relay.nn.contrib_dense_pack(
1385+
relay.layout_transform(squeeze, "NC8c", "NC"),
1386+
relay.layout_transform(dense_weight, "NK", "NK16n"),
1387+
"NK16n",
1388+
out_dtype="float32",
1389+
)
1390+
return relay.Function(analysis.free_vars(dense), dense)
1391+
1392+
with tvm.target.Target("llvm"):
1393+
with TempOpAttr(
1394+
"nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout
1395+
):
1396+
a = run_opt_pass(before(), transform.AlterOpLayout())
1397+
b = run_opt_pass(expected(), transform.InferType())
1398+
assert tvm.ir.structural_equal(a, b)
1399+
1400+
13561401
if __name__ == "__main__":
13571402
test_alter_op()
13581403
test_alter_return_none()
@@ -1377,3 +1422,4 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
13771422
test_alter_op_dense()
13781423
test_alter_layout_strided_slice_axes_nhwc()
13791424
test_not_inplace_modify()
1425+
test_alter_op_dense_packed_data()

0 commit comments

Comments
 (0)