Skip to content

Commit 6c7be41

Browse files
committed
Support buffers in LinalgFoldUnitExtentDims
This doesn't add any canonicalizations, but executes the same simplification on bufferSemantic linalg.generic ops by using linalg::ReshapeOp instead of linalg::TensorReshapeOp. Differential Revision: https://reviews.llvm.org/D103513
1 parent d8c5a4d commit 6c7be41

File tree

2 files changed

+374
-30
lines changed

2 files changed

+374
-30
lines changed

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 74 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2121
#include "mlir/IR/AffineExpr.h"
2222
#include "mlir/IR/AffineMap.h"
23+
#include "mlir/IR/BuiltinTypes.h"
2324
#include "mlir/Transforms/FoldUtils.h"
2425
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2526
#include "llvm/Support/CommandLine.h"
@@ -256,7 +257,7 @@ struct UnitExtentReplacementInfo {
256257
} // namespace
257258

258259
/// Utility function for replacing operands/results to a linalg generic
259-
/// operation on tensors with unit-extent dimensions. These can be replaced with
260+
/// operation with unit-extent dimensions. These can be replaced with
260261
/// an operand/result with the unit-extent dimension removed. This is only done
261262
/// if the indexing map used to access that didimensionmension has a
262263
/// AffineConstantExpr of value 0. Given the `type` of an result/operand of a
@@ -301,10 +302,19 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
301302
++dim;
302303
}
303304
// Compute the tensor or scalar replacement type.
305+
Type actualType = opOperand->get().getType();
304306
Type elementType = getElementTypeOrSelf(opOperand->get());
305-
Type replacementType = elementType == opOperand->get().getType()
306-
? elementType
307-
: RankedTensorType::get(newShape, elementType);
307+
Type replacementType;
308+
if (elementType == opOperand->get().getType()) {
309+
replacementType = elementType;
310+
} else if (actualType.isa<RankedTensorType>()) {
311+
replacementType = RankedTensorType::get(newShape, elementType);
312+
} else if (actualType.isa<MemRefType>()) {
313+
assert(actualType.cast<MemRefType>().getAffineMaps().empty() &&
314+
"unsupported strided memrefs");
315+
replacementType = MemRefType::get(newShape, elementType);
316+
}
317+
assert(replacementType && "unsupported shaped type");
308318
UnitExtentReplacementInfo info = {replacementType,
309319
AffineMap::get(indexingMap.getNumDims(),
310320
indexingMap.getNumSymbols(),
@@ -324,22 +334,60 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
324334
return reassociationExprs;
325335
}
326336

327-
/// Pattern to replace tensors operands/results that are unit extents.
328-
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
337+
/// Pattern to replace tensor/buffer operands/results that are unit extents.
338+
struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
329339
using OpRewritePattern<GenericOp>::OpRewritePattern;
340+
341+
// Return the original value if the type is unchanged, or reshape it. Return a
342+
// nullptr if this is an unsupported type.
343+
Value maybeExpand(Value result, Type origResultType,
344+
ArrayAttr reassociationMap, Location loc,
345+
PatternRewriter &rewriter) const {
346+
if (origResultType == result.getType())
347+
return result;
348+
if (origResultType.isa<RankedTensorType>()) {
349+
return rewriter.create<linalg::TensorExpandShapeOp>(
350+
loc, origResultType, result,
351+
convertAffineMapArrayToExprs(reassociationMap));
352+
}
353+
if (origResultType.isa<MemRefType>()) {
354+
return rewriter.create<linalg::ExpandShapeOp>(
355+
loc, origResultType, result,
356+
convertAffineMapArrayToExprs(reassociationMap));
357+
}
358+
return nullptr;
359+
};
360+
361+
// Return the original value if the type is unchanged, or reshape it. Return a
362+
// nullptr if this is an unsupported type.
363+
Value maybeCollapse(Value operand, Type newInputOutputType,
364+
ArrayAttr reassociationMap, Location loc,
365+
PatternRewriter &rewriter) const {
366+
auto operandType = operand.getType();
367+
if (operandType == newInputOutputType)
368+
return operand;
369+
if (operandType.isa<MemRefType>()) {
370+
return rewriter.create<linalg::CollapseShapeOp>(
371+
loc, newInputOutputType, operand,
372+
convertAffineMapArrayToExprs(reassociationMap));
373+
}
374+
if (operandType.isa<RankedTensorType>()) {
375+
return rewriter.create<linalg::TensorCollapseShapeOp>(
376+
loc, newInputOutputType, operand,
377+
convertAffineMapArrayToExprs(reassociationMap));
378+
}
379+
return nullptr;
380+
};
381+
330382
LogicalResult matchAndRewrite(GenericOp genericOp,
331383
PatternRewriter &rewriter) const override {
332-
if (!genericOp.hasTensorSemantics())
333-
return failure();
334-
335384
MLIRContext *context = rewriter.getContext();
336385
Location loc = genericOp.getLoc();
337386

338387
SmallVector<AffineMap> newIndexingMaps;
339388
SmallVector<ArrayAttr> reassociationMaps;
340389
SmallVector<Type> newInputOutputTypes;
341390
bool doCanonicalization = false;
342-
343391
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
344392
UnitExtentReplacementInfo replacementInfo =
345393
replaceUnitExtents(genericOp, opOperand, context);
@@ -362,14 +410,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
362410
auto insertReshapes = [&](ValueRange values) {
363411
SmallVector<Value, 4> res;
364412
res.reserve(values.size());
365-
for (auto operand : llvm::enumerate(values)) {
366-
if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
367-
res.push_back(operand.value());
368-
else {
369-
res.push_back(rewriter.create<TensorCollapseShapeOp>(
370-
loc, newInputOutputTypes[flattenedIdx], operand.value(),
371-
convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])));
372-
}
413+
for (auto operand : values) {
414+
auto reshapedValue =
415+
maybeCollapse(operand, newInputOutputTypes[flattenedIdx],
416+
reassociationMaps[flattenedIdx], loc, rewriter);
417+
assert(reshapedValue &&
418+
"expected ranked MemRef or Tensor operand type");
419+
res.push_back(reshapedValue);
373420
++flattenedIdx;
374421
}
375422
return res;
@@ -396,15 +443,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
396443
SmallVector<Value, 4> resultReplacements;
397444
for (auto result : llvm::enumerate(replacementOp.getResults())) {
398445
unsigned index = result.index() + replacementOp.getNumInputs();
399-
RankedTensorType origResultType = genericOp.getResult(result.index())
400-
.getType()
401-
.template cast<RankedTensorType>();
402-
if (origResultType != result.value().getType()) {
403-
resultReplacements.push_back(rewriter.create<TensorExpandShapeOp>(
404-
loc, origResultType, result.value(),
405-
convertAffineMapArrayToExprs(reassociationMaps[index])));
406-
} else
407-
resultReplacements.push_back(result.value());
446+
auto origResultType = genericOp.getResult(result.index()).getType();
447+
448+
auto newResult = maybeExpand(result.value(), origResultType,
449+
reassociationMaps[index], loc, rewriter);
450+
assert(newResult &&
451+
"unexpected output type other than ranked MemRef or Tensor");
452+
resultReplacements.push_back(newResult);
408453
}
409454
rewriter.replaceOp(genericOp, resultReplacements);
410455
return success();
@@ -501,9 +546,8 @@ struct UseRankReducedSubTensorInsertOp
501546
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
502547
RewritePatternSet &patterns) {
503548
auto *context = patterns.getContext();
504-
patterns.add<FoldUnitDimLoops, ReplaceUnitExtentTensors,
505-
UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
506-
context);
549+
patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
550+
UseRankReducedSubTensorInsertOp>(context);
507551
TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context);
508552
TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context);
509553
}

0 commit comments

Comments
 (0)