Skip to content

Commit 80f72d7

Browse files
committed
Fix winograd_nnpack_fp16
1 parent 6d52eca commit 80f72d7

File tree

9 files changed

+77
-43
lines changed

9 files changed

+77
-43
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,17 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
160160
struct Conv2DWinogradNNPACKWeightTransformAttrs
161161
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
162162
int convolution_algorithm;
163+
DataType out_dtype;
163164

164165
TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
165166
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
166167
TVM_ATTR_FIELD(convolution_algorithm)
167168
.describe(
168169
"The convolution algorithm for Winograd NNPACK. E.g. 3 for WT_8x8, "
169170
"6 for WT_8x8_FP16");
171+
TVM_ATTR_FIELD(out_dtype)
172+
.set_default(NullValue<DataType>())
173+
.describe("Output data type, set to explicit type under mixed precision setting");
170174
}
171175
};
172176

nnvm/include/nnvm/top/nn.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,17 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
186186
struct WinogradNNPACKWeightTransformParam
187187
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
188188
int convolution_algorithm;
189+
int out_dtype;
189190

190191
DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
191192
DMLC_DECLARE_FIELD(convolution_algorithm)
192193
.describe(
193194
"The convolution algorithm for Winograd NNPACK. E.g. 3 for WT_8x8, "
194195
"6 for WT_8x8_FP16");
196+
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
197+
.add_enum("same", -1)
198+
.set_default(-1)
199+
.describe("Output data type, set to explicit type under mixed precision setting");
195200
}
196201

197202
static const constexpr int kWeight = 0;

nnvm/python/nnvm/top/nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,9 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe
278278

279279
@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
280280
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
281-
return topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0],
282-
attrs.get_int('convolution_algorithm'))
281+
convolution_algorithm = attrs.get_int('convolution_algorithm')
282+
out_dype = attrs.get_str('out_dtype')
283+
return topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dype)
283284

284285
@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
285286
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):

nnvm/src/top/nn/convolution.cc

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,23 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
414414
DMLC_REGISTER_PARAMETER(WinogradConv2DParam);
415415

416416

417+
inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
418+
std::vector<int>* in_type,
419+
std::vector<int>* out_type) {
420+
const WinogradNNPACKWeightTransformParam& param =
421+
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);
422+
423+
CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
424+
CHECK_EQ(out_type->size(), 1U);
425+
printf("param.out_dtype: %d\n", param.out_dtype);
426+
if (param.out_dtype != -1) {
427+
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
428+
} else {
429+
ElemwiseType<1, 1>(attrs, in_type, out_type);
430+
}
431+
return true;
432+
}
433+
417434
NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
418435
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
419436
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
@@ -432,7 +449,7 @@ weight transformation in advance.
432449
TShape oshape({wshape[0], wshape[1], 8, 8});
433450
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
434451
return true;
435-
})
452+
})
436453
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
437454
std::vector<Layout> *ilayouts,
438455
const std::vector<Layout> *last_ilayouts,
@@ -442,7 +459,7 @@ weight transformation in advance.
442459
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
443460
return true;
444461
})
445-
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
462+
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
446463
.set_num_outputs(1)
447464
.set_num_inputs(1)
448465
.set_support_level(5);

python/tvm/relay/op/nn/_nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,8 @@ def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs
341341
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
342342
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
343343
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
344-
out = topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], attrs.get_int('convolution_algorithm'))
344+
convolution_algorithm = attrs.get_int('convolution_algorithm')
345+
out = topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dtype)
345346
return [out]
346347

347348
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")

python/tvm/relay/op/nn/nn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,8 @@ def contrib_conv2d_winograd_weight_transform(weight,
993993

994994

995995
def contrib_conv2d_winograd_nnpack_weight_transform(weight,
996-
convolution_algorithm):
996+
convolution_algorithm,
997+
out_dtype=""):
997998
r"""Weight Transformation part for 2D convolution with winograd algorithm.
998999
9991000
We separate this as a single op to enable pre-compute for inference.
@@ -1012,4 +1013,5 @@ def contrib_conv2d_winograd_nnpack_weight_transform(weight,
10121013
result : tvm.relay.Expr
10131014
The computed result.
10141015
"""
1015-
return _make.contrib_conv2d_winograd_nnpack_weight_transform(weight, convolution_algorithm)
1016+
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
1017+
weight, convolution_algorithm, out_dtype)

src/relay/op/nn/convolution.cc

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "../../pass/alter_op_layout.h"
1111
#include "../layout.h"
1212

13+
1314
namespace tvm {
1415
namespace relay {
1516

@@ -499,8 +500,8 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
499500

500501
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
501502
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
502-
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
503-
});
503+
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
504+
});
504505

505506

506507
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
@@ -521,17 +522,17 @@ weight transformation in advance.
521522
// Positional relay function to create conv2d winograd nnpack operator
522523
// used by frontend FFI.
523524
Expr MakeConv2DWinogradNNPACK(Expr data,
524-
Expr weight,
525-
Array<IndexExpr> strides,
526-
Array<IndexExpr> padding,
527-
Array<IndexExpr> dilation,
528-
int groups,
529-
IndexExpr channels,
530-
Array<IndexExpr> kernel_size,
531-
std::string data_layout,
532-
std::string kernel_layout,
533-
std::string out_layout,
534-
DataType out_dtype) {
525+
Expr weight,
526+
Array<IndexExpr> strides,
527+
Array<IndexExpr> padding,
528+
Array<IndexExpr> dilation,
529+
int groups,
530+
IndexExpr channels,
531+
Array<IndexExpr> kernel_size,
532+
std::string data_layout,
533+
std::string kernel_layout,
534+
std::string out_layout,
535+
DataType out_dtype) {
535536
auto attrs = make_node<Conv2DAttrs>();
536537
attrs->strides = std::move(strides);
537538
attrs->padding = std::move(padding);
@@ -547,17 +548,15 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
547548
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
548549
}
549550

550-
551551
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
552552
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
553-
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
554-
});
555-
553+
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
554+
});
556555

557556
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
558557
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
559-
This operator assumes the weight tensor is already pre-transformed by
560-
nn.contrib_conv2d_winograd_nnpack_weight_transform.
558+
This operator assumes the weight tensor is already pre-transformed by
559+
nn.contrib_conv2d_winograd_nnpack_weight_transform.
561560
562561
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
563562
- **weight**: Any shape
@@ -572,60 +571,63 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
572571
.add_argument("weight", "Tensor", "The weight tensor.")
573572
.set_support_level(10)
574573
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
575-
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
576-
Conv2DInferCorrectLayout<Conv2DAttrs>);
574+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
577575

578576
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
579577
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
580578

581579
bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
582-
int num_inputs,
583-
const Attrs& attrs,
584-
const TypeReporter& reporter) {
580+
int num_inputs,
581+
const Attrs& attrs,
582+
const TypeReporter& reporter) {
585583
CHECK_EQ(types.size(), 2);
586584
const auto* data = types[0].as<TensorTypeNode>();
587585
if (data == nullptr) return false;
588586

589-
const Conv2DWinogradNNPACKWeightTransformAttrs* param = attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
587+
const Conv2DWinogradNNPACKWeightTransformAttrs* param =
588+
attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
590589
CHECK(param != nullptr);
591590

592591
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
593592

594593
// each pad width element should be a pair of positive integers
595-
std::vector<IndexExpr> oshape {
594+
std::vector<IndexExpr> oshape{
596595
data->shape[0],
597596
data->shape[1],
598597
8,
599598
8,
600599
};
601600

602-
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
603-
data->dtype));
601+
DataType out_dtype = param->out_dtype;
602+
if (out_dtype.bits() == 0) {
603+
out_dtype = data->dtype;
604+
}
605+
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
604606
return true;
605607
}
606608

607609
Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
608-
int convolution_algorithm) {
610+
int convolution_algorithm,
611+
DataType out_dtype) {
609612
auto attrs = make_node<Conv2DWinogradNNPACKWeightTransformAttrs>();
610613
attrs->convolution_algorithm = convolution_algorithm;
614+
attrs->out_dtype = std::move(out_dtype);
611615
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform");
612616
return CallNode::make(op, {weight}, Attrs(attrs), {});
613617
}
614618

615-
616619
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
617620
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
618-
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
619-
});
620-
621+
runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
622+
});
621623

622624
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
623625
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
624-
625626
Separate this into another symbol in order to enable Precompute Pass to compute the
626627
weight transformation in advance.
627628
628629
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
630+
629631
)code" TVM_ADD_FILELINE)
630632
.set_attrs_type_key("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
631633
.set_num_inputs(1)

topi/python/topi/arm_cpu/conv2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,8 +760,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
760760
return F.nn.contrib_conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
761761
elif cfg.template_key == "winograd_nnpack_fp16" or cfg.template_key == "winograd_nnpack_fp32":
762762
# pre-compute winograd_nnpack transform
763+
# for winograd_nnpack_fp16, the the precomputeprune pass must run on device (where float16 is supported)
764+
weight_dtype = 'same' if cfg.template_key == "winograd_nnpack_fp32" else 'float16'
763765
transformed_kernel = F.nn.contrib_conv2d_winograd_nnpack_weight_transform(
764-
copy_inputs[1], convolution_algorithm=cfg['winograd_nnpack_algorithm'].val)
766+
copy_inputs[1], convolution_algorithm=cfg['winograd_nnpack_algorithm'].val, out_dtype=weight_dtype)
765767
copy_inputs[1] = transformed_kernel
766768
new_data = data
767769
new_kernel = tvm.placeholder((CO, CI, 8, 8), "float32")

topi/python/topi/nn/conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
410410
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
411411

412412

413-
def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm):
413+
def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
414414
"""Weight transformation for winograd
415415
Parameters
416416
----------

0 commit comments

Comments
 (0)