Skip to content

Commit 4bcc414

Browse files
Jezurkogysit
andauthored
[MLIR][TableGen] Error on APInt parameter without custom comparator (llvm#135970)
The error is triggered when an attribute or type uses an APInt typed parameter with the generated equality operator. If the compared APInts have different bit widths the equality operator triggers an assert. This is dangerous, since `StorageUniquer` for types and attributes uses the equality operator when a hash collision appears. As such, it is necessary to use custom provided comarator or `APIntParameter` that already has it. This commit also replaces uses of the raw `APInt` parameter with the `APIntParameter` and removes the no longer necessary custom StorageClass for the `BitVectorAttr` from the SMT dialect that was a workaround for the described issue. --------- Co-authored-by: Tobias Gysi <tobias.gysi@nextsilicon.com>
1 parent 278c429 commit 4bcc414

File tree

7 files changed

+39
-50
lines changed

7 files changed

+39
-50
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTAttributes.td

+1-11
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,11 @@ def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
4545
present).
4646
}];
4747

48-
let parameters = (ins "llvm::APInt":$value);
48+
let parameters = (ins APIntParameter<"">:$value);
4949

5050
let hasCustomAssemblyFormat = true;
5151
let genVerifyDecl = true;
5252

53-
// We need to manually define the storage class because the generated one is
54-
// buggy (because the APInt asserts matching bitwidth in the `==` operator and
55-
// the generated storage uses that directly.
56-
// Alternatively: add a type parameter to redundantly store the bitwidth of
57-
// of the attribute type, it it's in the order before the 'value' it will be
58-
// checked before the APInt equality (this is the reason it works for the
59-
// builtin integer attribute), but would be more fragile (and we'd store
60-
// duplicate data).
61-
let genStorageClass = false;
62-
6353
let builders = [
6454
AttrBuilder<(ins "llvm::StringRef":$value)>,
6555
AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,

mlir/include/mlir/IR/BuiltinAttributes.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer",
700700
false // A bool, i.e. i1, value.
701701
```
702702
}];
703-
let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value);
703+
let parameters = (ins AttributeSelfTypeParameter<"">:$type, APIntParameter<"">:$value);
704704
let builders = [
705705
AttrBuilderWithInferredContext<(ins "Type":$type,
706706
"const APInt &":$value), [{

mlir/include/mlir/TableGen/AttrOrTypeDef.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ class AttrOrTypeParameter {
6868
/// If specified, get the custom allocator code for this parameter.
6969
std::optional<StringRef> getAllocator() const;
7070

71-
/// If specified, get the custom comparator code for this parameter.
71+
/// Return true if user defined comparator is specified.
72+
bool hasCustomComparator() const;
73+
74+
/// Get the custom comparator code for this parameter or fallback to the
75+
/// default.
7276
StringRef getComparator() const;
7377

7478
/// Get the C++ type of this parameter.

mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp

-36
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,6 @@ using namespace mlir::smt;
2121
// BitVectorAttr
2222
//===----------------------------------------------------------------------===//
2323

24-
namespace mlir {
25-
namespace smt {
26-
namespace detail {
27-
struct BitVectorAttrStorage : public mlir::AttributeStorage {
28-
using KeyTy = APInt;
29-
BitVectorAttrStorage(APInt value) : value(std::move(value)) {}
30-
31-
KeyTy getAsKey() const { return value; }
32-
33-
// NOTE: the implementation of this operator is the reason we need to define
34-
// the storage manually. The auto-generated version would just do the direct
35-
// equality check of the APInt, but that asserts the bitwidth of both to be
36-
// the same, leading to a crash. This implementation, therefore, checks for
37-
// matching bit-width beforehand.
38-
bool operator==(const KeyTy &key) const {
39-
return (value.getBitWidth() == key.getBitWidth() && value == key);
40-
}
41-
42-
static llvm::hash_code hashKey(const KeyTy &key) {
43-
return llvm::hash_value(key);
44-
}
45-
46-
static BitVectorAttrStorage *
47-
construct(mlir::AttributeStorageAllocator &allocator, KeyTy &&key) {
48-
return new (allocator.allocate<BitVectorAttrStorage>())
49-
BitVectorAttrStorage(std::move(key));
50-
}
51-
52-
APInt value;
53-
};
54-
} // namespace detail
55-
} // namespace smt
56-
} // namespace mlir
57-
58-
APInt BitVectorAttr::getValue() const { return getImpl()->value; }
59-
6024
LogicalResult BitVectorAttr::verify(
6125
function_ref<InFlightDiagnostic()> emitError,
6226
APInt value) { // NOLINT(performance-unnecessary-value-param)

mlir/lib/TableGen/AttrOrTypeDef.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,10 @@ std::optional<StringRef> AttrOrTypeParameter::getAllocator() const {
278278
return getDefValue<StringInit>("allocator");
279279
}
280280

281+
bool AttrOrTypeParameter::hasCustomComparator() const {
282+
return getDefValue<StringInit>("comparator").has_value();
283+
}
284+
281285
StringRef AttrOrTypeParameter::getComparator() const {
282286
return getDefValue<StringInit>("comparator").value_or("$_lhs == $_rhs");
283287
}
+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: not mlir-tblgen -gen-attrdef-decls -I %S/../../include %s 2>&1 | FileCheck %s
2+
3+
include "mlir/IR/AttrTypeBase.td"
4+
include "mlir/IR/OpBase.td"
5+
6+
def Test_Dialect: Dialect {
7+
let name = "TestDialect";
8+
let cppNamespace = "::test";
9+
}
10+
11+
def RawAPIntAttr : AttrDef<Test_Dialect, "RawAPInt"> {
12+
let mnemonic = "raw_ap_int";
13+
let parameters = (ins "APInt":$value);
14+
let hasCustomAssemblyFormat = 1;
15+
}
16+
17+
// CHECK: apint-param-error.td:11:5: error: Using a raw APInt parameter

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,18 @@ void DefGen::emitStorageClass() {
678678
emitConstruct();
679679
// Emit the storage class members as public, at the very end of the struct.
680680
storageCls->finalize();
681-
for (auto &param : params)
681+
for (auto &param : params) {
682+
if (param.getCppType().contains("APInt") && !param.hasCustomComparator()) {
683+
PrintFatalError(
684+
def.getLoc(),
685+
"Using a raw APInt parameter without a custom comparator is "
686+
"not supported because an assert in the equality operator is "
687+
"triggered when the two APInts have different bit widths. This can "
688+
"lead to unexpected crashes. Use an `APIntParameter` or "
689+
"provide a custom comparator.");
690+
}
682691
storageCls->declare<Field>(param.getCppType(), param.getName());
692+
}
683693
}
684694

685695
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)