@@ -2885,15 +2885,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
28852885// WinogradInputTransformOp
28862886// ===----------------------------------------------------------------------===//
28872887
2888- Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2889- Location loc) {
2890- if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2891- auto intAttr = cast<IntegerAttr>(attr);
2892- return builder.create <arith::ConstantOp>(loc, intAttr);
2893- }
2894- return opFoldResult.get <Value>();
2895- }
2896-
28972888LogicalResult WinogradInputTransformOp::verify () {
28982889 auto inputType = cast<ShapedType>(getInput ().getType ());
28992890 ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2934,9 +2925,9 @@ LogicalResult WinogradInputTransformOp::verify() {
29342925SmallVector<Range>
29352926WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
29362927 Location loc = getLoc ();
2937- auto indexType = builder.getIndexType ();
2938- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2939- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2928+ IndexType indexType = builder.getIndexType ();
2929+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
2930+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
29402931 Value output = getOutput ();
29412932 SmallVector<Range> loopBounds (6 );
29422933 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -2958,21 +2949,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
29582949 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
29592950 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
29602951 SmallVector<OpFoldResult> &resultSizes) {
2961- auto zeroAttr = builder.getI64IntegerAttr (0 );
2962- auto oneAttr = builder.getI64IntegerAttr (1 );
2952+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
2953+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
29632954
2964- resultOffsets.push_back (zeroAttr);
2965- resultOffsets.push_back (zeroAttr);
2966- resultOffsets.push_back (offsets[2 ]);
2967- resultOffsets.push_back (offsets[3 ]);
2968- resultOffsets.push_back (zeroAttr);
2969- resultOffsets.push_back (zeroAttr);
2970- resultSizes.push_back (sizes[0 ]);
2971- resultSizes.push_back (sizes[1 ]);
2972- resultSizes.push_back (oneAttr);
2973- resultSizes.push_back (oneAttr);
2974- resultSizes.push_back (sizes[4 ]);
2975- resultSizes.push_back (sizes[5 ]);
2955+ resultOffsets.append (
2956+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
2957+ resultSizes.append (
2958+ {sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
29762959
29772960 return success ();
29782961}
@@ -2981,41 +2964,37 @@ FailureOr<TilingResult>
29812964WinogradInputTransformOp::getTiledImplementation (OpBuilder &builder,
29822965 ArrayRef<OpFoldResult> offsets,
29832966 ArrayRef<OpFoldResult> sizes) {
2984- auto oneAttr = builder.getI64IntegerAttr (1 );
2985- auto zeroAttr = builder.getI64IntegerAttr (0 );
2967+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
2968+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
29862969 Value input = getInput ();
29872970 auto inputType = cast<ShapedType>(input.getType ());
2988- auto inputShape = inputType.getShape ();
2971+ ArrayRef< int64_t > inputShape = inputType.getShape ();
29892972 int64_t inputH = inputShape[1 ];
29902973 int64_t inputW = inputShape[2 ];
29912974 int64_t m = getM ();
29922975 int64_t r = getR ();
29932976 int64_t alpha = m + r - 1 ;
29942977 int64_t alphaH = inputH != 1 ? alpha : 1 ;
29952978 int64_t alphaW = inputW != 1 ? alpha : 1 ;
2996- auto alphaHAttr = builder.getI64IntegerAttr (alphaH);
2997- auto alphaWAttr = builder.getI64IntegerAttr (alphaW);
2979+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
2980+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
29982981
29992982 Location loc = getLoc ();
30002983 SmallVector<Value> tiledOperands;
30012984 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30022985
3003- auto context = builder.getContext ();
2986+ MLIRContext * context = builder.getContext ();
30042987 auto affineMap =
30052988 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
30062989 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3007- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2990+ loc, affineMap,
2991+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
30082992 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3009- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3010-
3011- sliceOffsets.push_back (zeroAttr);
3012- sliceOffsets.push_back (mappedOffset1);
3013- sliceOffsets.push_back (mappedOffset2);
3014- sliceOffsets.push_back (zeroAttr);
3015- sliceSizes.push_back (sizes[4 ]);
3016- sliceSizes.push_back (alphaHAttr);
3017- sliceSizes.push_back (alphaWAttr);
3018- sliceSizes.push_back (sizes[5 ]);
2993+ loc, affineMap,
2994+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
2995+
2996+ sliceOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
2997+ sliceSizes.append ({sizes[4 ], alphaHAttr, alphaWAttr, sizes[5 ]});
30192998 SmallVector<OpFoldResult> inputStrides (4 , oneAttr);
30202999 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
30213000 loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
@@ -3030,7 +3009,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
30303009 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
30313010 loc, getOutput (), sliceOffsets, sliceSizes, outputStrides));
30323011
3033- SmallVector<Type, 4 > resultTypes;
3012+ SmallVector<Type> resultTypes;
30343013 resultTypes.push_back (tiledOperands[1 ].getType ());
30353014 Operation *tiledOp =
30363015 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
@@ -3083,9 +3062,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
30833062SmallVector<Range>
30843063WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
30853064 Location loc = getLoc ();
3086- auto indexType = builder.getIndexType ();
3087- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
3088- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
3065+ IndexType indexType = builder.getIndexType ();
3066+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
3067+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
30893068 Value value = getValue ();
30903069 SmallVector<Range> loopBounds (6 );
30913070 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -3107,57 +3086,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
31073086 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
31083087 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31093088 SmallVector<OpFoldResult> &resultSizes) {
3110- auto zeroAttr = builder.getI64IntegerAttr (0 );
3089+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
31113090 Value output = getOutput ();
31123091 auto outputType = cast<ShapedType>(output.getType ());
3113- auto outputShape = outputType.getShape ();
3092+ ArrayRef< int64_t > outputShape = outputType.getShape ();
31143093 int64_t outputH = outputShape[1 ];
31153094 int64_t outputW = outputShape[2 ];
31163095 int64_t m = getM ();
3117- auto heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3118- auto widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
3096+ IntegerAttr heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3097+ IntegerAttr widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
31193098
31203099 Location loc = getLoc ();
3121- auto context = builder.getContext ();
3100+ MLIRContext * context = builder.getContext ();
31223101 auto affineMap =
31233102 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
31243103 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3125- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
3104+ loc, affineMap,
3105+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
31263106 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3127- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3128-
3129- resultOffsets.push_back (zeroAttr);
3130- resultOffsets.push_back (mappedOffset1);
3131- resultOffsets.push_back (mappedOffset2);
3132- resultOffsets.push_back (zeroAttr);
3133- resultSizes.push_back (sizes[4 ]);
3134- resultSizes.push_back (heightM);
3135- resultSizes.push_back (widthM);
3136- resultSizes.push_back (sizes[5 ]);
3107+ loc, affineMap,
3108+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
3109+
3110+ resultOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
3111+ resultSizes.append ({sizes[4 ], heightM, widthM, sizes[5 ]});
31373112 return success ();
31383113}
31393114
31403115FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation (
31413116 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
31423117 ArrayRef<OpFoldResult> sizes) {
3143- auto oneAttr = builder.getI64IntegerAttr (1 );
3144- auto zeroAttr = builder.getI64IntegerAttr (0 );
3118+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3119+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
31453120 Location loc = getLoc ();
31463121 SmallVector<Value> tiledOperands;
31473122 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
31483123
3149- sliceOffsets.push_back (zeroAttr);
3150- sliceOffsets.push_back (zeroAttr);
3151- sliceOffsets.push_back (offsets[2 ]);
3152- sliceOffsets.push_back (offsets[3 ]);
3153- sliceOffsets.push_back (zeroAttr);
3154- sliceOffsets.push_back (zeroAttr);
3155- sliceSizes.push_back (sizes[0 ]);
3156- sliceSizes.push_back (sizes[1 ]);
3157- sliceSizes.push_back (oneAttr);
3158- sliceSizes.push_back (oneAttr);
3159- sliceSizes.push_back (sizes[4 ]);
3160- sliceSizes.push_back (sizes[5 ]);
3124+ sliceOffsets.append (
3125+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
3126+ sliceSizes.append ({sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
31613127 SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
31623128 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
31633129 loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
@@ -3172,7 +3138,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
31723138 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
31733139 loc, getOutput (), sliceOffsets, sliceSizes, strides));
31743140
3175- SmallVector<Type, 4 > resultTypes;
3141+ SmallVector<Type> resultTypes;
31763142 resultTypes.push_back (tiledOperands[1 ].getType ());
31773143 Operation *tiledOp =
31783144 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
0 commit comments