From bd9d3361ada1c2715a59bd0b9d191401d9059191 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Sat, 13 Jan 2024 15:07:37 +0100 Subject: [PATCH] [SMT] Add bitvector type, attribute, and constant operation --- include/circt/Dialect/SMT/CMakeLists.txt | 6 + include/circt/Dialect/SMT/SMT.td | 2 + include/circt/Dialect/SMT/SMTAttributes.h | 19 +++ include/circt/Dialect/SMT/SMTAttributes.td | 61 ++++++++ include/circt/Dialect/SMT/SMTBitVectorOps.td | 54 +++++++ include/circt/Dialect/SMT/SMTDialect.td | 4 + include/circt/Dialect/SMT/SMTOps.h | 1 + include/circt/Dialect/SMT/SMTOps.td | 2 +- include/circt/Dialect/SMT/SMTTypes.td | 16 ++ lib/Dialect/SMT/CMakeLists.txt | 2 + lib/Dialect/SMT/SMTAttributes.cpp | 155 +++++++++++++++++++ lib/Dialect/SMT/SMTDialect.cpp | 18 +++ lib/Dialect/SMT/SMTOps.cpp | 35 +++++ lib/Dialect/SMT/SMTTypes.cpp | 12 ++ test/Dialect/SMT/basic.mlir | 4 +- test/Dialect/SMT/bitvector-errors.mlir | 43 +++++ test/Dialect/SMT/bitvectors.mlir | 26 ++++ unittests/Dialect/CMakeLists.txt | 1 + unittests/Dialect/SMT/AttributeTest.cpp | 36 +++++ unittests/Dialect/SMT/CMakeLists.txt | 8 + 20 files changed, 502 insertions(+), 3 deletions(-) create mode 100644 include/circt/Dialect/SMT/SMTAttributes.h create mode 100644 include/circt/Dialect/SMT/SMTAttributes.td create mode 100644 include/circt/Dialect/SMT/SMTBitVectorOps.td create mode 100644 lib/Dialect/SMT/SMTAttributes.cpp create mode 100644 test/Dialect/SMT/bitvector-errors.mlir create mode 100644 test/Dialect/SMT/bitvectors.mlir create mode 100644 unittests/Dialect/SMT/AttributeTest.cpp create mode 100644 unittests/Dialect/SMT/CMakeLists.txt diff --git a/include/circt/Dialect/SMT/CMakeLists.txt b/include/circt/Dialect/SMT/CMakeLists.txt index fbf84b401b09..1e50abc24992 100644 --- a/include/circt/Dialect/SMT/CMakeLists.txt +++ b/include/circt/Dialect/SMT/CMakeLists.txt @@ -3,6 +3,12 @@ add_circt_doc(SMTOps Dialects/SMTOps -gen-op-doc) add_circt_doc(SMTTypes Dialects/SMTTypes -gen-typedef-doc -dialect smt) set(LLVM_TARGET_DEFINITIONS SMT.td) + +mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(CIRCTSMTAttrIncGen) +add_dependencies(circt-headers CIRCTSMTAttrIncGen) + mlir_tablegen(SMTEnums.h.inc -gen-enum-decls) mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(CIRCTSMTEnumsIncGen) diff --git a/include/circt/Dialect/SMT/SMT.td b/include/circt/Dialect/SMT/SMT.td index dc3e9b7b1200..8473f72a7436 100644 --- a/include/circt/Dialect/SMT/SMT.td +++ b/include/circt/Dialect/SMT/SMT.td @@ -11,8 +11,10 @@ include "mlir/IR/OpBase.td" +include "circt/Dialect/SMT/SMTAttributes.td" include "circt/Dialect/SMT/SMTDialect.td" include "circt/Dialect/SMT/SMTTypes.td" include "circt/Dialect/SMT/SMTOps.td" +include "circt/Dialect/SMT/SMTBitVectorOps.td" #endif // CIRCT_DIALECT_SMT_SMT_TD diff --git a/include/circt/Dialect/SMT/SMTAttributes.h b/include/circt/Dialect/SMT/SMTAttributes.h new file mode 100644 index 000000000000..6d0c5c133cbe --- /dev/null +++ b/include/circt/Dialect/SMT/SMTAttributes.h @@ -0,0 +1,19 @@ +//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_ATTRIBUTES_H +#define CIRCT_DIALECT_SMT_ATTRIBUTES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/SMT/SMTAttributes.h.inc" + +#endif // CIRCT_DIALECT_SMT_ATTRIBUTES_H diff --git a/include/circt/Dialect/SMT/SMTAttributes.td b/include/circt/Dialect/SMT/SMTAttributes.td new file mode 100644 index 000000000000..0e77075d05c1 --- /dev/null +++ b/include/circt/Dialect/SMT/SMTAttributes.td @@ -0,0 +1,61 @@ +//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines SMT dialect specific attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD +#define CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD + +include "circt/Dialect/SMT/SMTDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" + +def BitVectorAttr : AttrDef +]> { + let mnemonic = "bv"; + let description = [{ + This attribute represents a constant value of the `(_ BitVec width)` sort as + described in the [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The constant is parsed and printed as #bX (binary) or #xX (hexadecimal) + where X is the value in the corresponding format without any further + prefixing just as in SMT-LIB. The number of digits determines the bit-width + of the bit-vector. This means, leading zeros are important! + + Examples: + ```mlir + #smt.bv<"#b0101"> : !smt.bv<4> + #smt.bv<"#x5c"> : !smt.bv<8> + ``` + The explicit type-suffix is optional. + + The bit-width must be greater than zero (i.e., at least one digit as to be + present). + }]; + + let parameters = (ins "llvm::APInt":$value); + + let hasCustomAssemblyFormat = true; + let genVerifyDecl = true; + + let builders = [ + AttrBuilder<(ins "llvm::StringRef":$value)>, + AttrBuilder<(ins "unsigned":$value, "unsigned":$width)>, + ]; + + let extraClassDeclaration = [{ + /// Return the bit-vector constant as a SMT-LIB formatted string. + std::string getValueAsString(bool prefix = true) const; + }]; +} + +#endif // CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD diff --git a/include/circt/Dialect/SMT/SMTBitVectorOps.td b/include/circt/Dialect/SMT/SMTBitVectorOps.td new file mode 100644 index 000000000000..166f20430d1f --- /dev/null +++ b/include/circt/Dialect/SMT/SMTBitVectorOps.td @@ -0,0 +1,54 @@ +//===- SMTBitVectorOps.td - SMT bit-vector dialect ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD +#define CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD + +include "circt/Dialect/SMT/SMTDialect.td" +include "circt/Dialect/SMT/SMTAttributes.td" +include "circt/Dialect/SMT/SMTTypes.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SMTBVOp traits = []> : + Op; + +def BVConstantOp : SMTBVOp<"constant", [ + Pure, + ConstantLike, + FirstAttrDerivedResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = "produce a constant bit-vector"; + let description = [{ + This operation produces an SSA value equal to the bit-vector constant + specified by the 'value' attribute. + Refer to the `BitVectorAttr` documentation for more information about + the semantics of bit-vector constants, their format, and associated sort. + The result type always matches the attribute's type. + + Examples: + ```mlir + %bv_x5c = smt.bv.constant <"#x5c"> + %bv_b0101 = smt.bv.constant <"#b0101"> + ``` + }]; + + let arguments = (ins BitVectorAttr:$value); + let results = (outs BitVectorType:$result); + + let assemblyFormat = "$value attr-dict"; + + let hasFolder = true; + let hasVerifier = true; +} + +#endif // CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD diff --git a/include/circt/Dialect/SMT/SMTDialect.td b/include/circt/Dialect/SMT/SMTDialect.td index 72e038b0d3b0..27de8dff4d71 100644 --- a/include/circt/Dialect/SMT/SMTDialect.td +++ b/include/circt/Dialect/SMT/SMTDialect.td @@ -16,9 +16,13 @@ def SMTDialect : Dialect { let summary = "a dialect that models satisfiability modulo theories"; let cppNamespace = "circt::smt"; + let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let hasConstantMaterializer = 1; + let extraClassDeclaration = [{ + void registerAttributes(); void registerTypes(); }]; } diff --git a/include/circt/Dialect/SMT/SMTOps.h b/include/circt/Dialect/SMT/SMTOps.h index aa747b43927c..6fda00d703bb 100644 --- a/include/circt/Dialect/SMT/SMTOps.h +++ b/include/circt/Dialect/SMT/SMTOps.h @@ -14,6 +14,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "circt/Dialect/SMT/SMTAttributes.h" #include "circt/Dialect/SMT/SMTDialect.h" #include "circt/Dialect/SMT/SMTTypes.h" diff --git a/include/circt/Dialect/SMT/SMTOps.td b/include/circt/Dialect/SMT/SMTOps.td index 4af067333e5d..5cbe55756982 100644 --- a/include/circt/Dialect/SMT/SMTOps.td +++ b/include/circt/Dialect/SMT/SMTOps.td @@ -10,6 +10,7 @@ #define CIRCT_DIALECT_SMT_SMTOPS_TD include "circt/Dialect/SMT/SMTDialect.td" +include "circt/Dialect/SMT/SMTAttributes.td" include "circt/Dialect/SMT/SMTTypes.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpAsmInterface.td" @@ -19,5 +20,4 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class SMTOp traits = []> : Op; - #endif // CIRCT_DIALECT_SMT_SMTOPS_TD diff --git a/include/circt/Dialect/SMT/SMTTypes.td b/include/circt/Dialect/SMT/SMTTypes.td index 5dba084f7115..63588fad427c 100644 --- a/include/circt/Dialect/SMT/SMTTypes.td +++ b/include/circt/Dialect/SMT/SMTTypes.td @@ -19,4 +19,20 @@ def BoolType : SMTTypeDef<"Bool"> { let assemblyFormat = ""; } +def BitVectorType : SMTTypeDef<"BitVector"> { + let mnemonic = "bv"; + let description = [{ + This type represents the `(_ BitVec width)` sort as described in the + [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The bit-width must be strictly greater than zero. + }]; + + let parameters = (ins "unsigned":$width); + let assemblyFormat = "`<` $width `>`"; + + let genVerifyDecl = true; +} + #endif // CIRCT_DIALECT_SMT_SMTTYPES_TD diff --git a/lib/Dialect/SMT/CMakeLists.txt b/lib/Dialect/SMT/CMakeLists.txt index f3b7defdb58f..56c2dc1f4900 100644 --- a/lib/Dialect/SMT/CMakeLists.txt +++ b/lib/Dialect/SMT/CMakeLists.txt @@ -1,4 +1,5 @@ add_circt_dialect_library(CIRCTSMT + SMTAttributes.cpp SMTDialect.cpp SMTOps.cpp SMTTypes.cpp @@ -7,6 +8,7 @@ add_circt_dialect_library(CIRCTSMT ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/SMT DEPENDS + CIRCTSMTAttrIncGen CIRCTSMTEnumsIncGen MLIRSMTIncGen diff --git a/lib/Dialect/SMT/SMTAttributes.cpp b/lib/Dialect/SMT/SMTAttributes.cpp new file mode 100644 index 000000000000..bca45233fb65 --- /dev/null +++ b/lib/Dialect/SMT/SMTAttributes.cpp @@ -0,0 +1,155 @@ +//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Format.h" + +using namespace circt; +using namespace circt::smt; + +//===----------------------------------------------------------------------===// +// BitVectorAttr +//===----------------------------------------------------------------------===// + +LogicalResult +BitVectorAttr::verify(function_ref emitError, + APInt value) { + if (value.getBitWidth() < 1) + return emitError() << "bit-width must be at least 1, but got " + << value.getBitWidth(); + return success(); +} + +std::string BitVectorAttr::getValueAsString(bool prefix) const { + unsigned width = getValue().getBitWidth(); + SmallVector toPrint; + StringRef pref = prefix ? "#" : ""; + if (width % 4 == 0) { + getValue().toString(toPrint, 16, false, false, false); + // APInt's 'toString' omits leading zeros. However, those are critical here + // because they determine the bit-width of the bit-vector. + SmallVector leadingZeros(width / 4 - toPrint.size(), '0'); + return (pref + "x" + Twine(leadingZeros) + toPrint).str(); + } + + getValue().toString(toPrint, 2, false, false, false); + // APInt's 'toString' omits leading zeros + SmallVector leadingZeros(width - toPrint.size(), '0'); + return (pref + "b" + Twine(leadingZeros) + toPrint).str(); +} + +/// Parse an SMT-LIB formatted bit-vector string. +static FailureOr +parseBitVectorString(function_ref emitError, + StringRef value) { + if (value[0] != '#') + return emitError() << "expected '#'"; + + if (value.size() < 3) + return emitError() << "expected at least one digit"; + + if (value[1] == 'b') + return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), + 2); + + if (value[1] == 'x') + return APInt((value.size() - 2) * 4, + std::string(value.begin() + 2, value.end()), 16); + + return emitError() << "expected either 'b' or 'x'"; +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { + InFlightDiagnostic diag; + auto maybeValue = + parseBitVectorString([&]() { return std::move(diag); }, value); + diag.abandon(); + + assert(succeeded(maybeValue) && "string must have SMT-LIB format"); + return Base::get(context, *maybeValue); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, StringRef value) { + auto maybeValue = parseBitVectorString(emitError, value); + if (failed(maybeValue)) + return {}; + + return Base::getChecked(emitError, context, *maybeValue); +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, unsigned value, + unsigned width) { + return Base::get(context, APInt(width, value)); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, unsigned value, + unsigned width) { + return Base::getChecked(emitError, context, APInt(width, value)); +} + +Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { + llvm::SMLoc loc = odsParser.getCurrentLocation(); + + std::string val; + if (odsParser.parseLess() || odsParser.parseString(&val)) + return {}; + + auto maybeVal = + parseBitVectorString([&]() { return odsParser.emitError(loc); }, val); + if (failed(maybeVal)) + return {}; + + if (odsParser.parseGreater()) + return {}; + + // We implement the TypedAttr interface, i.e., the AsmParser allows an + // optional explicit type as suffix which is provided here as 'odsType'. We + // can always build the type from the constant itself, therefore, we can + // ignore this explicit type. However, to ensure consistency we should verify + // that the correct type is provided. + auto bv = BitVectorAttr::get(odsParser.getContext(), *maybeVal); + if (odsType && bv.getType() != odsType) { + odsParser.emitError(loc, "expected type for constant does not match " + "explicitly provided attribute type, got ") + << odsType << ", expected " << bv.getType(); + return {}; + } + + return bv; +} + +void BitVectorAttr::print(AsmPrinter &odsPrinter) const { + odsPrinter << "<\"" << getValueAsString() << "\">"; +} + +Type BitVectorAttr::getType() const { + return BitVectorType::get(getContext(), getValue().getBitWidth()); +} + +//===----------------------------------------------------------------------===// +// ODS Boilerplate +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/SMT/SMTAttributes.cpp.inc" + +void SMTDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "circt/Dialect/SMT/SMTAttributes.cpp.inc" + >(); +} diff --git a/lib/Dialect/SMT/SMTDialect.cpp b/lib/Dialect/SMT/SMTDialect.cpp index 81d150704316..47d6a3c6b654 100644 --- a/lib/Dialect/SMT/SMTDialect.cpp +++ b/lib/Dialect/SMT/SMTDialect.cpp @@ -7,11 +7,15 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTOps.h" +#include "circt/Dialect/SMT/SMTTypes.h" using namespace circt; using namespace smt; void SMTDialect::initialize() { + registerAttributes(); registerTypes(); addOperations< #define GET_OP_LIST @@ -19,5 +23,19 @@ void SMTDialect::initialize() { >(); } +Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // BitVectorType constants can materialize into smt.bv.constant + if (auto bvType = type.dyn_cast()) { + if (auto attrValue = value.dyn_cast()) { + bool typesMatch = bvType == attrValue.getType(); + assert(typesMatch && "attribute and desired result types have to match"); + return builder.create(loc, attrValue); + } + } + + return nullptr; +} + #include "circt/Dialect/SMT/SMTDialect.cpp.inc" #include "circt/Dialect/SMT/SMTEnums.cpp.inc" diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index 2a94063867eb..5f9606dc142d 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -14,5 +14,40 @@ using namespace circt; using namespace smt; using namespace mlir; +//===----------------------------------------------------------------------===// +// BVConstantOp +//===----------------------------------------------------------------------===// + +LogicalResult BVConstantOp::verify() { + if (getValue().getType() != getType()) + return emitError( + "smt.bv.constant attribute bitwidth doesn't match return type"); + + return success(); +} + +LogicalResult BVConstantOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.push_back( + properties.as()->getValue().getType()); + return success(); +} + +void BVConstantOp::getAsmResultNames( + function_ref setNameFn) { + SmallVector specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "bv_" << getValue().getValueAsString(false); + setNameFn(getResult(), specialName.str()); +} + +OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + #define GET_OP_CLASSES #include "circt/Dialect/SMT/SMT.cpp.inc" diff --git a/lib/Dialect/SMT/SMTTypes.cpp b/lib/Dialect/SMT/SMTTypes.cpp index 96e8d9e855ad..c57a6ab12946 100644 --- a/lib/Dialect/SMT/SMTTypes.cpp +++ b/lib/Dialect/SMT/SMTTypes.cpp @@ -25,3 +25,15 @@ void SMTDialect::registerTypes() { #include "circt/Dialect/SMT/SMTTypes.cpp.inc" >(); } + +//===----------------------------------------------------------------------===// +// BitVectorType +//===----------------------------------------------------------------------===// + +LogicalResult +BitVectorType::verify(function_ref emitError, + unsigned width) { + if (width <= 0) + return emitError() << "bit-vector must have at least a width of one"; + return success(); +} diff --git a/test/Dialect/SMT/basic.mlir b/test/Dialect/SMT/basic.mlir index 5f3064bb48bb..fcde1a475f86 100644 --- a/test/Dialect/SMT/basic.mlir +++ b/test/Dialect/SMT/basic.mlir @@ -1,7 +1,7 @@ // RUN: circt-opt %s | circt-opt | FileCheck %s // CHECK-LABEL: func @types -// CHECK-SAME: (%{{.*}}: !smt.bool) -func.func @types(%arg0: !smt.bool) { +// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>) +func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>) { return } diff --git a/test/Dialect/SMT/bitvector-errors.mlir b/test/Dialect/SMT/bitvector-errors.mlir new file mode 100644 index 000000000000..495f5105966b --- /dev/null +++ b/test/Dialect/SMT/bitvector-errors.mlir @@ -0,0 +1,43 @@ +// RUN: circt-opt %s --split-input-file --verify-diagnostics + +// expected-error @below {{bit-vector must have at least a width of one}} +func.func @at_least_size_one(%arg0: !smt.bv<0>) { + return +} + +// ----- + +func.func @attr_type_and_return_type_match() { + // expected-error @below {{smt.bv.constant attribute bitwidth doesn't match return type}} + %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<1>}> : () -> !smt.bv<32> + return +} + +// ----- + +func.func @implicit_constant_type_and_explicit_type_match() { + // expected-error @below {{expected type for constant does not match explicitly provided attribute type, got '!smt.bv<2>', expected '!smt.bv<1>'}} + %c0_bv2 = "smt.bv.constant"() <{value = #smt.bv<"#b0"> : !smt.bv<2>}> : () -> !smt.bv<1> + return +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{expected at least one digit}} + smt.bv.constant #smt.bv<"#b"> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{expected either 'b' or 'x'}} + smt.bv.constant #smt.bv<"#c0"> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{expected '#'}} + smt.bv.constant #smt.bv<"b"> +} diff --git a/test/Dialect/SMT/bitvectors.mlir b/test/Dialect/SMT/bitvectors.mlir new file mode 100644 index 000000000000..3645937675d6 --- /dev/null +++ b/test/Dialect/SMT/bitvectors.mlir @@ -0,0 +1,26 @@ +// RUN: circt-opt %s | circt-opt | FileCheck %s + +// CHECK-LABEL: func @bitvectors +func.func @bitvectors() { + // A bit-width divisible by 4 is always printed in hex + // CHECK: %bv_x5a = smt.bv.constant <"#x5a"> {smt.some_attr} + %bv_x5a = smt.bv.constant <"#b01011010"> {smt.some_attr} + + // A bit-width not divisible by 4 is always printed in binary + // Also, make sure leading zeros are printed + // CHECK: %bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr} + %bv_b0101101 = smt.bv.constant <"#b0101101"> {smt.some_attr} + + // CHECK: %bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr} + %bv_x3c = smt.bv.constant <"#x3c"> {smt.some_attr} + + // Make sure leading zeros are printed + // CHECK: %bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr} + %bv_x03c = smt.bv.constant <"#x03c"> {smt.some_attr} + + // It is allowed to fully quantify the attribute including an explicit type + // CHECK: %bv_x3cd = smt.bv.constant <"#x3cd"> + %bv_x3cd = smt.bv.constant #smt.bv<"#x3cd"> : !smt.bv<12> + + return +} diff --git a/unittests/Dialect/CMakeLists.txt b/unittests/Dialect/CMakeLists.txt index aaa5a1db6a87..1b05fe1054f6 100644 --- a/unittests/Dialect/CMakeLists.txt +++ b/unittests/Dialect/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Moore) add_subdirectory(FIRRTL) add_subdirectory(HW) add_subdirectory(OM) +add_subdirectory(SMT) diff --git a/unittests/Dialect/SMT/AttributeTest.cpp b/unittests/Dialect/SMT/AttributeTest.cpp new file mode 100644 index 000000000000..ebb5d8f6c2e6 --- /dev/null +++ b/unittests/Dialect/SMT/AttributeTest.cpp @@ -0,0 +1,36 @@ +//===- AttributeTest.cpp - SMT attribute unit tests -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace circt; +using namespace smt; + +namespace { + +TEST(AttributeTest, BitVectorAttr) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = BitVectorAttr::getChecked(loc, &context, 0U, 0U); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0"); + }); + + attr = BitVectorAttr::get(&context, "#b1010"); + ASSERT_EQ(attr.getValue(), APInt(4, 10)); + ASSERT_EQ(attr.getType(), BitVectorType::get(&context, 4)); +} + +} // namespace diff --git a/unittests/Dialect/SMT/CMakeLists.txt b/unittests/Dialect/SMT/CMakeLists.txt new file mode 100644 index 000000000000..2e9ba7eef89f --- /dev/null +++ b/unittests/Dialect/SMT/CMakeLists.txt @@ -0,0 +1,8 @@ +add_circt_unittest(CIRCTSMTTests + AttributeTest.cpp +) + +target_link_libraries(CIRCTSMTTests + PRIVATE + CIRCTSMT +)