diff --git a/include/circt/Dialect/SMT/SMTAttributes.td b/include/circt/Dialect/SMT/SMTAttributes.td index 0e77075d05c1..3023eebe061b 100644 --- a/include/circt/Dialect/SMT/SMTAttributes.td +++ b/include/circt/Dialect/SMT/SMTAttributes.td @@ -26,19 +26,22 @@ def BitVectorAttr : AttrDef : !smt.bv<4> - #smt.bv<"#x5c"> : !smt.bv<8> + #smt.bv<5> : !smt.bv<4> + #smt.bv<92> : !smt.bv<8> ``` - The explicit type-suffix is optional. - The bit-width must be greater than zero (i.e., at least one digit as to be + The explicit type-suffix is mandatory to uniquely represent the attribute, + i.e., this attribute should always be used in the extended form (using the + `quantified` keyword in the operation assembly format string). + + The bit-width must be greater than zero (i.e., at least one digit has to be present). }]; diff --git a/include/circt/Dialect/SMT/SMTBitVectorOps.td b/include/circt/Dialect/SMT/SMTBitVectorOps.td index 166f20430d1f..c2a352bb0e6c 100644 --- a/include/circt/Dialect/SMT/SMTBitVectorOps.td +++ b/include/circt/Dialect/SMT/SMTBitVectorOps.td @@ -37,18 +37,28 @@ def BVConstantOp : SMTBVOp<"constant", [ Examples: ```mlir - %bv_x5c = smt.bv.constant <"#x5c"> - %bv_b0101 = smt.bv.constant <"#b0101"> + %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> + %c5_bv4 = smt.bv.constant #smt.bv<5> : !smt.bv<4> ``` }]; let arguments = (ins BitVectorAttr:$value); let results = (outs BitVectorType:$result); - let assemblyFormat = "$value attr-dict"; + let assemblyFormat = "qualified($value) attr-dict"; + + let builders = [ + OpBuilder<(ins "const llvm::APInt &":$value), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value)); + }]>, + OpBuilder<(ins "unsigned":$value, "unsigned":$width), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value, width)); + }]>, + ]; let hasFolder = true; - let hasVerifier = true; } #endif // CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD diff --git a/lib/Dialect/SMT/SMTAttributes.cpp b/lib/Dialect/SMT/SMTAttributes.cpp index 3a3ca0f70730..cc00ffa3d383 100644 --- a/lib/Dialect/SMT/SMTAttributes.cpp +++ b/lib/Dialect/SMT/SMTAttributes.cpp @@ -70,10 +70,7 @@ parseBitVectorString(function_ref emitError, } BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { - InFlightDiagnostic diag; - auto maybeValue = - parseBitVectorString([&]() { return std::move(diag); }, value); - diag.abandon(); + auto maybeValue = parseBitVectorString(nullptr, value); assert(succeeded(maybeValue) && "string must have SMT-LIB format"); return Base::get(context, *maybeValue); @@ -98,42 +95,49 @@ BitVectorAttr BitVectorAttr::getChecked(function_ref emitError, MLIRContext *context, unsigned value, unsigned width) { + if ((~((1U << width) - 1U) & value) != 0U) { + emitError() << "value does not fit in a bit-vector of desired width"; + return {}; + } return Base::getChecked(emitError, context, APInt(width, value)); } Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { llvm::SMLoc loc = odsParser.getCurrentLocation(); - std::string val; - if (odsParser.parseLess() || odsParser.parseString(&val)) - return {}; - - auto maybeVal = - parseBitVectorString([&]() { return odsParser.emitError(loc); }, val); - if (failed(maybeVal)) + APInt val; + if (odsParser.parseLess() || odsParser.parseInteger(val) || + odsParser.parseGreater()) return {}; - if (odsParser.parseGreater()) + // Requires the use of `quantified()` in operation assembly formats. + if (!odsType || !llvm::isa(odsType)) { + odsParser.emitError(loc) << "explicit bit-vector type required"; return {}; + } - // We implement the TypedAttr interface, i.e., the AsmParser allows an - // optional explicit type as suffix which is provided here as 'odsType'. We - // can always build the type from the constant itself, therefore, we can - // ignore this explicit type. However, to ensure consistency we should verify - // that the correct type is provided. - auto bv = BitVectorAttr::get(odsParser.getContext(), *maybeVal); - if (odsType && bv.getType() != odsType) { - odsParser.emitError(loc, "expected type for constant does not match " - "explicitly provided attribute type, got ") - << odsType << ", expected " << bv.getType(); - return {}; + unsigned width = llvm::cast(odsType).getWidth(); + if (width > val.getBitWidth()) + val = val.sext(width); + + if (width < val.getBitWidth()) { + if ((val.isNegative() && val.getSignificantBits() > width) || + val.getActiveBits() > width) { + odsParser.emitError(loc) + << "integer value out of range for given bit-vector type"; + return {}; + } + val = val.trunc(width); } - return bv; + return BitVectorAttr::get(odsParser.getContext(), val); } void BitVectorAttr::print(AsmPrinter &odsPrinter) const { - odsPrinter << "<\"" << getValueAsString() << "\">"; + // This printer only works for the extended format where the MLIR + // infrastructure prints the type for us. This means, the attribute should + // never be used without `quantified` in an assembly format. + odsPrinter << "<" << getValue() << ">"; } Type BitVectorAttr::getType() const { diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index 5f9606dc142d..b086e77cac95 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -18,14 +18,6 @@ using namespace mlir; // BVConstantOp //===----------------------------------------------------------------------===// -LogicalResult BVConstantOp::verify() { - if (getValue().getType() != getType()) - return emitError( - "smt.bv.constant attribute bitwidth doesn't match return type"); - - return success(); -} - LogicalResult BVConstantOp::inferReturnTypes( mlir::MLIRContext *context, std::optional location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, @@ -40,7 +32,8 @@ void BVConstantOp::getAsmResultNames( function_ref setNameFn) { SmallVector specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); - specialName << "bv_" << getValue().getValueAsString(false); + specialName << "c" << getValue().getValue() << "_bv" + << getValue().getValue().getBitWidth(); setNameFn(getResult(), specialName.str()); } diff --git a/test/Dialect/SMT/bitvector-errors.mlir b/test/Dialect/SMT/bitvector-errors.mlir index 495f5105966b..7c2d1952bdf6 100644 --- a/test/Dialect/SMT/bitvector-errors.mlir +++ b/test/Dialect/SMT/bitvector-errors.mlir @@ -8,36 +8,29 @@ func.func @at_least_size_one(%arg0: !smt.bv<0>) { // ----- func.func @attr_type_and_return_type_match() { - // expected-error @below {{smt.bv.constant attribute bitwidth doesn't match return type}} - %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<1>}> : () -> !smt.bv<32> - return -} - -// ----- - -func.func @implicit_constant_type_and_explicit_type_match() { - // expected-error @below {{expected type for constant does not match explicitly provided attribute type, got '!smt.bv<2>', expected '!smt.bv<1>'}} - %c0_bv2 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<2>}> : () -> !smt.bv<1> + // expected-error @below {{inferred type(s) '!smt.bv<1>' are incompatible with return type(s) of operation '!smt.bv<32>'}} + // expected-error @below {{failed to infer returned types}} + %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<0> : !smt.bv<1>}> : () -> !smt.bv<32> return } // ----- func.func @invalid_bitvector_attr() { - // expected-error @below {{expected at least one digit}} - smt.bv.constant #smt.bv<"#b"> + // expected-error @below {{explicit bit-vector type required}} + smt.bv.constant #smt.bv<5> } // ----- func.func @invalid_bitvector_attr() { - // expected-error @below {{expected either 'b' or 'x'}} - smt.bv.constant #smt.bv<"#c0"> + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<32> : !smt.bv<2> } // ----- func.func @invalid_bitvector_attr() { - // expected-error @below {{expected '#'}} - smt.bv.constant #smt.bv<"b"> + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<-4> : !smt.bv<2> } diff --git a/test/Dialect/SMT/bitvectors.mlir b/test/Dialect/SMT/bitvectors.mlir index 3645937675d6..15d07ef197d5 100644 --- a/test/Dialect/SMT/bitvectors.mlir +++ b/test/Dialect/SMT/bitvectors.mlir @@ -2,25 +2,12 @@ // CHECK-LABEL: func @bitvectors func.func @bitvectors() { - // A bit-width divisible by 4 is always printed in hex - // CHECK: %bv_x5a = smt.bv.constant <"#x5a"> {smt.some_attr} - %bv_x5a = smt.bv.constant <"#b01011010"> {smt.some_attr} - - // A bit-width not divisible by 4 is always printed in binary - // Also, make sure leading zeros are printed - // CHECK: %bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr} - %bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr} - - // CHECK: %bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr} - %bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr} - - // Make sure leading zeros are printed - // CHECK: %bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr} - %bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr} - - // It is allowed to fully quantify the attribute including an explicit type - // CHECK: %bv_x3cd = smt.bv.constant <"#x3cd"> - %bv_x3cd = smt.bv.constant #smt.bv<"#x3cd"> : !smt.bv<12> + // CHECK: %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + // CHECK: %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> {smt.some_attr} + %c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr} + // CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> return } diff --git a/unittests/Dialect/SMT/AttributeTest.cpp b/unittests/Dialect/SMT/AttributeTest.cpp index ebb5d8f6c2e6..f24dcc111670 100644 --- a/unittests/Dialect/SMT/AttributeTest.cpp +++ b/unittests/Dialect/SMT/AttributeTest.cpp @@ -17,7 +17,7 @@ using namespace smt; namespace { -TEST(AttributeTest, BitVectorAttr) { +TEST(BitVectorAttrTest, MinBitWidth) { MLIRContext context; context.loadDialect(); Location loc(UnknownLoc::get(&context)); @@ -27,10 +27,81 @@ TEST(AttributeTest, BitVectorAttr) { context.getDiagEngine().registerHandler([&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0"); }); +} + +TEST(BitVectorAttrTest, ParserAndPrinterCorrect) { + MLIRContext context; + context.loadDialect(); - attr = BitVectorAttr::get(&context, "#b1010"); + auto attr = BitVectorAttr::get(&context, "#b1010"); ASSERT_EQ(attr.getValue(), APInt(4, 10)); ASSERT_EQ(attr.getType(), BitVectorType::get(&context, 4)); + + // A bit-width divisible by 4 is always printed in hex + attr = BitVectorAttr::get(&context, "#b01011010"); + ASSERT_EQ(attr.getValueAsString(), "#x5a"); + + // A bit-width not divisible by 4 is always printed in binary + // Also, make sure leading zeros are printed + attr = BitVectorAttr::get(&context, "#b0101101"); + ASSERT_EQ(attr.getValueAsString(), "#b0101101"); + + attr = BitVectorAttr::get(&context, "#x3c"); + ASSERT_EQ(attr.getValueAsString(), "#x3c"); + + attr = BitVectorAttr::get(&context, "#x03c"); + ASSERT_EQ(attr.getValueAsString(), "#x03c"); +} + +TEST(BitVectorAttrTest, ExpectedOneDigit) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("#b")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "expected at least one digit"); + }); +} + +TEST(BitVectorAttrTest, ExpectedBOrX) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("#c0")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "expected either 'b' or 'x'"); + }); +} + +TEST(BitVectorAttrTest, ExpectedHashtag) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("b0")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "expected '#'"); }); +} + +TEST(BitVectorAttrTest, OutOfRange) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = BitVectorAttr::getChecked(loc, &context, 2U, 1U); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), + "value does not fit in a bit-vector of desired width"); + }); } } // namespace