Skip to content

Commit 0c121bb

Browse files
committed
Address ftynse's comments
1 parent bad2ae0 commit 0c121bb

File tree

2 files changed

+226
-228
lines changed

2 files changed

+226
-228
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 44 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,15 +2885,6 @@ LogicalResult WinogradFilterTransformOp::verify() {
28852885
// WinogradInputTransformOp
28862886
//===----------------------------------------------------------------------===//
28872887

2888-
Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
2889-
Location loc) {
2890-
if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
2891-
auto intAttr = cast<IntegerAttr>(attr);
2892-
return builder.create<arith::ConstantOp>(loc, intAttr);
2893-
}
2894-
return opFoldResult.get<Value>();
2895-
}
2896-
28972888
LogicalResult WinogradInputTransformOp::verify() {
28982889
auto inputType = cast<ShapedType>(getInput().getType());
28992890
ArrayRef<int64_t> inputShape = inputType.getShape();
@@ -2934,9 +2925,9 @@ LogicalResult WinogradInputTransformOp::verify() {
29342925
SmallVector<Range>
29352926
WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
29362927
Location loc = getLoc();
2937-
auto indexType = builder.getIndexType();
2938-
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
2939-
auto oneAttr = builder.getIntegerAttr(indexType, 1);
2928+
IndexType indexType = builder.getIndexType();
2929+
IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
2930+
IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
29402931
Value output = getOutput();
29412932
SmallVector<Range> loopBounds(6);
29422933
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -2958,21 +2949,13 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
29582949
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
29592950
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
29602951
SmallVector<OpFoldResult> &resultSizes) {
2961-
auto zeroAttr = builder.getI64IntegerAttr(0);
2962-
auto oneAttr = builder.getI64IntegerAttr(1);
2952+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2953+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
29632954

2964-
resultOffsets.push_back(zeroAttr);
2965-
resultOffsets.push_back(zeroAttr);
2966-
resultOffsets.push_back(offsets[2]);
2967-
resultOffsets.push_back(offsets[3]);
2968-
resultOffsets.push_back(zeroAttr);
2969-
resultOffsets.push_back(zeroAttr);
2970-
resultSizes.push_back(sizes[0]);
2971-
resultSizes.push_back(sizes[1]);
2972-
resultSizes.push_back(oneAttr);
2973-
resultSizes.push_back(oneAttr);
2974-
resultSizes.push_back(sizes[4]);
2975-
resultSizes.push_back(sizes[5]);
2955+
resultOffsets.append(
2956+
{zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
2957+
resultSizes.append(
2958+
{sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
29762959

29772960
return success();
29782961
}
@@ -2981,41 +2964,37 @@ FailureOr<TilingResult>
29812964
WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
29822965
ArrayRef<OpFoldResult> offsets,
29832966
ArrayRef<OpFoldResult> sizes) {
2984-
auto oneAttr = builder.getI64IntegerAttr(1);
2985-
auto zeroAttr = builder.getI64IntegerAttr(0);
2967+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
2968+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
29862969
Value input = getInput();
29872970
auto inputType = cast<ShapedType>(input.getType());
2988-
auto inputShape = inputType.getShape();
2971+
ArrayRef<int64_t> inputShape = inputType.getShape();
29892972
int64_t inputH = inputShape[1];
29902973
int64_t inputW = inputShape[2];
29912974
int64_t m = getM();
29922975
int64_t r = getR();
29932976
int64_t alpha = m + r - 1;
29942977
int64_t alphaH = inputH != 1 ? alpha : 1;
29952978
int64_t alphaW = inputW != 1 ? alpha : 1;
2996-
auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
2997-
auto alphaWAttr = builder.getI64IntegerAttr(alphaW);
2979+
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2980+
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
29982981

29992982
Location loc = getLoc();
30002983
SmallVector<Value> tiledOperands;
30012984
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30022985

3003-
auto context = builder.getContext();
2986+
MLIRContext *context = builder.getContext();
30042987
auto affineMap =
30052988
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
30062989
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
3007-
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
2990+
loc, affineMap,
2991+
getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
30082992
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
3009-
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
3010-
3011-
sliceOffsets.push_back(zeroAttr);
3012-
sliceOffsets.push_back(mappedOffset1);
3013-
sliceOffsets.push_back(mappedOffset2);
3014-
sliceOffsets.push_back(zeroAttr);
3015-
sliceSizes.push_back(sizes[4]);
3016-
sliceSizes.push_back(alphaHAttr);
3017-
sliceSizes.push_back(alphaWAttr);
3018-
sliceSizes.push_back(sizes[5]);
2993+
loc, affineMap,
2994+
getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
2995+
2996+
sliceOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
2997+
sliceSizes.append({sizes[4], alphaHAttr, alphaWAttr, sizes[5]});
30192998
SmallVector<OpFoldResult> inputStrides(4, oneAttr);
30202999
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
30213000
loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
@@ -3030,7 +3009,7 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
30303009
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
30313010
loc, getOutput(), sliceOffsets, sliceSizes, outputStrides));
30323011

3033-
SmallVector<Type, 4> resultTypes;
3012+
SmallVector<Type> resultTypes;
30343013
resultTypes.push_back(tiledOperands[1].getType());
30353014
Operation *tiledOp =
30363015
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
@@ -3083,9 +3062,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
30833062
SmallVector<Range>
30843063
WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
30853064
Location loc = getLoc();
3086-
auto indexType = builder.getIndexType();
3087-
auto zeroAttr = builder.getIntegerAttr(indexType, 0);
3088-
auto oneAttr = builder.getIntegerAttr(indexType, 1);
3065+
IndexType indexType = builder.getIndexType();
3066+
IntegerAttr zeroAttr = builder.getIntegerAttr(indexType, 0);
3067+
IntegerAttr oneAttr = builder.getIntegerAttr(indexType, 1);
30893068
Value value = getValue();
30903069
SmallVector<Range> loopBounds(6);
30913070
for (unsigned dim = 0; dim < 6; ++dim) {
@@ -3107,57 +3086,44 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
31073086
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
31083087
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
31093088
SmallVector<OpFoldResult> &resultSizes) {
3110-
auto zeroAttr = builder.getI64IntegerAttr(0);
3089+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
31113090
Value output = getOutput();
31123091
auto outputType = cast<ShapedType>(output.getType());
3113-
auto outputShape = outputType.getShape();
3092+
ArrayRef<int64_t> outputShape = outputType.getShape();
31143093
int64_t outputH = outputShape[1];
31153094
int64_t outputW = outputShape[2];
31163095
int64_t m = getM();
3117-
auto heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
3118-
auto widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
3096+
IntegerAttr heightM = builder.getI64IntegerAttr(outputH != 1 ? m : 1);
3097+
IntegerAttr widthM = builder.getI64IntegerAttr(outputW != 1 ? m : 1);
31193098

31203099
Location loc = getLoc();
3121-
auto context = builder.getContext();
3100+
MLIRContext *context = builder.getContext();
31223101
auto affineMap =
31233102
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
31243103
Value mappedOffset1 = builder.create<affine::AffineApplyOp>(
3125-
loc, affineMap, getValueFromOpFoldResult(offsets[2], builder, loc));
3104+
loc, affineMap,
3105+
getValueOrCreateConstantIndexOp(builder, loc, offsets[2]));
31263106
Value mappedOffset2 = builder.create<affine::AffineApplyOp>(
3127-
loc, affineMap, getValueFromOpFoldResult(offsets[3], builder, loc));
3128-
3129-
resultOffsets.push_back(zeroAttr);
3130-
resultOffsets.push_back(mappedOffset1);
3131-
resultOffsets.push_back(mappedOffset2);
3132-
resultOffsets.push_back(zeroAttr);
3133-
resultSizes.push_back(sizes[4]);
3134-
resultSizes.push_back(heightM);
3135-
resultSizes.push_back(widthM);
3136-
resultSizes.push_back(sizes[5]);
3107+
loc, affineMap,
3108+
getValueOrCreateConstantIndexOp(builder, loc, offsets[3]));
3109+
3110+
resultOffsets.append({zeroAttr, mappedOffset1, mappedOffset2, zeroAttr});
3111+
resultSizes.append({sizes[4], heightM, widthM, sizes[5]});
31373112
return success();
31383113
}
31393114

31403115
FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
31413116
OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
31423117
ArrayRef<OpFoldResult> sizes) {
3143-
auto oneAttr = builder.getI64IntegerAttr(1);
3144-
auto zeroAttr = builder.getI64IntegerAttr(0);
3118+
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3119+
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
31453120
Location loc = getLoc();
31463121
SmallVector<Value> tiledOperands;
31473122
SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
31483123

3149-
sliceOffsets.push_back(zeroAttr);
3150-
sliceOffsets.push_back(zeroAttr);
3151-
sliceOffsets.push_back(offsets[2]);
3152-
sliceOffsets.push_back(offsets[3]);
3153-
sliceOffsets.push_back(zeroAttr);
3154-
sliceOffsets.push_back(zeroAttr);
3155-
sliceSizes.push_back(sizes[0]);
3156-
sliceSizes.push_back(sizes[1]);
3157-
sliceSizes.push_back(oneAttr);
3158-
sliceSizes.push_back(oneAttr);
3159-
sliceSizes.push_back(sizes[4]);
3160-
sliceSizes.push_back(sizes[5]);
3124+
sliceOffsets.append(
3125+
{zeroAttr, zeroAttr, offsets[2], offsets[3], zeroAttr, zeroAttr});
3126+
sliceSizes.append({sizes[0], sizes[1], oneAttr, oneAttr, sizes[4], sizes[5]});
31613127
SmallVector<OpFoldResult> sliceStrides(6, oneAttr);
31623128
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
31633129
loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
@@ -3172,7 +3138,7 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
31723138
tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
31733139
loc, getOutput(), sliceOffsets, sliceSizes, strides));
31743140

3175-
SmallVector<Type, 4> resultTypes;
3141+
SmallVector<Type> resultTypes;
31763142
resultTypes.push_back(tiledOperands[1].getType());
31773143
Operation *tiledOp =
31783144
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

0 commit comments

Comments
 (0)