Skip to content

Commit

Permalink
[SMT] Minor width related fixes for BitVectorAttr (llvm#6900)
Browse files Browse the repository at this point in the history
Some minor changes to the construction of BitVector attributes in the SMT dialect:
- Fix parsing of smt.bv.constant #smt.bv<-1> : !smt.bv<1> which currently trips the width check due to odsParser.parseInteger creating an unnecessarily wide APInt. The new logic is copy-pasted from the FIRRTL ConstantOp parser. See llvm#6794.
- Change the type of the attribute builder's value argument from unsigned to uint64_t matching the signature of the APInt constructor, to allow values up to 64 bits and avoid architecture dependent behavior.
- Prevent left-shifts wider than (or equal to) the shifted operand's number of bits in the width checking logic to avoid undefined behavior.
  • Loading branch information
fzi-hielscher authored and cepheus69 committed Apr 22, 2024
1 parent e6e13d6 commit ecedb3d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 14 deletions.
2 changes: 1 addition & 1 deletion include/circt/Dialect/SMT/SMTAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [

let builders = [
AttrBuilder<(ins "llvm::StringRef":$value)>,
AttrBuilder<(ins "unsigned":$value, "unsigned":$width)>,
AttrBuilder<(ins "uint64_t":$value, "unsigned":$width)>,
];

let extraClassDeclaration = [{
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/SMT/SMTBitVectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def BVConstantOp : SMTBVOp<"constant", [
build($_builder, $_state,
BitVectorAttr::get($_builder.getContext(), value));
}]>,
OpBuilder<(ins "unsigned":$value, "unsigned":$width), [{
OpBuilder<(ins "uint64_t":$value, "unsigned":$width), [{
build($_builder, $_state,
BitVectorAttr::get($_builder.getContext(), value, width));
}]>,
Expand Down
24 changes: 15 additions & 9 deletions lib/Dialect/SMT/SMTAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
return Base::getChecked(emitError, context, *maybeValue);
}

BitVectorAttr BitVectorAttr::get(MLIRContext *context, unsigned value,
BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
unsigned width) {
return Base::get(context, APInt(width, value));
}

BitVectorAttr
BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, unsigned value,
MLIRContext *context, uint64_t value,
unsigned width) {
if ((~((1U << width) - 1U) & value) != 0U) {
if (width < 64 && value >= (UINT64_C(1) << width)) {
emitError() << "value does not fit in a bit-vector of desired width";
return {};
}
Expand All @@ -117,14 +117,20 @@ Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
}

unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
if (width > val.getBitWidth())
val = val.sext(width);

if (width < val.getBitWidth()) {
if ((val.isNegative() && val.getSignificantBits() > width) ||
val.getActiveBits() > width) {
if (width > val.getBitWidth()) {
// sext is always safe here, even for unsigned values, because the
// parseOptionalInteger method will return something with a zero in the
// top bits if it is a positive number.
val = val.sext(width);
} else if (width < val.getBitWidth()) {
// The parser can return an unnecessarily wide result.
// This isn't a problem, but truncating off bits is bad.
unsigned neededBits =
val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
if (width < neededBits) {
odsParser.emitError(loc)
<< "integer value out of range for given bit-vector type";
<< "integer value out of range for given bit-vector type " << odsType;
return {};
}
val = val.trunc(width);
Expand Down
4 changes: 4 additions & 0 deletions test/Dialect/SMT/bitvectors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ func.func @bitvectors() {
%c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr}
// CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
%c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
// CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
%c-1_bv1_neg = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
// CHECK: %c-1_bv1{{(_[0-9]+)?}} = smt.bv.constant #smt.bv<-1> : !smt.bv<1>
%c-1_bv1_pos = smt.bv.constant #smt.bv<1> : !smt.bv<1>

// CHECK: [[C0:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
%c = smt.bv.constant #smt.bv<0> : !smt.bv<32>
Expand Down
18 changes: 15 additions & 3 deletions unittests/Dialect/SMT/AttributeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ TEST(BitVectorAttrTest, MinBitWidth) {
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr = BitVectorAttr::getChecked(loc, &context, 0U, 0U);
auto attr = BitVectorAttr::getChecked(loc, &context, UINT64_C(0), 0U);
ASSERT_EQ(attr, BitVectorAttr());
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0");
Expand Down Expand Up @@ -96,12 +96,24 @@ TEST(BitVectorAttrTest, OutOfRange) {
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr = BitVectorAttr::getChecked(loc, &context, 2U, 1U);
ASSERT_EQ(attr, BitVectorAttr());
auto attr1 = BitVectorAttr::getChecked(loc, &context, UINT64_C(2), 1U);
auto attr63 =
BitVectorAttr::getChecked(loc, &context, UINT64_C(3) << 62, 63U);
ASSERT_EQ(attr1, BitVectorAttr());
ASSERT_EQ(attr63, BitVectorAttr());
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(),
"value does not fit in a bit-vector of desired width");
});
}

TEST(BitVectorAttrTest, GetUInt64Max) {
MLIRContext context;
context.loadDialect<SMTDialect>();
auto attr64 = BitVectorAttr::get(&context, UINT64_MAX, 64);
auto attr65 = BitVectorAttr::get(&context, UINT64_MAX, 65);
ASSERT_EQ(attr64.getValue(), APInt::getAllOnes(64));
ASSERT_EQ(attr65.getValue(), APInt::getAllOnes(64).zext(65));
}

} // namespace

0 comments on commit ecedb3d

Please sign in to comment.