20
20
#include " mlir/Dialect/Linalg/Utils/Utils.h"
21
21
#include " mlir/IR/AffineExpr.h"
22
22
#include " mlir/IR/AffineMap.h"
23
+ #include " mlir/IR/BuiltinTypes.h"
23
24
#include " mlir/Transforms/FoldUtils.h"
24
25
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26
#include " llvm/Support/CommandLine.h"
@@ -256,7 +257,7 @@ struct UnitExtentReplacementInfo {
256
257
} // namespace
257
258
258
259
// / Utility function for replacing operands/results to a linalg generic
259
- // / operation on tensors with unit-extent dimensions. These can be replaced with
260
+ // / operation with unit-extent dimensions. These can be replaced with
260
261
// / an operand/result with the unit-extent dimension removed. This is only done
261
262
// / if the indexing map used to access that didimensionmension has a
262
263
// / AffineConstantExpr of value 0. Given the `type` of an result/operand of a
@@ -301,10 +302,19 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
301
302
++dim;
302
303
}
303
304
// Compute the tensor or scalar replacement type.
305
+ Type actualType = opOperand->get ().getType ();
304
306
Type elementType = getElementTypeOrSelf (opOperand->get ());
305
- Type replacementType = elementType == opOperand->get ().getType ()
306
- ? elementType
307
- : RankedTensorType::get (newShape, elementType);
307
+ Type replacementType;
308
+ if (elementType == opOperand->get ().getType ()) {
309
+ replacementType = elementType;
310
+ } else if (actualType.isa <RankedTensorType>()) {
311
+ replacementType = RankedTensorType::get (newShape, elementType);
312
+ } else if (actualType.isa <MemRefType>()) {
313
+ assert (actualType.cast <MemRefType>().getAffineMaps ().empty () &&
314
+ " unsupported strided memrefs" );
315
+ replacementType = MemRefType::get (newShape, elementType);
316
+ }
317
+ assert (replacementType && " unsupported shaped type" );
308
318
UnitExtentReplacementInfo info = {replacementType,
309
319
AffineMap::get (indexingMap.getNumDims (),
310
320
indexingMap.getNumSymbols (),
@@ -324,22 +334,60 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
324
334
return reassociationExprs;
325
335
}
326
336
327
- // / Pattern to replace tensors operands/results that are unit extents.
328
- struct ReplaceUnitExtentTensors : public OpRewritePattern <GenericOp> {
337
+ // / Pattern to replace tensor/buffer operands/results that are unit extents.
338
+ struct ReplaceUnitExtents : public OpRewritePattern <GenericOp> {
329
339
using OpRewritePattern<GenericOp>::OpRewritePattern;
340
+
341
+ // Return the original value if the type is unchanged, or reshape it. Return a
342
+ // nullptr if this is an unsupported type.
343
+ Value maybeExpand (Value result, Type origResultType,
344
+ ArrayAttr reassociationMap, Location loc,
345
+ PatternRewriter &rewriter) const {
346
+ if (origResultType == result.getType ())
347
+ return result;
348
+ if (origResultType.isa <RankedTensorType>()) {
349
+ return rewriter.create <linalg::TensorExpandShapeOp>(
350
+ loc, origResultType, result,
351
+ convertAffineMapArrayToExprs (reassociationMap));
352
+ }
353
+ if (origResultType.isa <MemRefType>()) {
354
+ return rewriter.create <linalg::ExpandShapeOp>(
355
+ loc, origResultType, result,
356
+ convertAffineMapArrayToExprs (reassociationMap));
357
+ }
358
+ return nullptr ;
359
+ };
360
+
361
+ // Return the original value if the type is unchanged, or reshape it. Return a
362
+ // nullptr if this is an unsupported type.
363
+ Value maybeCollapse (Value operand, Type newInputOutputType,
364
+ ArrayAttr reassociationMap, Location loc,
365
+ PatternRewriter &rewriter) const {
366
+ auto operandType = operand.getType ();
367
+ if (operandType == newInputOutputType)
368
+ return operand;
369
+ if (operandType.isa <MemRefType>()) {
370
+ return rewriter.create <linalg::CollapseShapeOp>(
371
+ loc, newInputOutputType, operand,
372
+ convertAffineMapArrayToExprs (reassociationMap));
373
+ }
374
+ if (operandType.isa <RankedTensorType>()) {
375
+ return rewriter.create <linalg::TensorCollapseShapeOp>(
376
+ loc, newInputOutputType, operand,
377
+ convertAffineMapArrayToExprs (reassociationMap));
378
+ }
379
+ return nullptr ;
380
+ };
381
+
330
382
LogicalResult matchAndRewrite (GenericOp genericOp,
331
383
PatternRewriter &rewriter) const override {
332
- if (!genericOp.hasTensorSemantics ())
333
- return failure ();
334
-
335
384
MLIRContext *context = rewriter.getContext ();
336
385
Location loc = genericOp.getLoc ();
337
386
338
387
SmallVector<AffineMap> newIndexingMaps;
339
388
SmallVector<ArrayAttr> reassociationMaps;
340
389
SmallVector<Type> newInputOutputTypes;
341
390
bool doCanonicalization = false ;
342
-
343
391
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands ()) {
344
392
UnitExtentReplacementInfo replacementInfo =
345
393
replaceUnitExtents (genericOp, opOperand, context);
@@ -362,14 +410,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
362
410
auto insertReshapes = [&](ValueRange values) {
363
411
SmallVector<Value, 4 > res;
364
412
res.reserve (values.size ());
365
- for (auto operand : llvm::enumerate (values)) {
366
- if (operand.value ().getType () == newInputOutputTypes[flattenedIdx])
367
- res.push_back (operand.value ());
368
- else {
369
- res.push_back (rewriter.create <TensorCollapseShapeOp>(
370
- loc, newInputOutputTypes[flattenedIdx], operand.value (),
371
- convertAffineMapArrayToExprs (reassociationMaps[flattenedIdx])));
372
- }
413
+ for (auto operand : values) {
414
+ auto reshapedValue =
415
+ maybeCollapse (operand, newInputOutputTypes[flattenedIdx],
416
+ reassociationMaps[flattenedIdx], loc, rewriter);
417
+ assert (reshapedValue &&
418
+ " expected ranked MemRef or Tensor operand type" );
419
+ res.push_back (reshapedValue);
373
420
++flattenedIdx;
374
421
}
375
422
return res;
@@ -396,15 +443,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
396
443
SmallVector<Value, 4 > resultReplacements;
397
444
for (auto result : llvm::enumerate (replacementOp.getResults ())) {
398
445
unsigned index = result.index () + replacementOp.getNumInputs ();
399
- RankedTensorType origResultType = genericOp.getResult (result.index ())
400
- .getType ()
401
- .template cast <RankedTensorType>();
402
- if (origResultType != result.value ().getType ()) {
403
- resultReplacements.push_back (rewriter.create <TensorExpandShapeOp>(
404
- loc, origResultType, result.value (),
405
- convertAffineMapArrayToExprs (reassociationMaps[index])));
406
- } else
407
- resultReplacements.push_back (result.value ());
446
+ auto origResultType = genericOp.getResult (result.index ()).getType ();
447
+
448
+ auto newResult = maybeExpand (result.value (), origResultType,
449
+ reassociationMaps[index], loc, rewriter);
450
+ assert (newResult &&
451
+ " unexpected output type other than ranked MemRef or Tensor" );
452
+ resultReplacements.push_back (newResult);
408
453
}
409
454
rewriter.replaceOp (genericOp, resultReplacements);
410
455
return success ();
@@ -501,9 +546,8 @@ struct UseRankReducedSubTensorInsertOp
501
546
void mlir::linalg::populateFoldUnitExtentDimsPatterns (
502
547
RewritePatternSet &patterns) {
503
548
auto *context = patterns.getContext ();
504
- patterns.add <FoldUnitDimLoops, ReplaceUnitExtentTensors,
505
- UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
506
- context);
549
+ patterns.add <FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
550
+ UseRankReducedSubTensorInsertOp>(context);
507
551
TensorCollapseShapeOp::getCanonicalizationPatterns (patterns, context);
508
552
TensorExpandShapeOp::getCanonicalizationPatterns (patterns, context);
509
553
}
0 commit comments