@@ -2776,15 +2776,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
27762776// WinogradInputTransformOp
27772777// ===----------------------------------------------------------------------===//
27782778
2779- Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2780- Location loc) {
2781- if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2782- auto intAttr = cast<IntegerAttr>(attr);
2783- return builder.create <arith::ConstantOp>(loc, intAttr);
2784- }
2785- return opFoldResult.get <Value>();
2786- }
2787-
27882779LogicalResult WinogradInputTransformOp::verify () {
27892780 auto inputType = cast<ShapedType>(getInput ().getType ());
27902781 ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2825,9 +2816,9 @@ LogicalResult WinogradInputTransformOp::verify() {
28252816SmallVector<Range>
28262817WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
28272818 Location loc = getLoc ();
2828- auto indexType = builder.getIndexType ();
2829- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2830- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2819+ IndexType indexType = builder.getIndexType ();
2820+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
2821+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
28312822 Value output = getOutput ();
28322823 SmallVector<Range> loopBounds (6 );
28332824 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -2849,21 +2840,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
28492840 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
28502841 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
28512842 SmallVector<OpFoldResult> &resultSizes) {
2852- auto zeroAttr = builder.getI64IntegerAttr (0 );
2853- auto oneAttr = builder.getI64IntegerAttr (1 );
2843+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
2844+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
28542845
2855- resultOffsets.push_back (zeroAttr);
2856- resultOffsets.push_back (zeroAttr);
2857- resultOffsets.push_back (offsets[2 ]);
2858- resultOffsets.push_back (offsets[3 ]);
2859- resultOffsets.push_back (zeroAttr);
2860- resultOffsets.push_back (zeroAttr);
2861- resultSizes.push_back (sizes[0 ]);
2862- resultSizes.push_back (sizes[1 ]);
2863- resultSizes.push_back (oneAttr);
2864- resultSizes.push_back (oneAttr);
2865- resultSizes.push_back (sizes[4 ]);
2866- resultSizes.push_back (sizes[5 ]);
2846+ resultOffsets.append (
2847+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
2848+ resultSizes.append (
2849+ {sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
28672850
28682851 return success ();
28692852}
@@ -2872,41 +2855,37 @@ FailureOr<TilingResult>
28722855WinogradInputTransformOp::getTiledImplementation (OpBuilder &builder,
28732856 ArrayRef<OpFoldResult> offsets,
28742857 ArrayRef<OpFoldResult> sizes) {
2875- auto oneAttr = builder.getI64IntegerAttr (1 );
2876- auto zeroAttr = builder.getI64IntegerAttr (0 );
2858+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
2859+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
28772860 Value input = getInput ();
28782861 auto inputType = cast<ShapedType>(input.getType ());
2879- auto inputShape = inputType.getShape ();
2862+ ArrayRef< int64_t > inputShape = inputType.getShape ();
28802863 int64_t inputH = inputShape[1 ];
28812864 int64_t inputW = inputShape[2 ];
28822865 int64_t m = getM ();
28832866 int64_t r = getR ();
28842867 int64_t alpha = m + r - 1 ;
28852868 int64_t alphaH = inputH != 1 ? alpha : 1 ;
28862869 int64_t alphaW = inputW != 1 ? alpha : 1 ;
2887- auto alphaHAttr = builder.getI64IntegerAttr (alphaH);
2888- auto alphaWAttr = builder.getI64IntegerAttr (alphaW);
2870+ IntegerAttr alphaHAttr = builder.getI64IntegerAttr (alphaH);
2871+ IntegerAttr alphaWAttr = builder.getI64IntegerAttr (alphaW);
28892872
28902873 Location loc = getLoc ();
28912874 SmallVector<Value> tiledOperands;
28922875 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
28932876
2894- auto context = builder.getContext ();
2877+ MLIRContext * context = builder.getContext ();
28952878 auto affineMap =
28962879 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
28972880 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2898- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2881+ loc, affineMap,
2882+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
28992883 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2900- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
2901-
2902- sliceOffsets.push_back (zeroAttr);
2903- sliceOffsets.push_back (mappedOffset1);
2904- sliceOffsets.push_back (mappedOffset2);
2905- sliceOffsets.push_back (zeroAttr);
2906- sliceSizes.push_back (sizes[4 ]);
2907- sliceSizes.push_back (alphaHAttr);
2908- sliceSizes.push_back (alphaWAttr);
2909- sliceSizes.push_back (sizes[5 ]);
2884+ loc, affineMap,
2885+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
2886+
2887+ sliceOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
2888+ sliceSizes.append ({sizes[4 ], alphaHAttr, alphaWAttr, sizes[5 ]});
29102889 SmallVector<OpFoldResult> inputStrides (4 , oneAttr);
29112890 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
29122891 loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
@@ -2921,7 +2900,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
29212900 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
29222901 loc, getOutput (), sliceOffsets, sliceSizes, outputStrides));
29232902
2924- SmallVector<Type, 4 > resultTypes;
2903+ SmallVector<Type> resultTypes;
29252904 resultTypes.push_back (tiledOperands[1 ].getType ());
29262905 Operation *tiledOp =
29272906 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
@@ -2974,9 +2953,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
29742953SmallVector<Range>
29752954WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
29762955 Location loc = getLoc ();
2977- auto indexType = builder.getIndexType ();
2978- auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2979- auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2956+ IndexType indexType = builder.getIndexType ();
2957+ IntegerAttr zeroAttr = builder.getIntegerAttr (indexType, 0 );
2958+ IntegerAttr oneAttr = builder.getIntegerAttr (indexType, 1 );
29802959 Value value = getValue ();
29812960 SmallVector<Range> loopBounds (6 );
29822961 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
@@ -2998,57 +2977,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
29982977 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
29992978 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
30002979 SmallVector<OpFoldResult> &resultSizes) {
3001- auto zeroAttr = builder.getI64IntegerAttr (0 );
2980+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
30022981 Value output = getOutput ();
30032982 auto outputType = cast<ShapedType>(output.getType ());
3004- auto outputShape = outputType.getShape ();
2983+ ArrayRef< int64_t > outputShape = outputType.getShape ();
30052984 int64_t outputH = outputShape[1 ];
30062985 int64_t outputW = outputShape[2 ];
30072986 int64_t m = getM ();
3008- auto heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3009- auto widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
2987+ IntegerAttr heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
2988+ IntegerAttr widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
30102989
30112990 Location loc = getLoc ();
3012- auto context = builder.getContext ();
2991+ MLIRContext * context = builder.getContext ();
30132992 auto affineMap =
30142993 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
30152994 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3016- loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2995+ loc, affineMap,
2996+ getValueOrCreateConstantIndexOp (builder, loc, offsets[2 ]));
30172997 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3018- loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3019-
3020- resultOffsets.push_back (zeroAttr);
3021- resultOffsets.push_back (mappedOffset1);
3022- resultOffsets.push_back (mappedOffset2);
3023- resultOffsets.push_back (zeroAttr);
3024- resultSizes.push_back (sizes[4 ]);
3025- resultSizes.push_back (heightM);
3026- resultSizes.push_back (widthM);
3027- resultSizes.push_back (sizes[5 ]);
2998+ loc, affineMap,
2999+ getValueOrCreateConstantIndexOp (builder, loc, offsets[3 ]));
3000+
3001+ resultOffsets.append ({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
3002+ resultSizes.append ({sizes[4 ], heightM, widthM, sizes[5 ]});
30283003 return success ();
30293004}
30303005
30313006FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation (
30323007 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
30333008 ArrayRef<OpFoldResult> sizes) {
3034- auto oneAttr = builder.getI64IntegerAttr (1 );
3035- auto zeroAttr = builder.getI64IntegerAttr (0 );
3009+ IntegerAttr oneAttr = builder.getI64IntegerAttr (1 );
3010+ IntegerAttr zeroAttr = builder.getI64IntegerAttr (0 );
30363011 Location loc = getLoc ();
30373012 SmallVector<Value> tiledOperands;
30383013 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30393014
3040- sliceOffsets.push_back (zeroAttr);
3041- sliceOffsets.push_back (zeroAttr);
3042- sliceOffsets.push_back (offsets[2 ]);
3043- sliceOffsets.push_back (offsets[3 ]);
3044- sliceOffsets.push_back (zeroAttr);
3045- sliceOffsets.push_back (zeroAttr);
3046- sliceSizes.push_back (sizes[0 ]);
3047- sliceSizes.push_back (sizes[1 ]);
3048- sliceSizes.push_back (oneAttr);
3049- sliceSizes.push_back (oneAttr);
3050- sliceSizes.push_back (sizes[4 ]);
3051- sliceSizes.push_back (sizes[5 ]);
3015+ sliceOffsets.append (
3016+ {zeroAttr, zeroAttr, offsets[2 ], offsets[3 ], zeroAttr, zeroAttr});
3017+ sliceSizes.append ({sizes[0 ], sizes[1 ], oneAttr, oneAttr, sizes[4 ], sizes[5 ]});
30523018 SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
30533019 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
30543020 loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
@@ -3063,7 +3029,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
30633029 tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
30643030 loc, getOutput (), sliceOffsets, sliceSizes, strides));
30653031
3066- SmallVector<Type, 4 > resultTypes;
3032+ SmallVector<Type> resultTypes;
30673033 resultTypes.push_back (tiledOperands[1 ].getType ());
30683034 Operation *tiledOp =
30693035 mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
0 commit comments