Skip to content

Commit 3f9924d

Browse files
committed
winograd_nnpack
1 parent abe6f77 commit 3f9924d

File tree

14 files changed

+859
-28
lines changed

14 files changed

+859
-28
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
155155
}
156156
};
157157

158+
/*! \brief Attributes used in winograd weight transformation operators */
159+
struct Conv2DWinogradNNPACKWeightTransformAttrs
160+
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
161+
int convolution_algorithm;
162+
DataType out_dtype;
163+
164+
TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
165+
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
166+
TVM_ATTR_FIELD(convolution_algorithm)
167+
.describe(
168+
"The convolution algorithm for Winograd NNPACK. "
169+
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
170+
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
171+
TVM_ATTR_FIELD(out_dtype)
172+
.set_default(NullValue<DataType>())
173+
.describe("Output data type, set to explicit type under mixed precision setting");
174+
}
175+
};
158176

159177
/*! \brief Attributes used in softmax operators */
160178
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {

nnvm/include/nnvm/top/nn.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
183183
static const constexpr int kWeight = 0;
184184
};
185185

186+
struct WinogradNNPACKWeightTransformParam
187+
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
188+
int convolution_algorithm;
189+
int out_dtype;
190+
191+
DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
192+
DMLC_DECLARE_FIELD(convolution_algorithm)
193+
.describe(
194+
"The convolution algorithm for Winograd NNPACK. "
195+
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
196+
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
197+
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
198+
.add_enum("same", -1)
199+
.set_default(-1)
200+
.describe("Output data type, set to explicit type under mixed precision setting");
201+
}
202+
203+
static const constexpr int kWeight = 0;
204+
};
205+
186206
struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
187207
int channels;
188208
TShape kernel_size;

nnvm/python/nnvm/top/nn.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def alter_conv2d_layout(attrs, inputs, tinfos):
161161
sym.contrib.conv2d_winograd_without_weight_transform
162162
sym.contrib_conv2d_winograd_weight_transform = \
163163
sym.contrib.conv2d_winograd_weight_transform
164+
sym.contrib_conv2d_winograd_nnpack_without_weight_transform = \
165+
sym.contrib.conv2d_winograd_nnpack_without_weight_transform
166+
sym.contrib_conv2d_winograd_nnpack_weight_transform = \
167+
sym.contrib.conv2d_winograd_nnpack_weight_transform
164168
sym.nn = sym
165169

166170
# map relay argument names to nnvm argument names
@@ -272,6 +276,47 @@ def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, targe
272276
OpPattern.OUT_ELEMWISE_FUSABLE)
273277

274278

279+
@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
280+
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
281+
convolution_algorithm = attrs.get_int('convolution_algorithm')
282+
out_dype = attrs.get_str('out_dtype')
283+
return topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dype)
284+
285+
@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
286+
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
287+
with tvm.target.create(target):
288+
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
289+
290+
reg.register_pattern("_contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE)
291+
292+
293+
@reg.register_compute("_contrib_conv2d_winograd_nnpack_without_weight_transform")
294+
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, inputs, _):
295+
padding = attrs.get_int_tuple("padding")
296+
strides = attrs.get_int_tuple("strides")
297+
dilation = attrs.get_int_tuple("dilation")
298+
groups = attrs.get_int("groups")
299+
layout = attrs.get_str("layout")
300+
out_dtype = attrs.get_str("out_dtype")
301+
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
302+
assert dilation == (1, 1), "Do not support dilate now"
303+
assert groups == 1, "Do not supoort arbitrary group number"
304+
305+
# pylint: disable=assignment-from-no-return
306+
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
307+
inputs[0], inputs[1], inputs[2] if attrs.get_bool("use_bias") else None,
308+
strides, padding, dilation, layout, out_dtype)
309+
return out
310+
311+
@reg.register_schedule("_contrib_conv2d_winograd_nnpack_without_weight_transform")
312+
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
313+
with tvm.target.create(target):
314+
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
315+
316+
reg.register_pattern("_contrib_conv2d_winograd_nnpack_without_weight_transform",
317+
OpPattern.OPAQUE)
318+
319+
275320
# conv2d_transpose
276321
@reg.register_compute("conv2d_transpose")
277322
def compute_conv2d_transpose(attrs, inputs, _):

nnvm/src/top/nn/convolution.cc

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
130130
return true;
131131
}
132132

133+
template<class Param>
133134
inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs,
134135
std::vector<TShape>* in_shape,
135136
std::vector<TShape>* out_shape) {
136137
static const Layout kNCHW("NCHW");
137138
static const Layout kOIHW("OIHW");
138139

139-
const WinogradConv2DParam& param = nnvm::get<WinogradConv2DParam>(attrs.parsed);
140+
const Param& param = nnvm::get<Param>(attrs.parsed);
140141

141142
const Layout in_layout(param.layout);
142143
const Layout kernel_layout(param.kernel_layout);
@@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
403404
.set_attr_parser(ParamParser<WinogradConv2DParam>)
404405
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>)
405406
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>)
406-
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape)
407+
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<WinogradConv2DParam>)
407408
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
408409
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>)
409410
.set_num_outputs(1)
@@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
412413

413414
DMLC_REGISTER_PARAMETER(WinogradConv2DParam);
414415

416+
417+
inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
418+
std::vector<int>* in_type,
419+
std::vector<int>* out_type) {
420+
const WinogradNNPACKWeightTransformParam& param =
421+
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);
422+
423+
CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
424+
CHECK_EQ(out_type->size(), 1U);
425+
426+
if (param.out_dtype != -1) {
427+
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
428+
} else {
429+
ElemwiseType<1, 1>(attrs, in_type, out_type);
430+
}
431+
return true;
432+
}
433+
434+
NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
435+
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
436+
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
437+
weight transformation in advance.
438+
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
439+
)code" NNVM_ADD_FILELINE)
440+
.add_argument("weight", "4D Tensor", "Weight tensor.")
441+
.add_arguments(WinogradNNPACKWeightTransformParam::__FIELDS__())
442+
.set_attr_parser(ParamParser<WinogradNNPACKWeightTransformParam>)
443+
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradNNPACKWeightTransformParam>)
444+
.set_attr<FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
445+
std::vector<TShape> *in_shape,
446+
std::vector<TShape> *out_shape) {
447+
const TShape &wshape = (*in_shape)[0];
448+
CHECK_EQ(wshape.ndim(), 4) << "Weight should be a 4 dimensional tensor";
449+
TShape oshape({wshape[0], wshape[1], 8, 8});
450+
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
451+
return true;
452+
})
453+
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
454+
std::vector<Layout> *ilayouts,
455+
const std::vector<Layout> *last_ilayouts,
456+
std::vector<Layout> *olayouts) {
457+
Layout layout("OIHW");
458+
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, layout);
459+
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
460+
return true;
461+
})
462+
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
463+
.set_num_outputs(1)
464+
.set_num_inputs(1)
465+
.set_support_level(5);
466+
467+
DMLC_REGISTER_PARAMETER(WinogradNNPACKWeightTransformParam);
468+
469+
NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_without_weight_transform)
470+
.describe(R"code(Compute conv2d with winograd nnpack.
471+
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
472+
- **weight**: Any shape
473+
We do not check shape for this input tensor.
474+
- **bias**: (channels,)
475+
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
476+
)code" NNVM_ADD_FILELINE)
477+
.add_argument("data", "4D Tensor", "Input data.")
478+
.add_argument("weight", "4D Tensor", "Transformed weight tensor.")
479+
.add_argument("bias", "1D Tensor", "Bias parameter.")
480+
.add_arguments(Conv2DParam::__FIELDS__())
481+
.set_attr_parser(ParamParser<Conv2DParam>)
482+
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
483+
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
484+
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<Conv2DParam>)
485+
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
486+
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
487+
.set_num_outputs(1)
488+
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
489+
.set_support_level(5);
490+
491+
415492
NNVM_REGISTER_OP(_conv2d_grad)
416493
.describe(R"code(2D convolution grad.
417494

python/tvm/contrib/nnpack.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .. import api as _api
55
from .. import intrin as _intrin
66
from .._ffi.function import _init_api
7+
from tvm import relay
78

89
def is_available():
910
"""Check whether NNPACK is available, that is, `nnp_initialize()`
@@ -149,11 +150,12 @@ def convolution_inference_without_weight_transform(
149150
ins[1],
150151
ins[2] if bias is not None else 0,
151152
outs[0], padding[0], padding[1], padding[2], padding[3],
152-
stride[0], stride[1], nthreads, algorithm), name="C")
153+
stride[0], stride[1], nthreads, algorithm), name="C", dtype='float32')
153154

154155
def convolution_inference_weight_transform(
155156
kernel, nthreads=1,
156-
algorithm=ConvolutionAlgorithm.AUTO):
157+
algorithm=ConvolutionAlgorithm.AUTO,
158+
dtype='float32'):
157159
"""Create an extern op to do inference convolution of 3D tensor data and
158160
4D tensor kernel and 1D tensor bias with nnpack.
159161
@@ -171,13 +173,14 @@ def convolution_inference_weight_transform(
171173
"""
172174
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
173175
output_channels, input_channels, _, _ = kernel.shape
174-
175176
transform_tile_size = 8
177+
if isinstance(dtype, relay.ty.TensorType):
178+
dtype = dtype.dtype
176179
return _api.extern(
177180
(output_channels, input_channels, transform_tile_size, transform_tile_size),
178181
[kernel],
179182
lambda ins, outs: _intrin.call_packed(
180183
"tvm.contrib.nnpack.convolution_inference_weight_transform",
181-
ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
184+
ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
182185

183186
_init_api("tvm.contrib.nnpack")

python/tvm/relay/op/nn/_nn.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,55 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
321321
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
322322
OpPattern.OUT_ELEMWISE_FUSABLE)
323323

324+
325+
# winograd nnpack related operators
326+
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
327+
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, inputs, out_dtype, target):
328+
"""Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
329+
# pylint: disable=assignment-from-no-return
330+
padding = attrs.get_int_tuple("padding")
331+
strides = attrs.get_int_tuple("strides")
332+
dilation = attrs.get_int_tuple("dilation")
333+
groups = attrs.get_int("groups")
334+
data_layout = attrs.get_str("data_layout")
335+
out_dtype = attrs.get_str("out_dtype")
336+
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
337+
assert dilation == (1, 1), "Do not support dilate now"
338+
assert groups == 1, "Do not supoort arbitrary group number"
339+
340+
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
341+
inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
342+
out_dtype)
343+
344+
return [out]
345+
346+
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
347+
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
348+
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
349+
with target:
350+
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
351+
352+
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
353+
OpPattern.OPAQUE)
354+
355+
356+
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
357+
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
358+
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
359+
convolution_algorithm = attrs.get_int('convolution_algorithm')
360+
out = topi.nn.conv2d_winograd_nnpack_weight_transform(inputs[0], convolution_algorithm, out_dtype)
361+
return [out]
362+
363+
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
364+
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
365+
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
366+
with target:
367+
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
368+
369+
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
370+
OpPattern.OPAQUE)
371+
372+
324373
@reg.register_compute("nn.contrib_conv2d_NCHWc")
325374
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
326375
"""Compute definition of conv2d NCHWc"""

0 commit comments

Comments
 (0)