Skip to content

Commit c3746ff

Browse files
hanhanWrkayaith
andauthored
[mlir][Affine] Handle null parent op in getAffineParallelInductionVarOwner (#142025)
The issue occurs during a downstream pass which does dialect conversion, where both [`FuncOpConversion`](https://github.com/llvm/llvm-project/blob/cde67b6663f994fcb4ded28fd79b23a13d347c4a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L480) and [`SubviewFolder`](https://github.com/llvm/llvm-project/blob/cde67b6663f994fcb4ded28fd79b23a13d347c4a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp#L187) are run together. The original starting IR is: ```mlir module { func.func @foo(%arg0: memref<100x100xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> memref<?x?xf32, strided<[100, 1], offset: ?>> { %subview = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1] : memref<100x100xf32> to memref<?x?xf32, strided<[100, 1], offset: ?>> return %subview : memref<?x?xf32, strided<[100, 1], offset: ?>> } } ``` After `FuncOpConversion` runs, the IR looks like: ```mlir "builtin.module"() ({ "llvm.func"() <{CConv = #llvm.cconv<ccc>, function_type = !llvm.func<struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> (ptr, ptr, i64, i64, i64, i64, i64, i64, i64, i64, i64)>, linkage = #llvm.linkage<external>, sym_name = "foo", visibility_ = 0 : i64}> ({ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64, %arg8: i64, %arg9: i64, %arg10: i64): %0 = "memref.subview"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) <{operandSegmentSizes = array<i32: 1, 2, 2, 0>, static_offsets = array<i64: -9223372036854775808, -9223372036854775808>, static_sizes = array<i64: -9223372036854775808, -9223372036854775808>, static_strides = array<i64: 1, 1>}> : (memref<100x100xf32>, index, index, index, index) -> memref<?x?xf32, strided<[100, 1], offset: ?>> "func.return"(%0) : (memref<?x?xf32, strided<[100, 1], offset: ?>>) -> () }) : () -> () "func.func"() <{function_type = (memref<100x100xf32>, index, index, index, index) -> memref<?x?xf32, strided<[100, 1], offset: ?>>, sym_name = "foo"}> ({ }) : () -> () }) {llvm.data_layout = "", llvm.target_triple = ""} : () -> () ``` The `<<UNKNOWN SSA VALUE>>`'s here are block arguments of a separate unlinked block, which is disconnected from the rest of the IR (so not only is the IR verifier-invalid, it can't even be parsed). This IR is created by signature conversion in the dialect conversion infra. Now `SubviewFolder` is applied, and the utility function here is called on one of these disconnected block arguments, causing a crash. The TestMemRefToLLVMWithTransforms pass is introduced to exercise the bug, and it can be reused by other contributors in the future. --------- Signed-off-by: hanhanW <hanhan0912@gmail.com> Co-authored-by: Rahul Kayaith <rkayaith@gmail.com>
1 parent b59c888 commit c3746ff

File tree

7 files changed

+98
-1
lines changed

7 files changed

+98
-1
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2667,7 +2667,7 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
26672667
if (!ivArg || !ivArg.getOwner())
26682668
return nullptr;
26692669
Operation *containingOp = ivArg.getOwner()->getParentOp();
2670-
auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2670+
auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
26712671
if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
26722672
return parallelOp;
26732673
return nullptr;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-opt -test-memref-to-llvm-with-transforms %s | FileCheck %s
2+
3+
// Checks that the program does not crash. The functionality of the pattern is
4+
// already checked in test/Dialect/MemRef/*.mlir
5+
6+
func.func @subview_folder(%arg0: memref<100x100xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> memref<?x?xf32, strided<[100, 1], offset: ?>> {
7+
%subview = memref.subview %arg0[%arg1, %arg2] [%arg3, %arg4] [1, 1] : memref<100x100xf32> to memref<?x?xf32, strided<[100, 1], offset: ?>>
8+
return %subview : memref<?x?xf32, strided<[100, 1], offset: ?>>
9+
}
10+
// CHECK-LABEL: llvm.func @subview_folder
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_subdirectory(ConvertToSPIRV)
22
add_subdirectory(FuncToLLVM)
33
add_subdirectory(MathToVCIX)
4+
add_subdirectory(MemRefToLLVM)
45
add_subdirectory(VectorToSPIRV)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Exclude tests from libMLIR.so
2+
add_mlir_library(MLIRTestMemRefToLLVMWithTransforms
3+
TestMemRefToLLVMWithTransforms.cpp
4+
5+
EXCLUDE_FROM_LIBMLIR
6+
7+
LINK_LIBS PUBLIC
8+
MLIRTestDialect
9+
)
10+
mlir_target_link_libraries(MLIRTestFuncToLLVM PUBLIC
11+
MLIRLLVMCommonConversion
12+
MLIRLLVMDialect
13+
MLIRMemRefTransforms
14+
MLIRPass
15+
)
16+
17+
target_include_directories(MLIRTestFuncToLLVM
18+
PRIVATE
19+
${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
20+
${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
21+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- TestMemRefToLLVMWithTransforms.cpp ---------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
10+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
11+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
12+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13+
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16+
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/Pass.h"
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
24+
struct TestMemRefToLLVMWithTransforms
25+
: public PassWrapper<TestMemRefToLLVMWithTransforms,
26+
OperationPass<ModuleOp>> {
27+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefToLLVMWithTransforms)
28+
29+
void getDependentDialects(DialectRegistry &registry) const final {
30+
registry.insert<LLVM::LLVMDialect>();
31+
}
32+
33+
StringRef getArgument() const final {
34+
return "test-memref-to-llvm-with-transforms";
35+
}
36+
37+
StringRef getDescription() const final {
38+
return "Tests conversion of MemRef dialects + `func.func` to LLVM dialect "
39+
"with MemRef transforms.";
40+
}
41+
42+
void runOnOperation() override {
43+
MLIRContext *ctx = &getContext();
44+
LowerToLLVMOptions options(ctx);
45+
LLVMTypeConverter typeConverter(ctx, options);
46+
RewritePatternSet patterns(ctx);
47+
memref::populateExpandStridedMetadataPatterns(patterns);
48+
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
49+
LLVMConversionTarget target(getContext());
50+
if (failed(applyPartialConversion(getOperation(), target,
51+
std::move(patterns))))
52+
signalPassFailure();
53+
}
54+
};
55+
56+
} // namespace
57+
58+
namespace mlir::test {
59+
void registerTestMemRefToLLVMWithTransforms() {
60+
PassRegistration<TestMemRefToLLVMWithTransforms>();
61+
}
62+
} // namespace mlir::test

mlir/tools/mlir-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ if(MLIR_INCLUDE_TESTS)
2828
MLIRMathTestPasses
2929
MLIRTestMathToVCIX
3030
MLIRMemRefTestPasses
31+
MLIRTestMemRefToLLVMWithTransforms
3132
MLIRMeshTest
3233
MLIRNVGPUTestPasses
3334
MLIRSCFTestPasses

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ void registerTestMathToVCIXPass();
130130
void registerTestIrdlTestDialectConversionPass();
131131
void registerTestMemRefDependenceCheck();
132132
void registerTestMemRefStrideCalculation();
133+
void registerTestMemRefToLLVMWithTransforms();
133134
void registerTestMeshReshardingSpmdizationPass();
134135
void registerTestMeshSimplificationsPass();
135136
void registerTestMultiBuffering();
@@ -275,6 +276,7 @@ void registerTestPasses() {
275276
mlir::test::registerTestMathToVCIXPass();
276277
mlir::test::registerTestMemRefDependenceCheck();
277278
mlir::test::registerTestMemRefStrideCalculation();
279+
mlir::test::registerTestMemRefToLLVMWithTransforms();
278280
mlir::test::registerTestMeshReshardingSpmdizationPass();
279281
mlir::test::registerTestMeshSimplificationsPass();
280282
mlir::test::registerTestMultiBuffering();

0 commit comments

Comments
 (0)