diff --git a/include/circt/Dialect/SMT/SMTAttributes.td b/include/circt/Dialect/SMT/SMTAttributes.td index 3023eebe061b..6733bd47de49 100644 --- a/include/circt/Dialect/SMT/SMTAttributes.td +++ b/include/circt/Dialect/SMT/SMTAttributes.td @@ -52,7 +52,7 @@ def BitVectorAttr : AttrDef, - AttrBuilder<(ins "unsigned":$value, "unsigned":$width)>, + AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>, ]; let extraClassDeclaration = [{ diff --git a/include/circt/Dialect/SMT/SMTBitVectorOps.td b/include/circt/Dialect/SMT/SMTBitVectorOps.td index ba1ffe77ddb4..f60a9d8d0303 100644 --- a/include/circt/Dialect/SMT/SMTBitVectorOps.td +++ b/include/circt/Dialect/SMT/SMTBitVectorOps.td @@ -52,7 +52,7 @@ def BVConstantOp : SMTBVOp<"constant", [ build($_builder, $_state, BitVectorAttr::get($_builder.getContext(), value)); }]>, - OpBuilder<(ins "unsigned":$value, "unsigned":$width), [{ + OpBuilder<(ins "uint64_t":$value, "unsigned":$width), [{ build($_builder, $_state, BitVectorAttr::get($_builder.getContext(), value, width)); }]>, diff --git a/lib/Dialect/SMT/SMTAttributes.cpp b/lib/Dialect/SMT/SMTAttributes.cpp index cc00ffa3d383..8bacfda48d42 100644 --- a/lib/Dialect/SMT/SMTAttributes.cpp +++ b/lib/Dialect/SMT/SMTAttributes.cpp @@ -86,16 +86,16 @@ BitVectorAttr::getChecked(function_ref emitError, return Base::getChecked(emitError, context, *maybeValue); } -BitVectorAttr BitVectorAttr::get(MLIRContext *context, unsigned value, +BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, unsigned width) { return Base::get(context, APInt(width, value)); } BitVectorAttr BitVectorAttr::getChecked(function_ref emitError, - MLIRContext *context, unsigned value, + MLIRContext *context, uint64_t value, unsigned width) { - if ((~((1U << width) - 1U) & value) != 0U) { + if (width < 64 && value >= (UINT64_C(1) << width)) { emitError() << "value does not fit in a bit-vector of desired width"; return {}; } @@ -117,14 +117,20 @@ Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { } unsigned width = llvm::cast(odsType).getWidth(); - if (width > val.getBitWidth()) - val = val.sext(width); - if (width < val.getBitWidth()) { - if ((val.isNegative() && val.getSignificantBits() > width) || - val.getActiveBits() > width) { + if (width > val.getBitWidth()) { + // sext is always safe here, even for unsigned values, because the + // parseOptionalInteger method will return something with a zero in the + // top bits if it is a positive number. + val = val.sext(width); + } else if (width < val.getBitWidth()) { + // The parser can return an unnecessarily wide result. + // This isn't a problem, but truncating off bits is bad. + unsigned neededBits = + val.isNegative() ? val.getSignificantBits() : val.getActiveBits(); + if (width < neededBits) { odsParser.emitError(loc) - << "integer value out of range for given bit-vector type"; + << "integer value out of range for given bit-vector type " << odsType; return {}; } val = val.trunc(width); diff --git a/test/Dialect/SMT/bitvectors.mlir b/test/Dialect/SMT/bitvectors.mlir index e0cdbb45d0cf..552870793a54 100644 --- a/test/Dialect/SMT/bitvectors.mlir +++ b/test/Dialect/SMT/bitvectors.mlir @@ -8,6 +8,10 @@ func.func @bitvectors() { %c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr} // CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + %c-1_bv1_neg = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + // CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1> + %c-1_bv1_pos = smt.bv.constant #smt.bv<1> : !smt.bv<1> // CHECK: [[C0:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32> %c = smt.bv.constant #smt.bv<0> : !smt.bv<32> diff --git a/unittests/Dialect/SMT/AttributeTest.cpp b/unittests/Dialect/SMT/AttributeTest.cpp index f24dcc111670..7761105a2a7d 100644 --- a/unittests/Dialect/SMT/AttributeTest.cpp +++ b/unittests/Dialect/SMT/AttributeTest.cpp @@ -22,7 +22,7 @@ TEST(BitVectorAttrTest, MinBitWidth) { context.loadDialect(); Location loc(UnknownLoc::get(&context)); - auto attr = BitVectorAttr::getChecked(loc, &context, 0U, 0U); + auto attr = BitVectorAttr::getChecked(loc, &context, UINT64_C(0), 0U); ASSERT_EQ(attr, BitVectorAttr()); context.getDiagEngine().registerHandler([&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0"); @@ -96,12 +96,24 @@ TEST(BitVectorAttrTest, OutOfRange) { context.loadDialect(); Location loc(UnknownLoc::get(&context)); - auto attr = BitVectorAttr::getChecked(loc, &context, 2U, 1U); - ASSERT_EQ(attr, BitVectorAttr()); + auto attr1 = BitVectorAttr::getChecked(loc, &context, UINT64_C(2), 1U); + auto attr63 = + BitVectorAttr::getChecked(loc, &context, UINT64_C(3) << 62, 63U); + ASSERT_EQ(attr1, BitVectorAttr()); + ASSERT_EQ(attr63, BitVectorAttr()); context.getDiagEngine().registerHandler([&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "value does not fit in a bit-vector of desired width"); }); } +TEST(BitVectorAttrTest, GetUInt64Max) { + MLIRContext context; + context.loadDialect(); + auto attr64 = BitVectorAttr::get(&context, UINT64_MAX, 64); + auto attr65 = BitVectorAttr::get(&context, UINT64_MAX, 65); + ASSERT_EQ(attr64.getValue(), APInt::getAllOnes(64)); + ASSERT_EQ(attr65.getValue(), APInt::getAllOnes(64).zext(65)); +} + } // namespace