-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[MLIR][LLVM] Fix #llvm.constant_range crashing in storage uniquer #135772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds the bitwidth parameter to the constant range to allow for comparing of two instances of constant range. This fixes a crash in storage uniquer when two ranges with different bitwidths hashed to the same value and then the comparison triggered an assert in APInt because of the different bitwidths.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-llvm Author: Robert Konicar (Jezurko) ChangesThis PR adds the bitwidth parameter to the constant range to allow for comparing of two instances of constant range. This fixes a crash in storage uniquer when two ranges with different bitwidths hashed to the same value and then the comparison triggered an assert in APInt because of the different bitwidths. Full diff: https://github.com/llvm/llvm-project/pull/135772.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 690243525ede4..69376061bac72 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1095,6 +1095,7 @@ def LLVM_TBAATagArrayAttr
//===----------------------------------------------------------------------===//
def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
let parameters = (ins
+ "uint32_t":$width,
"::llvm::APInt":$lower,
"::llvm::APInt":$upper
);
@@ -1110,13 +1111,16 @@ def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
Syntax:
```
- `<` `i`(width($lower)) $lower `,` $upper `>`
+ `<` `i`(width) $lower `,` $upper `>`
```
}];
let builders = [
AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper), [{
- return $_get($_ctxt, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
+ return $_get($_ctxt, bitWidth, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
+ }]>,
+ AttrBuilder<(ins "::llvm::APInt":$lower, "::llvm::APInt":$upper), [{
+ return $_get($_ctxt, lower.getBitWidth(), lower, upper);
}]>
];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e4f9d6f987401..6975c593d7f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -278,13 +278,18 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
}
void ConstantRangeAttr::print(AsmPrinter &printer) const {
- printer << "<i" << getLower().getBitWidth() << ", " << getLower() << ", "
- << getUpper() << ">";
+ printer << "<i" << getWidth() << ", " << getLower() << ", " << getUpper()
+ << ">";
}
LogicalResult
ConstantRangeAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
- APInt lower, APInt upper) {
+ uint32_t width, llvm::APInt lower,
+ llvm::APInt upper) {
+ if (width != lower.getBitWidth())
+ return emitError()
+ << "expected type and value to have matching bitwidths but got "
+ << width << " vs. " << lower.getBitWidth();
if (lower.getBitWidth() != upper.getBitWidth())
return emitError()
<< "expected lower and upper to have matching bitwidths but got "
diff --git a/mlir/test/Dialect/LLVMIR/range-attr.mlir b/mlir/test/Dialect/LLVMIR/range-attr.mlir
new file mode 100644
index 0000000000000..5f2b67609743b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/range-attr.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -o - | FileCheck %s
+
+// CHECK: #llvm.constant_range<i32, 0, 12>
+llvm.func external @foo1(!llvm.ptr, i64) -> (i32 {llvm.range = #llvm.constant_range<i32, 0, 12>})
+// CHECK: #llvm.constant_range<i8, 1, 10>
+llvm.func external @foo2(!llvm.ptr, i64) -> (i8 {llvm.range = #llvm.constant_range<i8, 1, 10>})
+// CHECK: #llvm.constant_range<i64, 0, 2147483648>
+llvm.func external @foo3(!llvm.ptr, i64) -> (i64 {llvm.range = #llvm.constant_range<i64, 0, 2147483648>})
+// CHECK: #llvm.constant_range<i32, 1, -2147483648>
+llvm.func external @foo4(!llvm.ptr, i64) -> (i32 {llvm.range = #llvm.constant_range<i32, 1, -2147483648>})
|
@llvm/pr-subscribers-mlir Author: Robert Konicar (Jezurko) ChangesThis PR adds the bitwidth parameter to the constant range to allow for comparing of two instances of constant range. This fixes a crash in storage uniquer when two ranges with different bitwidths hashed to the same value and then the comparison triggered an assert in APInt because of the different bitwidths. Full diff: https://github.com/llvm/llvm-project/pull/135772.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 690243525ede4..69376061bac72 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -1095,6 +1095,7 @@ def LLVM_TBAATagArrayAttr
//===----------------------------------------------------------------------===//
def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
let parameters = (ins
+ "uint32_t":$width,
"::llvm::APInt":$lower,
"::llvm::APInt":$upper
);
@@ -1110,13 +1111,16 @@ def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> {
Syntax:
```
- `<` `i`(width($lower)) $lower `,` $upper `>`
+ `<` `i`(width) $lower `,` $upper `>`
```
}];
let builders = [
AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper), [{
- return $_get($_ctxt, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
+ return $_get($_ctxt, bitWidth, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper));
+ }]>,
+ AttrBuilder<(ins "::llvm::APInt":$lower, "::llvm::APInt":$upper), [{
+ return $_get($_ctxt, lower.getBitWidth(), lower, upper);
}]>
];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
index e4f9d6f987401..6975c593d7f7e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp
@@ -278,13 +278,18 @@ Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) {
}
void ConstantRangeAttr::print(AsmPrinter &printer) const {
- printer << "<i" << getLower().getBitWidth() << ", " << getLower() << ", "
- << getUpper() << ">";
+ printer << "<i" << getWidth() << ", " << getLower() << ", " << getUpper()
+ << ">";
}
LogicalResult
ConstantRangeAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
- APInt lower, APInt upper) {
+ uint32_t width, llvm::APInt lower,
+ llvm::APInt upper) {
+ if (width != lower.getBitWidth())
+ return emitError()
+ << "expected type and value to have matching bitwidths but got "
+ << width << " vs. " << lower.getBitWidth();
if (lower.getBitWidth() != upper.getBitWidth())
return emitError()
<< "expected lower and upper to have matching bitwidths but got "
diff --git a/mlir/test/Dialect/LLVMIR/range-attr.mlir b/mlir/test/Dialect/LLVMIR/range-attr.mlir
new file mode 100644
index 0000000000000..5f2b67609743b
--- /dev/null
+++ b/mlir/test/Dialect/LLVMIR/range-attr.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -o - | FileCheck %s
+
+// CHECK: #llvm.constant_range<i32, 0, 12>
+llvm.func external @foo1(!llvm.ptr, i64) -> (i32 {llvm.range = #llvm.constant_range<i32, 0, 12>})
+// CHECK: #llvm.constant_range<i8, 1, 10>
+llvm.func external @foo2(!llvm.ptr, i64) -> (i8 {llvm.range = #llvm.constant_range<i8, 1, 10>})
+// CHECK: #llvm.constant_range<i64, 0, 2147483648>
+llvm.func external @foo3(!llvm.ptr, i64) -> (i64 {llvm.range = #llvm.constant_range<i64, 0, 2147483648>})
+// CHECK: #llvm.constant_range<i32, 1, -2147483648>
+llvm.func external @foo4(!llvm.ptr, i64) -> (i32 {llvm.range = #llvm.constant_range<i32, 1, -2147483648>})
|
Fwiw, I will add a test that triggers this with the constant seed used in |
Isn't the underlying issue in the hash function of the APInt that does not take the bit width into account? |
Unless I'm missing something, I believe it does take it into account: https://github.com/llvm/llvm-project/blob/main/llvm/lib/Support/APInt.cpp#L590 |
hmm could it be that this is related to the llvm-project/llvm/include/llvm/ADT/APInt.h Line 1056 in 9a6c001
For some reason APInt seems to expect that the I am not entirely sure but I believe that the storage uniquer use the equality operator to avoid has value collisions. So maybe this is the root of the problem? With regards to the approach, I would first like to fully understand the problem before moving forward. In theory APIInt should work with the storage uniquer. If it doesn't, there may be more problematic attributes. |
Yes, the core of the issue is that two APInts with different bit widths can not be compared. I believe in other attributes the issue will be often avoided by the fact, that they contain some information about the type and the comparison short-circuits before comparing the APInts. I have a stack trace from when the assert is triggered: I will try to extract the hash seed that causes it with the input I added to the tests. |
A alternative solution is to add custom StorageClass for this: trailofbits@83a8b2b /// The hash key for this storage is a pair of the integer and type params.
using KeyTy = std::pair<llvm::APInt, llvm::APInt>;
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {
if (lower.getBitWidth() != key.first.getBitWidth() ||
upper.getBitWidth() != key.second.getBitWidth()) {
return false;
}
return lower == key.first && upper == key.second;
} |
Do you know if it is possibly to provide a storage class for all APInts to overwrite the strange equality operator? That way the problem could be solved for all attributes that use APInt. I also wonder why APInt has such a strange implementation of the equality operator? With regards to the two workarounds the question is probably a trade-off between simplicity and storing four additional bits? I have a slight tendency for using the custom storage class approach, but no strong opinion. |
Hacky solution is to add special case to handle APInt bitwidth in
|
For debugging purpose: making the |
Another way I see would be defining
APFloat is already handled there. But that doesn't fix the StorageClass for users of the raw APInt .
|
Yeah I think this is my preferred solution so far. |
9bb9bdc
to
de30d07
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding an APInt parameter looks reasonable to me, any other uses of APInt as a parameter in-tree that should be updated?
let parameters = (ins APIntParameter<"">:$lower, | ||
APIntParameter<"">:$upper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let parameters = (ins APIntParameter<"">:$lower, | |
APIntParameter<"">:$upper | |
let parameters = (ins | |
APIntParameter<"">:$lower, | |
APIntParameter<"">:$upper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found only:
- https://github.com/trail-of-forks/llvm-project/blob/bba2507c19ff678c5d7b18e0b220406be87451fe/mlir/include/mlir/IR/BuiltinAttributes.td#L683
def My_IntegerAttr : MyDialect_Attr<"Integer", "int"> {
We discussed with @gysit in side channel that it might be worthwhile to add a check to tablegen to warn on APInt
use and suggest APIntParameter
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One conceptual alternative to that I suppose, would be to detect APInt and use APIntParameter equivalent logic instead. Either way, would be nice to remove a footgun.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing, I like this solution.
LGTM modulo nit comments.
If we can add guardrails in tablegen that would be great. But that would be something for a separate PR.
mlir/include/mlir/IR/AttrTypeBase.td
Outdated
@@ -383,6 +383,12 @@ class StringRefParameter<string desc = "", string value = ""> : | |||
let defaultValue = value; | |||
} | |||
|
|||
// For APInts, which require comparison over different bitwidths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// For APInts, which require comparison over different bitwidths | |
// For APInts, which require comparison supporting different bitwidths. The default | |
// APInt comparison operator asserts when the bitwidths differ, so a custom | |
// implementation is necessary. |
nit: Let's maybe expand a bit why this is necessary.
I have integrated the suggested changes and updated the PR description to match the current solution. If it's okay like this, can someone merge it for me, please? I do not have write access yet. |
…vm#135772) Add APIntParameter with custom implementation for comparison and use it in llvm.constant_range attribute. This is necessary because the default equality operator of APInt asserts when the bit widths of the compared APInts differ. The comparison is used by StorageUniquer when hashes of two ranges with different bit widths collide.
Add APIntParameter with custom implementation for comparison and use it in llvm.constant_range attribute. This is necessary because the default equality operator of APInt asserts when the bit widths of the compared APInts differ. The comparison is used by StorageUniquer when hashes of two ranges with different bit widths collide.