@@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
130130 return true ;
131131}
132132
133+ template <class Param >
133134inline bool WinogradConv2DInferShape (const nnvm::NodeAttrs& attrs,
134135 std::vector<TShape>* in_shape,
135136 std::vector<TShape>* out_shape) {
136137 static const Layout kNCHW (" NCHW" );
137138 static const Layout kOIHW (" OIHW" );
138139
139- const WinogradConv2DParam & param = nnvm::get<WinogradConv2DParam >(attrs.parsed );
140+ const Param & param = nnvm::get<Param >(attrs.parsed );
140141
141142 const Layout in_layout (param.layout );
142143 const Layout kernel_layout (param.kernel_layout );
@@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
403404.set_attr_parser(ParamParser<WinogradConv2DParam>)
404405.set_attr<FGetAttrDict>(" FGetAttrDict" , ParamGetAttrDict<WinogradConv2DParam>)
405406.set_attr<FListInputNames>(" FListInputNames" , UseBiasListInputNames<WinogradConv2DParam>)
406- .set_attr<FInferShape>(" FInferShape" , WinogradConv2DInferShape)
407+ .set_attr<FInferShape>(" FInferShape" , WinogradConv2DInferShape<WinogradConv2DParam> )
407408.set_attr<FInferType>(" FInferType" , Conv2DInferType<WinogradConv2DParam>)
408409.set_attr<FCorrectLayout>(" FCorrectLayout" , Conv2DCorrectLayout<WinogradConv2DParam>)
409410.set_num_outputs(1 )
@@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
412413
413414DMLC_REGISTER_PARAMETER (WinogradConv2DParam);
414415
416+
417+ inline bool Conv2DWinogradNNPACKWTInferType (const nnvm::NodeAttrs& attrs,
418+ std::vector<int >* in_type,
419+ std::vector<int >* out_type) {
420+ const WinogradNNPACKWeightTransformParam& param =
421+ nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed );
422+
423+ CHECK_EQ (in_type->size (), 1U ) << " Input:[weight]" ;
424+ CHECK_EQ (out_type->size (), 1U );
425+
426+ if (param.out_dtype != -1 ) {
427+ NNVM_ASSIGN_OUTPUT_TYPE (attrs, *out_type, 0 , param.out_dtype );
428+ } else {
429+ ElemwiseType<1 , 1 >(attrs, in_type, out_type);
430+ }
431+ return true ;
432+ }
433+
434+ NNVM_REGISTER_OP (_contrib_conv2d_winograd_nnpack_weight_transform)
435+ .describe(R"code( Weight transformation of winograd fast convolution algorithm.
436+ Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
437+ weight transformation in advance.
438+ - **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
439+ )code" NNVM_ADD_FILELINE)
440+ .add_argument(" weight" , " 4D Tensor" , " Weight tensor." )
441+ .add_arguments(WinogradNNPACKWeightTransformParam::__FIELDS__())
442+ .set_attr_parser(ParamParser<WinogradNNPACKWeightTransformParam>)
443+ .set_attr<FGetAttrDict>(" FGetAttrDict" , ParamGetAttrDict<WinogradNNPACKWeightTransformParam>)
444+ .set_attr<FInferShape>(" FInferShape" , [](const nnvm::NodeAttrs& attrs,
445+ std::vector<TShape> *in_shape,
446+ std::vector<TShape> *out_shape) {
447+ const TShape &wshape = (*in_shape)[0 ];
448+ CHECK_EQ (wshape.ndim (), 4 ) << " Weight should be a 4 dimensional tensor" ;
449+ TShape oshape ({wshape[0 ], wshape[1 ], 8 , 8 });
450+ NNVM_ASSIGN_OUTPUT_SHAPE (attrs, *out_shape, 0 , oshape);
451+ return true ;
452+ })
453+ .set_attr<FCorrectLayout>(" FCorrectLayout" , [](const NodeAttrs& attrs,
454+ std::vector<Layout> *ilayouts,
455+ const std::vector<Layout> *last_ilayouts,
456+ std::vector<Layout> *olayouts) {
457+ Layout layout (" OIHW" );
458+ NNVM_ASSIGN_LAYOUT (*ilayouts, 0 , layout);
459+ NNVM_ASSIGN_LAYOUT (*olayouts, 0 , layout);
460+ return true ;
461+ })
462+ .set_attr<FInferType>(" FInferType" , Conv2DWinogradNNPACKWTInferType)
463+ .set_num_outputs(1 )
464+ .set_num_inputs(1 )
465+ .set_support_level(5 );
466+
467+ DMLC_REGISTER_PARAMETER (WinogradNNPACKWeightTransformParam);
468+
469+ NNVM_REGISTER_OP (_contrib_conv2d_winograd_nnpack_without_weight_transform)
470+ .describe(R"code( Compute conv2d with winograd nnpack.
471+ - **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
472+ - **weight**: Any shape
473+ We do not check shape for this input tensor.
474+ - **bias**: (channels,)
475+ - **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
476+ )code" NNVM_ADD_FILELINE)
477+ .add_argument(" data" , " 4D Tensor" , " Input data." )
478+ .add_argument(" weight" , " 4D Tensor" , " Transformed weight tensor." )
479+ .add_argument(" bias" , " 1D Tensor" , " Bias parameter." )
480+ .add_arguments(Conv2DParam::__FIELDS__())
481+ .set_attr_parser(ParamParser<Conv2DParam>)
482+ .set_attr<FGetAttrDict>(" FGetAttrDict" , ParamGetAttrDict<Conv2DParam>)
483+ .set_attr<FListInputNames>(" FListInputNames" , UseBiasListInputNames<Conv2DParam>)
484+ .set_attr<FInferShape>(" FInferShape" , WinogradConv2DInferShape<Conv2DParam>)
485+ .set_attr<FInferType>(" FInferType" , Conv2DInferType<Conv2DParam>)
486+ .set_attr<FCorrectLayout>(" FCorrectLayout" , Conv2DCorrectLayout<Conv2DParam>)
487+ .set_num_outputs(1 )
488+ .set_num_inputs(UseBiasNumInputs<Conv2DParam>)
489+ .set_support_level(5 );
490+
491+
415492NNVM_REGISTER_OP (_conv2d_grad)
416493 .describe(R"code( 2D convolution grad.
417494
0 commit comments