@@ -2739,6 +2739,122 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27392739 return SmallVector<Value>{result};
27402740}
27412741
2742+ // ===----------------------------------------------------------------------===//
2743+ // WinogradFilterTransformOp
2744+ // ===----------------------------------------------------------------------===//
2745+
2746+ LogicalResult WinogradFilterTransformOp::verify () {
2747+ auto filterType = cast<ShapedType>(getFilter ().getType ());
2748+ ArrayRef<int64_t > filterShape = filterType.getShape ();
2749+ int64_t filterH = filterShape[1 ];
2750+ int64_t filterW = filterShape[2 ];
2751+ int64_t r = getR ();
2752+ int64_t m = getM ();
2753+
2754+ if (filterH != r && filterH != 1 )
2755+ return emitOpError (" expect filter height either equals to r or 1" );
2756+ if (filterW != r && filterW != 1 )
2757+ return emitOpError (" expect filter width either equals to r or 1" );
2758+ if (filterH == 1 && filterW == 1 )
2759+ return emitOpError (" expect either filter height or width equals to r" );
2760+
2761+ SmallVector<int64_t > expectedOutputShape;
2762+ expectedOutputShape.push_back (filterH == r ? m + r - 1 : 1 );
2763+ expectedOutputShape.push_back (filterW == r ? m + r - 1 : 1 );
2764+ expectedOutputShape.push_back (filterShape[3 ]);
2765+ expectedOutputShape.push_back (filterShape[0 ]);
2766+
2767+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2768+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2769+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2770+ return emitOpError (" the output shape is not expected" );
2771+ }
2772+ return success ();
2773+ }
2774+
2775+ // ===----------------------------------------------------------------------===//
2776+ // WinogradInputTransformOp
2777+ // ===----------------------------------------------------------------------===//
2778+
2779+ LogicalResult WinogradInputTransformOp::verify () {
2780+ auto inputType = cast<ShapedType>(getInput ().getType ());
2781+ ArrayRef<int64_t > inputShape = inputType.getShape ();
2782+ int64_t inputH = inputShape[1 ];
2783+ int64_t inputW = inputShape[2 ];
2784+ int m = getM ();
2785+ int r = getR ();
2786+ int64_t tileSize = m + r - 1 ;
2787+ bool leftTransform = inputH != 1 ;
2788+ bool rightTransform = inputW != 1 ;
2789+
2790+ SmallVector<int64_t > expectedOutputShape (6 , inputH);
2791+ if (ShapedType::isDynamic (inputH)) {
2792+ expectedOutputShape[0 ] = tileSize;
2793+ expectedOutputShape[2 ] = ShapedType::kDynamic ;
2794+ } else {
2795+ expectedOutputShape[0 ] = leftTransform ? tileSize : 1 ;
2796+ expectedOutputShape[2 ] = leftTransform ? (inputH - (r - 1 )) / m : 1 ;
2797+ }
2798+ if (ShapedType::isDynamic (inputW)) {
2799+ expectedOutputShape[1 ] = tileSize;
2800+ expectedOutputShape[3 ] = ShapedType::kDynamic ;
2801+ } else {
2802+ expectedOutputShape[1 ] = rightTransform ? tileSize : 1 ;
2803+ expectedOutputShape[3 ] = rightTransform ? (inputW - (r - 1 )) / m : 1 ;
2804+ }
2805+ expectedOutputShape[4 ] = inputShape[0 ];
2806+ expectedOutputShape[5 ] = inputShape[3 ];
2807+
2808+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2809+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2810+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2811+ return emitOpError (" the output shape is not expected" );
2812+ }
2813+ return success ();
2814+ }
2815+
2816+ // ===----------------------------------------------------------------------===//
2817+ // WinogradOutputTransformOp
2818+ // ===----------------------------------------------------------------------===//
2819+
2820+ LogicalResult WinogradOutputTransformOp::verify () {
2821+ auto valueType = cast<ShapedType>(getValue ().getType ());
2822+ ArrayRef<int64_t > valueShape = valueType.getShape ();
2823+ int64_t valueH = valueShape[0 ];
2824+ int64_t valueW = valueShape[1 ];
2825+ int64_t valueTileH = valueShape[2 ];
2826+ int64_t valueTileW = valueShape[3 ];
2827+ int m = getM ();
2828+ int r = getR ();
2829+ bool leftTransform = valueH != 1 ;
2830+ bool rightTransform = valueW != 1 ;
2831+
2832+ SmallVector<int64_t > expectedOutputShape (4 , valueH);
2833+ if (ShapedType::isDynamic (valueH) || ShapedType::isDynamic (valueTileH)) {
2834+ expectedOutputShape[1 ] = ShapedType::kDynamic ;
2835+ } else {
2836+ if (valueH != (leftTransform ? m + r - 1 : 1 ))
2837+ return emitOpError (" expect input height equals to input tile size" );
2838+ expectedOutputShape[1 ] = (leftTransform ? m : 1 ) * valueTileH;
2839+ }
2840+ if (ShapedType::isDynamic (valueW) || ShapedType::isDynamic (valueTileW)) {
2841+ expectedOutputShape[2 ] = ShapedType::kDynamic ;
2842+ } else {
2843+ if (valueW != (rightTransform ? m + r - 1 : 1 ))
2844+ return emitOpError (" expect input width equals to input tile size" );
2845+ expectedOutputShape[2 ] = (rightTransform ? m : 1 ) * valueTileW;
2846+ }
2847+ expectedOutputShape[0 ] = valueShape[4 ];
2848+ expectedOutputShape[3 ] = valueShape[5 ];
2849+
2850+ auto outputType = cast<ShapedType>(getOutput ().getType ());
2851+ ArrayRef<int64_t > outputShape = outputType.getShape ();
2852+ if (failed (verifyCompatibleShape (expectedOutputShape, outputShape))) {
2853+ return emitOpError (" the output shape is not expected" );
2854+ }
2855+ return success ();
2856+ }
2857+
27422858// ===----------------------------------------------------------------------===//
27432859// LinalgDialect
27442860// ===----------------------------------------------------------------------===//
0 commit comments