@@ -39,26 +39,13 @@ namespace gc {
39
39
#define GEN_PASS_DEF_ANYTILABLEFUSION
40
40
#include " gc/Transforms/Passes.h.inc"
41
41
42
- namespace {
43
-
44
42
struct SystemDesc {
45
43
// get runtime OMP_NUM_THREADS
46
44
uint32_t getNumThreads ();
47
45
// get cache size by cacheLevel
48
46
size_t getCacheSize (uint8_t cacheLevel);
49
47
};
50
48
51
- SmallVector<LoopLikeOpInterface> static getOuterLoopsOfSliceOp (
52
- OffsetSizeAndStrideOpInterface sliceOp) {
53
- SmallVector<LoopLikeOpInterface> outerLoops;
54
- auto forOp = sliceOp->getParentOfType <LoopLikeOpInterface>();
55
- while (forOp) {
56
- outerLoops.push_back (forOp);
57
- forOp = forOp->getParentOfType <LoopLikeOpInterface>();
58
- }
59
- return {outerLoops.rbegin (), outerLoops.rend ()};
60
- }
61
-
62
49
template <typename T> class FusionAnchorBase {
63
50
static_assert (
64
51
llvm::is_one_of<T, tensor::ExtractSliceOp, tensor::InsertSliceOp,
@@ -139,15 +126,15 @@ verifyTilableOpTileSizesOnDimAndTileMap(RewriterBase &rewriter, Operation *op,
139
126
targetTileSizes =
140
127
llvm::to_vector (tileSizes.take_back (targetInnerTileSizes.size ()));
141
128
} else // tileSize comes from OpOperand
142
- if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
143
- targetTileSizes = llvm::to_vector (tileSizes);
144
- DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
145
- packOp.getDimAndTileMapping ();
146
- targetInnerTileSizes.resize (dimAndTileMapping.size ());
147
- for (const auto &dimAndTile : dimAndTileMapping) {
148
- targetInnerTileSizes[dimAndTile.first ] = dimAndTile.second ;
129
+ if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
130
+ targetTileSizes = llvm::to_vector (tileSizes);
131
+ DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
132
+ packOp.getDimAndTileMapping ();
133
+ targetInnerTileSizes.resize (dimAndTileMapping.size ());
134
+ for (const auto &dimAndTile : dimAndTileMapping) {
135
+ targetInnerTileSizes[dimAndTile.first ] = dimAndTile.second ;
136
+ }
149
137
}
150
- }
151
138
} else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(op)) {
152
139
// tileSize comes from OpResult
153
140
if (std::is_same<T, tensor::ExtractSliceOp>::value) {
@@ -159,11 +146,11 @@ verifyTilableOpTileSizesOnDimAndTileMap(RewriterBase &rewriter, Operation *op,
159
146
targetInnerTileSizes[dimAndTile.first ] = dimAndTile.second ;
160
147
}
161
148
} else // tileSize comes from OpOperand
162
- if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
163
- targetInnerTileSizes = unPackOp.getInnerTiles ();
164
- targetTileSizes =
165
- llvm::to_vector (tileSizes.take_back (targetInnerTileSizes.size ()));
166
- }
149
+ if (std::is_same<T, OffsetSizeAndStrideOpInterface>::value) {
150
+ targetInnerTileSizes = unPackOp.getInnerTiles ();
151
+ targetTileSizes =
152
+ llvm::to_vector (tileSizes.take_back (targetInnerTileSizes.size ()));
153
+ }
167
154
}
168
155
169
156
// check tileSizes is full on or multiple of `inner_tile_size`
@@ -205,11 +192,7 @@ FusionAnchorBase<T>::selectCandidateByCostModel(RewriterBase &rewriter,
205
192
if (candidateSliceOpList.empty ())
206
193
return failure ();
207
194
// / TODO: use cost model
208
- if (std::is_same<T, tensor::ExtractSliceOp>::value) {
209
- return cast<T>(candidateSliceOpList.front ());
210
- } else {
211
- return cast<T>(candidateSliceOpList.back ());
212
- }
195
+ return cast<T>(candidateSliceOpList.front ());
213
196
}
214
197
215
198
// Target at tensor.extract_slice
@@ -361,98 +344,6 @@ getProducerFusionAnchorFromOpOperand(RewriterBase &rewriter,
361
344
return failure ();
362
345
}
363
346
364
- /* * Get the Result of top-level Loop which yield the target InsertSliceOp
365
- *
366
- * %1 = scf.for
367
- * %2 = scf.for
368
- * %3 = scf.for
369
- * ...
370
- * %4 = insert
371
- * yield %4
372
- * %5 = insert %3
373
- * yield %5
374
- * yield %2
375
- *
376
- * @param targetSliceOp: %4 = insert
377
- * @return Result Value: %1
378
- * Collected insertSliceOp List during walk including targetSliceOp:
379
- * %4 = insert and %5 = insert %3
380
- */
381
- FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
382
- getResultOfTopLevelLoopYieldInsertSliceOp (
383
- OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0 ) {
384
- // control recursive time in avoid of stack overflow
385
- if (curDepth > MAX_DEPTH)
386
- return failure ();
387
-
388
- SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
389
- candidateSliceOpList.push_back (targetSliceOp);
390
- Value resultOfLoop;
391
- if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
392
- targetSliceOp.getOperation ())) {
393
- Value destValue = sliceOp.getDest ();
394
- auto iterArg = cast<BlockArgument>(destValue);
395
- auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner ()->getParentOp ());
396
- if (!forallOp)
397
- return failure ();
398
- resultOfLoop = forallOp.getTiedOpResult (forallOp.getTiedOpOperand (iterArg));
399
- } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
400
- targetSliceOp.getOperation ())) {
401
- Value resultValue = sliceOp.getResult ();
402
- for (auto &useOperand : resultValue.getUses ()) {
403
- if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
404
- if (llvm::detail::isPresent (resultOfLoop))
405
- return failure ();
406
- auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
407
- if (!forOp)
408
- return failure ();
409
- resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
410
- }
411
- }
412
- }
413
-
414
- if (!llvm::detail::isPresent (resultOfLoop))
415
- return failure ();
416
-
417
- while (true ) {
418
- bool walkThroughOuterLoop = false ;
419
- for (auto &useOperand : resultOfLoop.getUses ()) {
420
- if (auto sliceOp =
421
- dyn_cast<tensor::ParallelInsertSliceOp>(useOperand.getOwner ())) {
422
- auto resultAndSliceOpsPair =
423
- getResultOfTopLevelLoopYieldInsertSliceOp (sliceOp, curDepth + 1 );
424
- if (failed (resultAndSliceOpsPair))
425
- return failure ();
426
- candidateSliceOpList.append ((*resultAndSliceOpsPair).second .begin (),
427
- (*resultAndSliceOpsPair).second .end ());
428
- return std::make_pair ((*resultAndSliceOpsPair).first ,
429
- candidateSliceOpList);
430
- } else if (auto sliceOp =
431
- dyn_cast<tensor::InsertSliceOp>(useOperand.getOwner ())) {
432
- auto resultAndSliceOpsPair =
433
- getResultOfTopLevelLoopYieldInsertSliceOp (sliceOp, curDepth + 1 );
434
- if (failed (resultAndSliceOpsPair))
435
- return failure ();
436
- candidateSliceOpList.append ((*resultAndSliceOpsPair).second .begin (),
437
- (*resultAndSliceOpsPair).second .end ());
438
- return std::make_pair ((*resultAndSliceOpsPair).first ,
439
- candidateSliceOpList);
440
- } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner ())) {
441
- // walk through outer loop
442
- auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp ());
443
- if (!forOp)
444
- return failure ();
445
- resultOfLoop = forOp->getResult (useOperand.getOperandNumber ());
446
- walkThroughOuterLoop = true ;
447
- break ;
448
- }
449
- }
450
- if (!walkThroughOuterLoop)
451
- break ;
452
- }
453
- return std::make_pair (resultOfLoop, candidateSliceOpList);
454
- }
455
-
456
347
/* *
457
348
* Find the untiled Consumer op based on given OpResult of Tiled Op, E.g.
458
349
*
@@ -493,7 +384,7 @@ getConsumerFusionAnchorFromOpResult(RewriterBase &rewriter,
493
384
return failure ();
494
385
495
386
auto resultAndSliceOpsPair =
496
- getResultOfTopLevelLoopYieldInsertSliceOp (sliceOp);
387
+ scfX:: getResultOfTopLevelLoopYieldInsertSliceOp (sliceOp);
497
388
if (failed (resultAndSliceOpsPair))
498
389
return failure ();
499
390
@@ -530,7 +421,7 @@ static Operation *preOpFuseProducerOfOpOperand(
530
421
if (failed (candidateSliceOp)) {
531
422
return nullptr ;
532
423
}
533
- auto outerLoops = getOuterLoopsOfSliceOp (*candidateSliceOp);
424
+ auto outerLoops = scfX:: getOuterLoopsOfSliceOp (*candidateSliceOp);
534
425
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
535
426
scfX::tileAndFuseProducerOfSlice (rewriter, *candidateSliceOp, outerLoops);
536
427
@@ -562,7 +453,7 @@ static SmallVector<Operation *> postOpFuseConsumerOfOpResult(
562
453
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedResult =
563
454
scfX::tileAndFuseConsumerOfSlice (rewriter, *candidateSliceOp);
564
455
if (fusedResult) {
565
- tiledConsumerList.push_back (fusedResult.value ().tiledAndFusedConsumer );
456
+ tiledConsumerList.push_back (fusedResult.value ().tiledOps [ 0 ] );
566
457
rewriter.eraseOp (consAnchor.getFusableOp ());
567
458
}
568
459
}
@@ -712,6 +603,5 @@ struct AnyTilableFusion : public impl::AnyTilableFusionBase<AnyTilableFusion> {
712
603
}
713
604
};
714
605
715
- } // namespace
716
606
} // namespace gc
717
607
} // namespace mlir
0 commit comments