Skip to content

expand_shape fails to compile with more than one dynamic dim in a reassoc group #17760

@zjgarvey

Description

@zjgarvey

What happened?

Trying to compile

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<()[s0, s1] -> (s0 * s1)>
#map2 = affine_map<(d0, d1) -> (d0)>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @gather_nd_graph(%arg0: tensor<5x3xf32>, %arg1: tensor<?x?xi64>) -> tensor<?x?x3xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
    %dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
    %collapsed = tensor.collapse_shape %arg1 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
    %0 = affine.apply #map1()[%dim, %dim_0]
    %1 = tensor.empty(%0) : tensor<?x3xf32>
    %2 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%1 : tensor<?x3xf32>) {
    ^bb0(%in: i64, %out: f32):
      %12 = arith.index_cast %in : i64 to index
      %13 = linalg.index 1 : index
      %extracted = tensor.extract %arg0[%12, %13] : tensor<5x3xf32>
      linalg.yield %extracted : f32
    } -> tensor<?x3xf32>
    %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [%dim, %dim_0, 3] : tensor<?x3xf32> into tensor<?x?x3xf32>
    return %expanded : tensor<?x?x3xf32>
  }
}

with iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=tm_tensor gives:

iree-compile: /home/zjgar/code/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp:1913: virtual mlir::LogicalResult (anonymous namespace)::FoldDimOfExpandShape::matchAndRewrite(mlir::tensor::DimOp, mlir::PatternRewriter &) const: Assertion `!resultType.isDynamicDim(d) && "expected static dim"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/zjgar/code/iree-build/tools/iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=tm_tensor some.mlir -o some.vmfb --mlir-print-ir-before-all
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libIREECompiler.so 0x00007fcfc0609787 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 39
1  libIREECompiler.so 0x00007fcfc06079b0 llvm::sys::RunSignalHandlers() + 80
2  libIREECompiler.so 0x00007fcfc0609e4a
3  libc.so.6          0x00007fcfba1dc520
4  libc.so.6          0x00007fcfba2309fc pthread_kill + 300
5  libc.so.6          0x00007fcfba1dc476 raise + 22
6  libc.so.6          0x00007fcfba1c27f3 abort + 211
7  libc.so.6          0x00007fcfba1c271b
8  libc.so.6          0x00007fcfba1d3e96
9  libIREECompiler.so 0x00007fcfc455ba09
10 libIREECompiler.so 0x00007fcfc439e95e
11 libIREECompiler.so 0x00007fcfc439b8af mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) + 943
12 libIREECompiler.so 0x00007fcfc437c855
13 libIREECompiler.so 0x00007fcfc4378f20 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 1088
14 libIREECompiler.so 0x00007fcfc432253b
15 libIREECompiler.so 0x00007fcfc07ab685 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 629
16 libIREECompiler.so 0x00007fcfc07abe08 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 328
17 libIREECompiler.so 0x00007fcfc07b0331
18 libIREECompiler.so 0x00007fcfc4380e6a mlir::Inliner::Impl::optimizeCallable(mlir::CallGraphNode*, llvm::StringMap<mlir::OpPassManager, llvm::MallocAllocator>&) + 266
19 libIREECompiler.so 0x00007fcfc4385fad
20 libIREECompiler.so 0x00007fcfc4380be5 mlir::Inliner::Impl::optimizeSCCAsync(llvm::MutableArrayRef<mlir::CallGraphNode*>, mlir::MLIRContext*) + 1077
21 libIREECompiler.so 0x00007fcfc438157d mlir::Inliner::doInlining() + 1453
22 libIREECompiler.so 0x00007fcfc4326af1
23 libIREECompiler.so 0x00007fcfc07ab685 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 629
24 libIREECompiler.so 0x00007fcfc07abe08 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 328
25 libIREECompiler.so 0x00007fcfc07ae239 mlir::PassManager::run(mlir::Operation*) + 985
26 libIREECompiler.so 0x00007fcfc0560470 ireeCompilerInvocationPipeline + 3408
27 libIREECompiler.so 0x00007fcfc0771cf8
28 libIREECompiler.so 0x00007fcfc0771581
29 libc.so.6          0x00007fcfba1c3d90
30 libc.so.6          0x00007fcfba1c3e40 __libc_start_main + 128
31 iree-compile       0x000055b7175726b5
Aborted

The failure occurs on a canonicalize pass trying to fold the dim into expand_shape op pattern. Changing the assertion to a match failure causes an error later on, during iree-global-opt-remove-zero-extent-tensors pass:

// -----// IR Dump Before RemoveZeroExtentTensors (iree-global-opt-remove-zero-extent-tensors) //----- //
util.func public @gather_nd_graph(%arg0: !hal.buffer_view, %arg1: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "sync func @gather_nd_graph(%input0: tensor<5x3xf32>, %input1: tensor<?x?xi64>) -> (%output0: tensor<?x?x3xf32>)"}} {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %0 = hal.tensor.import %arg0 "input0" : !hal.buffer_view -> tensor<5x3xf32>
  %1 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[0] : index
  %2 = hal.buffer_view.dim<%arg1 : !hal.buffer_view>[1] : index
  %3 = hal.tensor.import %arg1 "input1" : !hal.buffer_view -> tensor<?x?xi64>{%1, %2}
  %collapsed = tensor.collapse_shape %3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
  %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%1, %2]
  %5 = tensor.empty(%4) : tensor<?x3xf32>
  %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%5 : tensor<?x3xf32>) {
  ^bb0(%in: i64, %out: f32):
    %8 = arith.index_cast %in : i64 to index
    %9 = linalg.index 1 : index
    %extracted = tensor.extract %0[%8, %9] : tensor<5x3xf32>
    linalg.yield %extracted : f32
  } -> tensor<?x3xf32>
  %expanded = tensor.expand_shape %6 [[0, 1], [2]] output_shape [%1, %2, 3] : tensor<?x3xf32> into tensor<?x?x3xf32>
  %dim = tensor.dim %expanded, %c0 : tensor<?x?x3xf32>
  %dim_0 = tensor.dim %expanded, %c1 : tensor<?x?x3xf32>
  %7 = hal.tensor.export %expanded "output0" : tensor<?x?x3xf32>{%dim, %dim_0} -> !hal.buffer_view
  util.return %7 : !hal.buffer_view
}

iree-compile: /home/zjgar/code/iree/third_party/llvm-project/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp:103: mlir::OpFoldResult getExpandedOutputDimFromInputShape(mlir::OpBuilder &, mlir::Location, int64_t, mlir::Value, ArrayRef<int64_t>, ArrayRef<mlir::AffineMap>, llvm::DenseMap<int64_t, int64_t> &): Assertion `!ShapedType::isDynamic(d.value()) && "single dimension cannot be expanded into multiple dynamic " "dimensions"' failed.
Please report issues to https://github.com/iree-org/iree/issues and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/zjgar/code/iree-build/tools/iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=tm_tensor some.mlir -o some.vmfb --mlir-print-ir-before-all
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0  libIREECompiler.so 0x00007fcd84d786c7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 39
1  libIREECompiler.so 0x00007fcd84d768f0 llvm::sys::RunSignalHandlers() + 80
2  libIREECompiler.so 0x00007fcd84d78d8a
3  libc.so.6          0x00007fcd7e94b520
4  libc.so.6          0x00007fcd7e99f9fc pthread_kill + 300
5  libc.so.6          0x00007fcd7e94b476 raise + 22
6  libc.so.6          0x00007fcd7e9317f3 abort + 211
7  libc.so.6          0x00007fcd7e93171b
8  libc.so.6          0x00007fcd7e942e96
9  libIREECompiler.so 0x00007fcd866d5a6f
10 libIREECompiler.so 0x00007fcd866d4d9f
11 libIREECompiler.so 0x00007fcd891ec313 mlir::reifyResultShapes(mlir::OpBuilder&, mlir::Operation*, llvm::SmallVector<llvm::SmallVector<mlir::OpFoldResult, 6u>, 1u>&) + 99
12 libIREECompiler.so 0x00007fcd87c15b96
13 libIREECompiler.so 0x00007fcd88b0d89e
14 libIREECompiler.so 0x00007fcd88b0a7ef mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) + 943
15 libIREECompiler.so 0x00007fcd88aeb795
16 libIREECompiler.so 0x00007fcd88ae7e60 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 1088
17 libIREECompiler.so 0x00007fcd862547cb
18 libIREECompiler.so 0x00007fcd84f1a5c5 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 629
19 libIREECompiler.so 0x00007fcd84f1ad48 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 328
20 libIREECompiler.so 0x00007fcd84f1fea3
21 libIREECompiler.so 0x00007fcd84f1c33b mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 2459
22 libIREECompiler.so 0x00007fcd84f1a760 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1040
23 libIREECompiler.so 0x00007fcd84f1ad48 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 328
24 libIREECompiler.so 0x00007fcd84f1d179 mlir::PassManager::run(mlir::Operation*) + 985
25 libIREECompiler.so 0x00007fcd84ccf3b0 ireeCompilerInvocationPipeline + 3408
26 libIREECompiler.so 0x00007fcd84ee0c38
27 libIREECompiler.so 0x00007fcd84ee04c1
28 libc.so.6          0x00007fcd7e932d90
29 libc.so.6          0x00007fcd7e932e40 __libc_start_main + 128
30 iree-compile       0x00005556202976b5
Aborted

Steps to reproduce your issue

Copy paste the mlir above into a file (named example.mlir for reference here)

Then run

iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=tm_tensor example.mlir -o some.vmfb

What component(s) does this issue relate to?

No response

Version information

commit 3b5d269

Additional context

The sample IR came out of an aten.unflatten.int lowering when I tried specifying the output shape explicitly (the more generic builder for tensor::expand_shape tries to infer this output shape with too little information). I've decided to temporarily change the lowering for that op to use tensor::reshape in the case of more than one dynamic reassociation dim, since this doesn't cause compilation issues.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug 🐞Something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions