Skip to content

[mlir][linalg] Add pattern to bubble-up pack through expand shape op #93529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something in this new code is triggering a crash (compiler assert) in the downstream IREE project: iree-org/iree#17734. If I revert this PR locally, the crash goes away.

I don't have a reduced test case yet and the input program is large (12MB) + specific to our downstream project.

  • Assert + stack trace:

    Assertion failed: input.size() == permutation.size() && "expected input rank to equal permutation rank", file D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir/Dialect/Utils/IndexingUtils.h, line 204
    Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
    Stack dump:
    0.	Program arguments: D:\\dev\\projects\\iree-build\\tools\\iree-compile.exe D:/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.mlir --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu-features=host -o D:/tmp/open_llama_3b_v2/open-llama-3b-v2-f16_cpu.vmfb --mlir-print-ir-before-all --mlir-elide-elementsattrs-if-larger=8 --mlir-elide-resource-strings-if-larger=8 --mlir-disable-threading
    Exception Code: 0x80000003
     #0 0x00007ff64dcf8e95 HandleAbort D:\dev\projects\iree\third_party\llvm-project\llvm\lib\Support\Windows\Signals.inc:425:0
     #1 0x00007ffe7d561881 (C:\WINDOWS\System32\ucrtbase.dll+0x71881)
     #2 0x00007ffe7d562851 (C:\WINDOWS\System32\ucrtbase.dll+0x72851)
     #3 0x00007ffe7d5641b5 (C:\WINDOWS\System32\ucrtbase.dll+0x741b5)
     #4 0x00007ffe7d5644f1 (C:\WINDOWS\System32\ucrtbase.dll+0x744f1)
     #5 0x00007ff651c4f70f mlir::applyPermutation<class llvm::SmallVector<__int64, 2>>(class llvm::ArrayRef<class llvm::SmallVector<__int64, 2>>, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:205:0
     #6 0x00007ff651c49c11 mlir::applyPermutation<class llvm::SmallVector<__int64, 2>>(class llvm::SmallVectorImpl<class llvm::SmallVector<__int64, 2>> const &, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:214:0
     #7 0x00007ff651c41c2b mlir::applyPermutationToVector<class llvm::SmallVector<__int64, 2>, 1>(class llvm::SmallVector<class llvm::SmallVector<__int64, 2>, 1> &, class llvm::ArrayRef<__int64>) D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\Dialect\Utils\IndexingUtils.h:225:0
     #8 0x00007ff6531d1eab `anonymous namespace'::applyPermutationAndReindexReassoc D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:610:0
     #9 0x00007ff6531d268d `anonymous namespace'::bubbleUpPackOpThroughCollapseShape D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:687:0
    #10 0x00007ff6531d3cd6 `anonymous namespace'::BubbleUpPackOpThroughReshapeOp::matchAndRewrite D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Dialect\Linalg\Transforms\DataLayoutPropagation.cpp:849:0
    #11 0x00007ff650b2bbe4 mlir::detail::OpOrInterfaceRewritePatternBase<class mlir::tensor::PackOp>::matchAndRewrite(class mlir::Operation *, class mlir::PatternRewriter &) const D:\dev\projects\iree\third_party\llvm-project\mlir\include\mlir\IR\PatternMatch.h:332:0
    #12 0x00007ff65209e8eb <lambda_033eed04a8a10a7b33015298d48d216a>::operator() D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Rewrite\PatternApplicator.cpp:212:0
    #13 0x00007ff65209c275 mlir::PatternApplicator::matchAndRewrite(class mlir::Operation *, class mlir::PatternRewriter &, class llvm::function_ref<(class mlir::Pattern const &)>, class llvm::function_ref<(class mlir::Pattern const &)>, class llvm::function_ref<(class mlir::Pattern const &)>) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Rewrite\PatternApplicator.cpp:233:0
    #14 0x00007ff650f1f91e `anonymous namespace'::GreedyPatternRewriteDriver::processWorklist D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:617:0
    #15 0x00007ff650f220e2 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_56efa1fe2231a48e07ce9bd5369059af> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #16 0x00007ff650f214ae `anonymous namespace'::RegionPatternRewriteDriver::simplify D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:872:0
    #17 0x00007ff650f1d38e mlir::applyPatternsAndFoldGreedily(class mlir::Region &, class mlir::FrozenRewritePatternSet const &, class mlir::GreedyRewriteConfig, bool *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Transforms\Utils\GreedyPatternRewriteDriver.cpp:920:0
    #18 0x00007ff651c78d1d mlir::iree_compiler::GlobalOptimization::`anonymous namespace'::DataLayoutPropagationPass::runOnOperation D:\dev\projects\iree\compiler\src\iree\compiler\GlobalOptimization\DataLayoutPropagation.cpp:31:0
    #19 0x00007ff64e0cead0 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_e8f8990a45bf3495636c03506b9db479> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #20 0x00007ff64e0c8637 mlir::detail::OpToOpPassAdaptor::run(class mlir::Pass *, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:533:0
    #21 0x00007ff64e0c883d mlir::detail::OpToOpPassAdaptor::runPipeline(class mlir::OpPassManager &, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int, class mlir::PassInstrumentor *, struct mlir::PassInstrumentation::PipelineParentInfo const *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:593:0
    #22 0x00007ff64e0c77bb mlir::detail::OpToOpPassAdaptor::runOnOperationImpl(bool) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:734:0
    #23 0x00007ff64e0ceb23 llvm::function_ref<void __cdecl(void)>::callback_fn<<lambda_e8f8990a45bf3495636c03506b9db479> > D:\dev\projects\iree\third_party\llvm-project\llvm\include\llvm\ADT\STLFunctionalExtras.h:45:0
    #24 0x00007ff64e0c8637 mlir::detail::OpToOpPassAdaptor::run(class mlir::Pass *, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:533:0
    #25 0x00007ff64e0c883d mlir::detail::OpToOpPassAdaptor::runPipeline(class mlir::OpPassManager &, class mlir::Operation *, class mlir::AnalysisManager, bool, unsigned int, class mlir::PassInstrumentor *, struct mlir::PassInstrumentation::PipelineParentInfo const *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:593:0
    #26 0x00007ff64e0c6d7b mlir::PassManager::runPasses(class mlir::Operation *, class mlir::AnalysisManager) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:904:0
    #27 0x00007ff64e0c6b3e mlir::PassManager::run(class mlir::Operation *) D:\dev\projects\iree\third_party\llvm-project\mlir\lib\Pass\Pass.cpp:883:0
    #28 0x00007ff64dca71c7 mlir::iree_compiler::embed::`anonymous namespace'::Invocation::runPipeline D:\dev\projects\iree\compiler\src\iree\compiler\API\Internal\CompilerDriver.cpp:995:0
    #29 0x00007ff64dc657ac <lambda_139d4d9eb9ed714e768e1c22e93f7b10>::operator() D:\dev\projects\iree\compiler\src\iree\compiler\Tools\iree_compile_lib.cc:254:0
    #30 0x00007ff64dc5ba18 mlir::iree_compiler::runIreecMain(int, char **) D:\dev\projects\iree\compiler\src\iree\compiler\Tools\iree_compile_lib.cc:355:0
    #31 0x00007ff658023d34 __scrt_common_main_seh d:\a01\_work\43\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl:288:0
    #32 0x00007ffe7e987344 (C:\WINDOWS\System32\KERNEL32.DLL+0x17344)
    #33 0x00007ffe7f9bcc91 (C:\WINDOWS\SYSTEM32\ntdll.dll+0x4cc91)
    
  • Here's a bit of printf debugging:

    // 100s of these, which are fine
    // Calling bubbleUpPackOpThroughCollapseShape with tensor::CollapseShapeOp:
    %collapsed_116 = tensor.collapse_shape %114 [[0, 1], [2], [3]] : tensor<4x32x100x?xf16> into tensor<128x100x?xf16>
    // ...and tensor::PackOp:
    %pack_119 = tensor.pack %collapsed_116 padding_value(%cst_27 : f16) outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [8, 1] into %117 : tensor<128x100x?xf16> -> tensor<128x?x100x8x1xf16>
    
    // crash right after this
    // Calling bubbleUpPackOpThroughCollapseShape with tensor::CollapseShapeOp:
    %collapsed_2681 = tensor.collapse_shape %expanded_2680 [[0], [1, 2], [3]] : tensor<4x32x1x100xf16> into tensor<4x32x100xf16>
    // ...and tensor::PackOp:
    %pack_2682 = tensor.pack %collapsed_2681 inner_dims_pos = [0, 2] inner_tiles = [1, 1] into %2209 : tensor<4x32x100xf16> -> tensor<4x32x100x1x1xf16>
  • Here's the IR before we call into this code and crash (13000 lines, can try reducing): https://gist.github.com/ScottTodd/d5f9721307e78cada067a81e60a471c0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I revert this PR locally, the crash goes away.

Interesting, looking at the stack dump it's calling bubbleUpPackOpThroughCollapseShape which is unrelated and should be untouched by this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, looking at the stack dump it's calling bubbleUpPackOpThroughCollapseShape which is unrelated and should be untouched by this PR.

There could be few issues. The propagation through expand_shape op changes the graph and trigger the failure. The issue could be either in expand_shape patterns or collapse_shape patterns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First guess is that this PR bubbles pack through expand_shape ops that acted as a barrier before and now it allows more bubbling to occur.
So, it either bubbled pack through some expand that it shouldn't have or exposed an edge case in the collapse_shape part of code.

Copy link

@AmosLewis AmosLewis Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got same issue for mit-b0 and 3 more models. The issue could be fixed by revert this commit. Here is the stacktrace

(gdb) bt
#0  __pthread_kill_implementation (no_tid=0, signo=6, threadid=140737352719616) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (signo=6, threadid=140737352719616) at ./nptl/pthread_kill.c:78
#2  __GI___pthread_kill (threadid=140737352719616, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3  0x00007fffdd442476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffdd4287f3 in __GI_abort () at ./stdlib/abort.c:79
#5  0x00007fffdd42871b in __assert_fail_base (fmt=0x7fffdd5dd130 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n", 
    assertion=0x7fffe22dc65a "input.size() == permutation.size() && \"expected input rank to equal permutation rank\"", 
    file=0x7fffe12f935b "iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h", line=204, function=<optimized out>) at ./assert/assert.c:92
#6  0x00007fffdd439e96 in __GI___assert_fail (assertion=0x7fffe22dc65a "input.size() == permutation.size() && \"expected input rank to equal permutation rank\"", 
    file=0x7fffe12f935b "iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h", line=204, 
    function=0x7fffe23b6a01 "SmallVector<T> mlir::applyPermutation(ArrayRef<T>, ArrayRef<int64_t>) [T = llvm::SmallVector<long, 2>]") at ./assert/assert.c:101
#7  0x00007fffec368e98 in mlir::applyPermutation<llvm::SmallVector<long, 2u> > (input=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:203

#8  0x00007fffec368dd9 in mlir::applyPermutation<llvm::SmallVector<long, 2u> > (input=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:214
#9  0x00007fffec368979 in mlir::applyPermutationToVector<llvm::SmallVector<long, 2u>, 1u> (inVec=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Dialect/Utils/IndexingUtils.h:225
#10 0x00007fffef5599c5 in (anonymous namespace)::applyPermutationAndReindexReassoc (reassocIndices=..., permutation=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:608
#11 0x00007fffef5594e6 in (anonymous namespace)::bubbleUpPackOpThroughCollapseShape (collapseOp=..., packOp=..., rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:685
#12 0x00007fffef559057 in (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}::operat--Type <RET> for more, q to quit, c to continue without paging--
or()(mlir::tensor::CollapseShapeOp) const (this=0x7fffffff8fc0, op=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:851
#13 0x00007fffef558fef in llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>::Case<mlir::tensor::CollapseShapeOp, (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}>((anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}&&) (this=0x7fffffff8fd0, caseFn=...) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h:102
#14 0x00007fffef558b45 in llvm::detail::TypeSwitchBase<llvm::TypeSwitch<mlir::Operation*, mlir::LogicalResult>, mlir::Operation*>::Case<(anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}>((anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite(mlir::tensor::PackOp, mlir::PatternRewriter&) const::{lambda(mlir::tensor::CollapseShapeOp)#1}&&) (this=0x7fffffff8fd0, caseFn=...) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/TypeSwitch.h:60
#15 0x00007fffef558a91 in (anonymous namespace)::BubbleUpPackOpThroughReshapeOp::matchAndRewrite (this=0x5555557239d0, packOp=..., rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp:850
#16 0x00007fffec3b87eb in mlir::detail::OpOrInterfaceRewritePatternBase<mlir::tensor::PackOp>::matchAndRewrite (this=0x5555557239d0, op=0x555556649190, rewriter=...)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/PatternMatch.h:331
#17 0x00007ffff16cf52e in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_0::operator()() const (this=0x7fffffff92b0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:212
#18 0x00007ffff16cf385 in llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>)::$_0>(long) (callable=140737488327344)
    at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#19 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff91f0) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#20 0x00007ffff16d0d55 in mlir::MLIRContext::executeAction<mlir::ApplyPatternAction, mlir::Pattern const&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pattern const&) (
    this=0x5555555ec710, actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#21 0x00007ffff16cde27 in mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (this=0x7fffffff9ed0, op=0x555556649190, rewriter=..., canApply=..., onFailure=..., onSuccess=...)
--Type <RET> for more, q to quit, c to continue without paging--
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:195
#22 0x00007ffff16898db in (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist (this=0x7fffffff9dd0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:615
#23 0x00007ffff1688b61 in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2::operator()() const (this=0x7fffffff9c80)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:874
#24 0x00007ffff1688b35 in llvm::function_ref<void ()>::callback_fn<(anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) &&::$_2>(long) (callable=140737488329856)
    at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#25 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffff9c20) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#26 0x00007ffff1688285 in mlir::MLIRContext::executeAction<(anonymous namespace)::GreedyPatternRewriteIteration, long&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, long&) (
    this=0x5555555ec710, actionFn=..., irUnits=..., args=@0x7fffffff9d88: 2) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#27 0x00007ffff168670e in (anonymous namespace)::RegionPatternRewriteDriver::simplify(bool*) && (this=0x7fffffff9dd0, changed=0x7fffffff9fd7)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:872
#28 0x00007ffff16863f7 in mlir::applyPatternsAndFoldGreedily (region=..., patterns=..., config=..., changed=0x7fffffff9fd7)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:919
#29 0x00007fffe924a105 in mlir::applyPatternsAndFoldGreedily (op=0x5555557d1cf0, patterns=..., config=..., changed=0x0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:159
#30 0x00007fffec31f23b in mlir::iree_compiler::GlobalOptimization::(anonymous namespace)::DataLayoutPropagationPass::runOnOperation (this=0x55555663f1d0)
    at /home/chi/src/iree/compiler/src/iree/compiler/GlobalOptimization/DataLayoutPropagation.cpp:31
#31 0x00007fffe980335b in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const (this=0x7fffffffa428)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:527
#32 0x00007fffe98032f5 in llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) (
    callable=140737488331816) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
--Type <RET> for more, q to quit, c to continue without paging--
#33 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffffa3b0) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#34 0x00007fffe9806175 in mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) (this=0x5555555ec710, 
    actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#35 0x00007fffe97feab3 in mlir::detail::OpToOpPassAdaptor::run (pass=0x55555663f1d0, op=0x5555557d1cf0, am=..., verifyPasses=true, parentInitGeneration=1)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:521
#36 0x00007fffe97ff034 in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5555557d1cf0, am=..., verifyPasses=true, parentInitGeneration=1, instrumentor=0x5555557f5ee0, 
    parentInfo=0x7fffffffaae0) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:593
#37 0x00007fffe98045e5 in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0::operator()(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&) const (
    this=0x7fffffffaa78, opInfo=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:813
#38 0x00007fffe9804269 in mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&) (context=0x5555555ec710, begin={passManagerIdx = 0, op = 0x5555557d1cf0, am = {impl = 0x5555557d2ce0}}, 
    end={passManagerIdx = 129, op = 0x55555578f6a0, am = {impl = 0x55555578ef00}}, func=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:46
#39 0x00007fffe98002eb in mlir::failableParallelForEach<std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&>(mlir::MLIRContext*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> >&, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_0&) (
    context=0x5555555ec710, range=std::vector of length 1, capacity 1 = {...}, func=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:92
#40 0x00007fffe97ffbfa in mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl (this=0x555555804b30, verifyPasses=true) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:823
#41 0x00007fffe97ff727 in mlir::detail::OpToOpPassAdaptor::runOnOperation (this=0x555555804b30, verifyPasses=true) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:714
--Type <RET> for more, q to quit, c to continue without paging--
#42 0x00007fffe9803346 in mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1::operator()() const (this=0x7fffffffade8)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:525
#43 0x00007fffe98032f5 in llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::$_1>(long) (
    callable=140737488334312) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
#44 0x00007fffe93074a9 in llvm::function_ref<void ()>::operator()() const (this=0x7fffffffad70) at /home/chi/src/iree/third_party/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
#45 0x00007fffe9806175 in mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) (this=0x5555555ec710, 
    actionFn=..., irUnits=..., args=...) at /home/chi/src/iree/third_party/llvm-project/mlir/include/mlir/IR/MLIRContext.h:275
#46 0x00007fffe97feab3 in mlir::detail::OpToOpPassAdaptor::run (pass=0x555555804b30, op=0x5555557f5690, am=..., verifyPasses=true, parentInitGeneration=1)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:521
#47 0x00007fffe97ff034 in mlir::detail::OpToOpPassAdaptor::runPipeline (pm=..., op=0x5555557f5690, am=..., verifyPasses=true, parentInitGeneration=1, instrumentor=0x0, parentInfo=0x0)
    at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:593
#48 0x00007fffe9800a78 in mlir::PassManager::runPasses (this=0x555555739cb0, op=0x5555557f5690, am=...) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:904
#49 0x00007fffe98009a2 in mlir::PassManager::run (this=0x555555739cb0, op=0x5555557f5690) at /home/chi/src/iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:884
#50 0x00007fffe9251cba in mlir::iree_compiler::embed::(anonymous namespace)::Invocation::runPipeline (this=0x55555565add0, pipeline=IREE_COMPILER_PIPELINE_STD)
    at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:995
#51 0x00007fffe9251593 in ireeCompilerInvocationPipeline (inv=0x55555565add0, pipeline=IREE_COMPILER_PIPELINE_STD)
    at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp:1430
#52 0x00007fffe978a88e in mlir::iree_compiler::runIreecMain(int, char**)::$_2::operator()(iree_compiler_source_t*) const (this=0x7fffffffc0e8, source=0x55555565aba0)
    at /home/chi/src/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:254
#53 0x00007fffe9789d1e in mlir::iree_compiler::runIreecMain (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/compiler/src/iree/compiler/Tools/iree_compile_lib.cc:355
#54 0x00007fffe929baab in ireeCompilerRunMain (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/compiler/src/iree/compiler/API/Internal/IREECompileToolEntryPoint.cpp:12
#55 0x00005555555557a2 in main (argc=4, argv=0x7fffffffdbd8) at /home/chi/src/iree/tools/iree-compile-main.cc:9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adam-smnk I encountered the same bug on my end.
The root cause for the bug is that outerDimsPerm is an optional attribute that could be empty. However, when calling applyPermutationAndReindexReassoc, it assumes outerDimsPerm to be non-empty. One possible solution is to fill outerDimsPerm with default values ([0, 1, 2, ...]).

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
Expand Down Expand Up @@ -694,6 +696,131 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
return success();
}

/// Project dimsPos to their collapsed positions in the reassocIndices.
///
/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
static SmallVector<int64_t>
projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
ArrayRef<ReassociationIndices> reassocIndices) {
SmallVector<int64_t> projectedPos;

// Map each dimension to the position of corresponding reassociation index.
for (auto pos : dimsPos) {
for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
// If the dimension is present in the current indices group, the group
// position within the reassociation map is the desired projected
// dimension position.
if (llvm::any_of(indices,
[&](int64_t expandDim) { return expandDim == pos; })) {
projectedPos.push_back(idx);
break;
}
}
}
assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");

return projectedPos;
}

/// Bubble up pack op through expand shape op.
///
/// For example:
///
/// %expand = tensor.expand_shape %in [[0], [1, 2]]
/// : tensor<?x64xf32> into tensor<?x4x16xf32>
/// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
/// inner_dims_pos = [2] inner_tiles = [8] into %empty
/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
///
/// can be transformed into:
///
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
/// inner_dims_pos = [1] inner_tiles = [8] into %empty
/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
static LogicalResult
bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
tensor::PackOp packOp,
PatternRewriter &rewriter) {
// Outer dimensions permutation is not supported currently.
// TODO: Handle outer_dims_perm variants.
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
return rewriter.notifyMatchFailure(packOp,
"non-identity outer dims perm NYI");
}

// Validate dimensions' relations between shape expansion and packing.
SmallVector<ReassociationIndices, 4> reassoc =
expandOp.getReassociationIndices();
ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
packInnerDims.end());

for (auto [idx, indices] : llvm::enumerate(reassoc)) {
// For each expand_shape reassociation, figure out which dimensions get
// packed if any.
llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
llvm::SetVector<int64_t> packedDims =
llvm::set_intersection(packDimsPos, expandDimPos);

// The expanded dimension is not packed so, it does not affect moving pack
// before shape expansion - simply continue.
if (packedDims.empty())
continue;
// Shape expansion cannot be propagated when multiple expanded dimension are
// packed - in this case operation reordering would affect final element
// positions and/or shapes can no longer be projected.
if (packedDims.size() != 1)
return rewriter.notifyMatchFailure(
packOp, "only one of the expanded dimensions can be packed");
// Only the inner-most expanded dimension should be packed. Otherwise,
// elements order will be affected after operation reordering.
if (packedDims.front() != indices.back())
return rewriter.notifyMatchFailure(
packOp, "can only pack the inner-most expanded dimension");
}

// Project pack.inner_dims_pos to positions before shape expansion.
SmallVector<int64_t> projectedInnerDimsPos =
projectDimsPosIntoReassocPos(packInnerDims, reassoc);

// Project the shape expansion to new packed shape.
// The pack.outer_dims_perm is restricted to identity so, the permutation can
// be omitted for simplicity.
// TODO: Account for outer dimensions permutation.
//
// If reassociation is not possible, then reordering cannot happen.
// This can be caused by pack padding affecting previously expanded
// dimensions or packing extending dimensions.
RankedTensorType newPackType = tensor::PackOp::inferPackedType(
expandOp.getSrcType(), packOp.getStaticInnerTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
auto reassocExpand =
getReassociationIndicesForReshape(newPackType, packOp.getDestType());
if (!reassocExpand)
return rewriter.notifyMatchFailure(
packOp, "could not reassociate dims after bubbling up");

Value destTensor = tensor::PackOp::createDestinationTensor(
rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
Value packedVal = rewriter.create<tensor::PackOp>(
packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
packOp.getMixedTiles(), packOp.getPaddingValue(),
/*outerDimsPerm=*/SmallVector<int64_t>{});

Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
rewriter.replaceOp(packOp, newExpandOp);

return success();
}

class BubbleUpPackOpThroughReshapeOp final
: public OpRewritePattern<tensor::PackOp> {
public:
Expand Down Expand Up @@ -723,6 +850,9 @@ class BubbleUpPackOpThroughReshapeOp final
.Case([&](tensor::CollapseShapeOp op) {
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
})
.Case([&](tensor::ExpandShapeOp op) {
return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
})
.Default([](Operation *) { return failure(); });
}

Expand Down
Loading