1010#include " ../../pass/alter_op_layout.h"
1111#include " ../layout.h"
1212
13+
1314namespace tvm {
1415namespace relay {
1516
@@ -499,8 +500,8 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
499500
500501TVM_REGISTER_API (" relay.op.nn._make.contrib_conv2d_winograd_weight_transform" )
501502.set_body([](const TVMArgs& args, TVMRetValue* rv) {
502- runtime::detail::unpack_call<Expr, 2 >(MakeConv2DWinogradWeightTransform, args, rv);
503- });
503+ runtime::detail::unpack_call<Expr, 2 >(MakeConv2DWinogradWeightTransform, args, rv);
504+ });
504505
505506
506507RELAY_REGISTER_OP (" nn.contrib_conv2d_winograd_weight_transform" )
@@ -521,17 +522,17 @@ weight transformation in advance.
521522// Positional relay function to create conv2d winograd nnpack operator
522523// used by frontend FFI.
523524Expr MakeConv2DWinogradNNPACK (Expr data,
524- Expr weight,
525- Array<IndexExpr> strides,
526- Array<IndexExpr> padding,
527- Array<IndexExpr> dilation,
528- int groups,
529- IndexExpr channels,
530- Array<IndexExpr> kernel_size,
531- std::string data_layout,
532- std::string kernel_layout,
533- std::string out_layout,
534- DataType out_dtype) {
525+ Expr weight,
526+ Array<IndexExpr> strides,
527+ Array<IndexExpr> padding,
528+ Array<IndexExpr> dilation,
529+ int groups,
530+ IndexExpr channels,
531+ Array<IndexExpr> kernel_size,
532+ std::string data_layout,
533+ std::string kernel_layout,
534+ std::string out_layout,
535+ DataType out_dtype) {
535536 auto attrs = make_node<Conv2DAttrs>();
536537 attrs->strides = std::move (strides);
537538 attrs->padding = std::move (padding);
@@ -547,17 +548,15 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
547548 return CallNode::make (op, {data, weight}, Attrs (attrs), {});
548549}
549550
550-
551551TVM_REGISTER_API (" relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform" )
552552.set_body([](const TVMArgs& args, TVMRetValue* rv) {
553- runtime::detail::unpack_call<Expr, 12 >(MakeConv2DWinogradNNPACK, args, rv);
554- });
555-
553+ runtime::detail::unpack_call<Expr, 12 >(MakeConv2DWinogradNNPACK, args, rv);
554+ });
556555
557556RELAY_REGISTER_OP (" nn.contrib_conv2d_winograd_nnpack_without_weight_transform" )
558557.describe(R"code( Compute conv2d with winograd nnpack. Only supports NCHW layout.
559- This operator assumes the weight tensor is already pre-transformed by
560- nn.contrib_conv2d_winograd_nnpack_weight_transform.
558+ This operator assumes the weight tensor is already pre-transformed by
559+ nn.contrib_conv2d_winograd_nnpack_weight_transform.
561560
562561- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
563562- **weight**: Any shape
@@ -572,60 +571,63 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
572571.add_argument(" weight" , " Tensor" , " The weight tensor." )
573572.set_support_level(10 )
574573.add_type_rel(" Conv2DWinogradNNPACKRel" , Conv2DWinogradRel<Conv2DAttrs>)
575- .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" ,
576- Conv2DInferCorrectLayout<Conv2DAttrs>);
574+ .set_attr<FInferCorrectLayout>(" FInferCorrectLayout" , Conv2DInferCorrectLayout<Conv2DAttrs>);
577575
578576// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
579577TVM_REGISTER_NODE_TYPE (Conv2DWinogradNNPACKWeightTransformAttrs);
580578
581579bool Conv2DWinogradNNPACKWeightTransformRel (const Array<Type>& types,
582- int num_inputs,
583- const Attrs& attrs,
584- const TypeReporter& reporter) {
580+ int num_inputs,
581+ const Attrs& attrs,
582+ const TypeReporter& reporter) {
585583 CHECK_EQ (types.size (), 2 );
586584 const auto * data = types[0 ].as <TensorTypeNode>();
587585 if (data == nullptr ) return false ;
588586
589- const Conv2DWinogradNNPACKWeightTransformAttrs* param = attrs.as <Conv2DWinogradNNPACKWeightTransformAttrs>();
587+ const Conv2DWinogradNNPACKWeightTransformAttrs* param =
588+ attrs.as <Conv2DWinogradNNPACKWeightTransformAttrs>();
590589 CHECK (param != nullptr );
591590
592591 CHECK_EQ (data->shape .size (), 4 ) << " Only support NCHW normal kernel layout" ;
593592
594593 // each pad width element should be a pair of positive integers
595- std::vector<IndexExpr> oshape {
594+ std::vector<IndexExpr> oshape{
596595 data->shape [0 ],
597596 data->shape [1 ],
598597 8 ,
599598 8 ,
600599 };
601600
602- reporter->Assign (types[1 ], TensorTypeNode::make (Array<IndexExpr>(oshape),
603- data->dtype ));
601+ DataType out_dtype = param->out_dtype ;
602+ if (out_dtype.bits () == 0 ) {
603+ out_dtype = data->dtype ;
604+ }
605+ reporter->Assign (types[1 ], TensorTypeNode::make (Array<IndexExpr>(oshape), out_dtype));
604606 return true ;
605607}
606608
607609Expr MakeConv2DWinogradNNPACKWeightTransform (Expr weight,
608- int convolution_algorithm) {
610+ int convolution_algorithm,
611+ DataType out_dtype) {
609612 auto attrs = make_node<Conv2DWinogradNNPACKWeightTransformAttrs>();
610613 attrs->convolution_algorithm = convolution_algorithm;
614+ attrs->out_dtype = std::move (out_dtype);
611615 static const Op& op = Op::Get (" nn.contrib_conv2d_winograd_nnpack_weight_transform" );
612616 return CallNode::make (op, {weight}, Attrs (attrs), {});
613617}
614618
615-
616619TVM_REGISTER_API (" relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform" )
617620.set_body([](const TVMArgs& args, TVMRetValue* rv) {
618- runtime::detail::unpack_call<Expr, 2 >(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
619- });
620-
621+ runtime::detail::unpack_call<Expr, 3 >(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
622+ });
621623
622624RELAY_REGISTER_OP (" nn.contrib_conv2d_winograd_nnpack_weight_transform" )
623625.describe(R"code( Weight transformation of winograd fast convolution algorithm with NNPACK.
624-
625626Separate this into another symbol in order to enable Precompute Pass to compute the
626627weight transformation in advance.
627628
628629- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
630+
629631)code" TVM_ADD_FILELINE)
630632.set_attrs_type_key(" relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs" )
631633.set_num_inputs(1 )
0 commit comments