Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Mar 12, 2024
1 parent 8b59934 commit 228bfdf
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 82 deletions.
17 changes: 10 additions & 7 deletions include/circt/Dialect/SMT/SMTAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
described in the [SMT bit-vector
theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).

The constant is parsed and printed as #bX (binary) or #xX (hexadecimal)
The constant is as #bX (binary) or #xX (hexadecimal) in SMT-LIB
where X is the value in the corresponding format without any further
prefixing just as in SMT-LIB. The number of digits determines the bit-width
of the bit-vector. This means, leading zeros are important!
prefixing. Here, the bit-vector constant is given as a regular integer
literal and the associated bit-vector type indicating the bit-width.

Examples:
```mlir
#smt.bv<"#b0101"> : !smt.bv<4>
#smt.bv<"#x5c"> : !smt.bv<8>
#smt.bv<5> : !smt.bv<4>
#smt.bv<92> : !smt.bv<8>
```
The explicit type-suffix is optional.

The bit-width must be greater than zero (i.e., at least one digit as to be
The explicit type-suffix is mandatory to uniquely represent the attribute,
i.e., this attribute should always be used in the extended form (using the
`quantified` keyword in the operation assembly format string).

The bit-width must be greater than zero (i.e., at least one digit has to be
present).
}];

Expand Down
18 changes: 14 additions & 4 deletions include/circt/Dialect/SMT/SMTBitVectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,28 @@ def BVConstantOp : SMTBVOp<"constant", [

Examples:
```mlir
%bv_x5c = smt.bv.constant <"#x5c">
%bv_b0101 = smt.bv.constant <"#b0101">
%c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8>
%c5_bv4 = smt.bv.constant #smt.bv<5> : !smt.bv<4>
```
}];

let arguments = (ins BitVectorAttr:$value);
let results = (outs BitVectorType:$result);

let assemblyFormat = "$value attr-dict";
let assemblyFormat = "qualified($value) attr-dict";

let builders = [
OpBuilder<(ins "const llvm::APInt &":$value), [{
build($_builder, $_state,
BitVectorAttr::get($_builder.getContext(), value));
}]>,
OpBuilder<(ins "unsigned":$value, "unsigned":$width), [{
build($_builder, $_state,
BitVectorAttr::get($_builder.getContext(), value, width));
}]>,
];

let hasFolder = true;
let hasVerifier = true;
}

#endif // CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD
54 changes: 29 additions & 25 deletions lib/Dialect/SMT/SMTAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
}

BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
InFlightDiagnostic diag;
auto maybeValue =
parseBitVectorString([&]() { return std::move(diag); }, value);
diag.abandon();
auto maybeValue = parseBitVectorString(nullptr, value);

assert(succeeded(maybeValue) && "string must have SMT-LIB format");
return Base::get(context, *maybeValue);
Expand All @@ -98,42 +95,49 @@ BitVectorAttr
BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, unsigned value,
unsigned width) {
if ((~((1U << width) - 1U) & value) != 0U) {
emitError() << "value does not fit in a bit-vector of desired width";
return {};
}
return Base::getChecked(emitError, context, APInt(width, value));
}

Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
llvm::SMLoc loc = odsParser.getCurrentLocation();

std::string val;
if (odsParser.parseLess() || odsParser.parseString(&val))
return {};

auto maybeVal =
parseBitVectorString([&]() { return odsParser.emitError(loc); }, val);
if (failed(maybeVal))
APInt val;
if (odsParser.parseLess() || odsParser.parseInteger(val) ||
odsParser.parseGreater())
return {};

if (odsParser.parseGreater())
// Requires the use of `quantified(<attr>)` in operation assembly formats.
if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
odsParser.emitError(loc) << "explicit bit-vector type required";
return {};
}

// We implement the TypedAttr interface, i.e., the AsmParser allows an
// optional explicit type as suffix which is provided here as 'odsType'. We
// can always build the type from the constant itself, therefore, we can
// ignore this explicit type. However, to ensure consistency we should verify
// that the correct type is provided.
auto bv = BitVectorAttr::get(odsParser.getContext(), *maybeVal);
if (odsType && bv.getType() != odsType) {
odsParser.emitError(loc, "expected type for constant does not match "
"explicitly provided attribute type, got ")
<< odsType << ", expected " << bv.getType();
return {};
unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
if (width > val.getBitWidth())
val = val.sext(width);

if (width < val.getBitWidth()) {
if ((val.isNegative() && val.getSignificantBits() > width) ||
val.getActiveBits() > width) {
odsParser.emitError(loc)
<< "integer value out of range for given bit-vector type";
return {};
}
val = val.trunc(width);
}

return bv;
return BitVectorAttr::get(odsParser.getContext(), val);
}

void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
odsPrinter << "<\"" << getValueAsString() << "\">";
// This printer only works for the extended format where the MLIR
// infrastructure prints the type for us. This means, the attribute should
// never be used without `quantified` in an assembly format.
odsPrinter << "<" << getValue() << ">";
}

Type BitVectorAttr::getType() const {
Expand Down
11 changes: 2 additions & 9 deletions lib/Dialect/SMT/SMTOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ using namespace mlir;
// BVConstantOp
//===----------------------------------------------------------------------===//

LogicalResult BVConstantOp::verify() {
if (getValue().getType() != getType())
return emitError(
"smt.bv.constant attribute bitwidth doesn't match return type");

return success();
}

LogicalResult BVConstantOp::inferReturnTypes(
mlir::MLIRContext *context, std::optional<mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
Expand All @@ -40,7 +32,8 @@ void BVConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
SmallVector<char, 128> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << "bv_" << getValue().getValueAsString(false);
specialName << "c" << getValue().getValue() << "_bv"
<< getValue().getValue().getBitWidth();
setNameFn(getResult(), specialName.str());
}

Expand Down
25 changes: 9 additions & 16 deletions test/Dialect/SMT/bitvector-errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,29 @@ func.func @at_least_size_one(%arg0: !smt.bv<0>) {
// -----

func.func @attr_type_and_return_type_match() {
// expected-error @below {{smt.bv.constant attribute bitwidth doesn't match return type}}
%c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<1>}> : () -> !smt.bv<32>
return
}

// -----

func.func @implicit_constant_type_and_explicit_type_match() {
// expected-error @below {{expected type for constant does not match explicitly provided attribute type, got '!smt.bv<2>', expected '!smt.bv<1>'}}
%c0_bv2 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<2>}> : () -> !smt.bv<1>
// expected-error @below {{inferred type(s) '!smt.bv<1>' are incompatible with return type(s) of operation '!smt.bv<32>'}}
// expected-error @below {{failed to infer returned types}}
%c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<0> : !smt.bv<1>}> : () -> !smt.bv<32>
return
}

// -----

func.func @invalid_bitvector_attr() {
// expected-error @below {{expected at least one digit}}
smt.bv.constant #smt.bv<"#b">
// expected-error @below {{explicit bit-vector type required}}
smt.bv.constant #smt.bv<5>
}

// -----

func.func @invalid_bitvector_attr() {
// expected-error @below {{expected either 'b' or 'x'}}
smt.bv.constant #smt.bv<"#c0">
// expected-error @below {{integer value out of range for given bit-vector type}}
smt.bv.constant #smt.bv<32> : !smt.bv<2>
}

// -----

func.func @invalid_bitvector_attr() {
// expected-error @below {{expected '#'}}
smt.bv.constant #smt.bv<"b">
// expected-error @below {{integer value out of range for given bit-vector type}}
smt.bv.constant #smt.bv<-4> : !smt.bv<2>
}
25 changes: 6 additions & 19 deletions test/Dialect/SMT/bitvectors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,12 @@

// CHECK-LABEL: func @bitvectors
func.func @bitvectors() {
// A bit-width divisible by 4 is always printed in hex
// CHECK: %bv_x5a = smt.bv.constant <"#x5a"> {smt.some_attr}
%bv_x5a = smt.bv.constant <"#b01011010"> {smt.some_attr}

// A bit-width not divisible by 4 is always printed in binary
// Also, make sure leading zeros are printed
// CHECK: %bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr}
%bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr}

// CHECK: %bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr}
%bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr}

// Make sure leading zeros are printed
// CHECK: %bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr}
%bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr}

// It is allowed to fully quantify the attribute including an explicit type
// CHECK: %bv_x3cd = smt.bv.constant <"#x3cd">
%bv_x3cd = smt.bv.constant #smt.bv<"#x3cd"> : !smt.bv<12>
// CHECK: %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr}
%c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr}
// CHECK: %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> {smt.some_attr}
%c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr}
// CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>
%c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8>

return
}
75 changes: 73 additions & 2 deletions unittests/Dialect/SMT/AttributeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using namespace smt;

namespace {

TEST(AttributeTest, BitVectorAttr) {
TEST(BitVectorAttrTest, MinBitWidth) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));
Expand All @@ -27,10 +27,81 @@ TEST(AttributeTest, BitVectorAttr) {
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0");
});
}

TEST(BitVectorAttrTest, ParserAndPrinterCorrect) {
MLIRContext context;
context.loadDialect<SMTDialect>();

attr = BitVectorAttr::get(&context, "#b1010");
auto attr = BitVectorAttr::get(&context, "#b1010");
ASSERT_EQ(attr.getValue(), APInt(4, 10));
ASSERT_EQ(attr.getType(), BitVectorType::get(&context, 4));

// A bit-width divisible by 4 is always printed in hex
attr = BitVectorAttr::get(&context, "#b01011010");
ASSERT_EQ(attr.getValueAsString(), "#x5a");

// A bit-width not divisible by 4 is always printed in binary
// Also, make sure leading zeros are printed
attr = BitVectorAttr::get(&context, "#b0101101");
ASSERT_EQ(attr.getValueAsString(), "#b0101101");

attr = BitVectorAttr::get(&context, "#x3c");
ASSERT_EQ(attr.getValueAsString(), "#x3c");

attr = BitVectorAttr::get(&context, "#x03c");
ASSERT_EQ(attr.getValueAsString(), "#x03c");
}

TEST(BitVectorAttrTest, ExpectedOneDigit) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr =
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("#b"));
ASSERT_EQ(attr, BitVectorAttr());
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(), "expected at least one digit");
});
}

TEST(BitVectorAttrTest, ExpectedBOrX) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr =
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("#c0"));
ASSERT_EQ(attr, BitVectorAttr());
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(), "expected either 'b' or 'x'");
});
}

TEST(BitVectorAttrTest, ExpectedHashtag) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr =
BitVectorAttr::getChecked(loc, &context, static_cast<StringRef>("b0"));
ASSERT_EQ(attr, BitVectorAttr());
context.getDiagEngine().registerHandler(
[&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "expected '#'"); });
}

TEST(BitVectorAttrTest, OutOfRange) {
MLIRContext context;
context.loadDialect<SMTDialect>();
Location loc(UnknownLoc::get(&context));

auto attr = BitVectorAttr::getChecked(loc, &context, 2U, 1U);
ASSERT_EQ(attr, BitVectorAttr());
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
ASSERT_EQ(diag.str(),
"value does not fit in a bit-vector of desired width");
});
}

} // namespace

0 comments on commit 228bfdf

Please sign in to comment.