From 7e0ae45817417b8ec94dc1622b906f5d804dad5f Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 12 Mar 2024 08:27:58 +0100 Subject: [PATCH] [SMT] Add bitvector type, attribute, and constant operation (#6804) To clearly separate semantics, define a bit-vector type and attribute instead of reusing the built-in integer attribute. The built-in integer is usually encoded using two SMT bit-vectors to model poison and the regular bit values. --- include/circt/Dialect/SMT/CMakeLists.txt | 6 + include/circt/Dialect/SMT/SMT.td | 2 + include/circt/Dialect/SMT/SMTAttributes.h | 19 +++ include/circt/Dialect/SMT/SMTAttributes.td | 64 ++++++++ include/circt/Dialect/SMT/SMTBitVectorOps.td | 64 ++++++++ include/circt/Dialect/SMT/SMTDialect.td | 4 + include/circt/Dialect/SMT/SMTOps.h | 1 + include/circt/Dialect/SMT/SMTOps.td | 2 +- include/circt/Dialect/SMT/SMTTypes.td | 16 ++ lib/Dialect/SMT/CMakeLists.txt | 2 + lib/Dialect/SMT/SMTAttributes.cpp | 159 +++++++++++++++++++ lib/Dialect/SMT/SMTDialect.cpp | 18 +++ lib/Dialect/SMT/SMTOps.cpp | 28 ++++ lib/Dialect/SMT/SMTTypes.cpp | 12 ++ test/Dialect/SMT/basic.mlir | 4 +- test/Dialect/SMT/bitvector-errors.mlir | 36 +++++ test/Dialect/SMT/bitvectors.mlir | 13 ++ unittests/Dialect/CMakeLists.txt | 1 + unittests/Dialect/SMT/AttributeTest.cpp | 107 +++++++++++++ unittests/Dialect/SMT/CMakeLists.txt | 8 + 20 files changed, 563 insertions(+), 3 deletions(-) create mode 100644 include/circt/Dialect/SMT/SMTAttributes.h create mode 100644 include/circt/Dialect/SMT/SMTAttributes.td create mode 100644 include/circt/Dialect/SMT/SMTBitVectorOps.td create mode 100644 lib/Dialect/SMT/SMTAttributes.cpp create mode 100644 test/Dialect/SMT/bitvector-errors.mlir create mode 100644 test/Dialect/SMT/bitvectors.mlir create mode 100644 unittests/Dialect/SMT/AttributeTest.cpp create mode 100644 unittests/Dialect/SMT/CMakeLists.txt diff --git a/include/circt/Dialect/SMT/CMakeLists.txt b/include/circt/Dialect/SMT/CMakeLists.txt index fbf84b401b09..1e50abc24992 100644 --- a/include/circt/Dialect/SMT/CMakeLists.txt +++ b/include/circt/Dialect/SMT/CMakeLists.txt @@ -3,6 +3,12 @@ add_circt_doc(SMTOps Dialects/SMTOps -gen-op-doc) add_circt_doc(SMTTypes Dialects/SMTTypes -gen-typedef-doc -dialect smt) set(LLVM_TARGET_DEFINITIONS SMT.td) + +mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(CIRCTSMTAttrIncGen) +add_dependencies(circt-headers CIRCTSMTAttrIncGen) + mlir_tablegen(SMTEnums.h.inc -gen-enum-decls) mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(CIRCTSMTEnumsIncGen) diff --git a/include/circt/Dialect/SMT/SMT.td b/include/circt/Dialect/SMT/SMT.td index dc3e9b7b1200..8473f72a7436 100644 --- a/include/circt/Dialect/SMT/SMT.td +++ b/include/circt/Dialect/SMT/SMT.td @@ -11,8 +11,10 @@ include "mlir/IR/OpBase.td" +include "circt/Dialect/SMT/SMTAttributes.td" include "circt/Dialect/SMT/SMTDialect.td" include "circt/Dialect/SMT/SMTTypes.td" include "circt/Dialect/SMT/SMTOps.td" +include "circt/Dialect/SMT/SMTBitVectorOps.td" #endif // CIRCT_DIALECT_SMT_SMT_TD diff --git a/include/circt/Dialect/SMT/SMTAttributes.h b/include/circt/Dialect/SMT/SMTAttributes.h new file mode 100644 index 000000000000..5e0a3f42a290 --- /dev/null +++ b/include/circt/Dialect/SMT/SMTAttributes.h @@ -0,0 +1,19 @@ +//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_SMTATTRIBUTES_H +#define CIRCT_DIALECT_SMT_SMTATTRIBUTES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/SMT/SMTAttributes.h.inc" + +#endif // CIRCT_DIALECT_SMT_SMTATTRIBUTES_H diff --git a/include/circt/Dialect/SMT/SMTAttributes.td b/include/circt/Dialect/SMT/SMTAttributes.td new file mode 100644 index 000000000000..3023eebe061b --- /dev/null +++ b/include/circt/Dialect/SMT/SMTAttributes.td @@ -0,0 +1,64 @@ +//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines SMT dialect specific attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD +#define CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD + +include "circt/Dialect/SMT/SMTDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" + +def BitVectorAttr : AttrDef +]> { + let mnemonic = "bv"; + let description = [{ + This attribute represents a constant value of the `(_ BitVec width)` sort as + described in the [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The constant is as #bX (binary) or #xX (hexadecimal) in SMT-LIB + where X is the value in the corresponding format without any further + prefixing. Here, the bit-vector constant is given as a regular integer + literal and the associated bit-vector type indicating the bit-width. + + Examples: + ```mlir + #smt.bv<5> : !smt.bv<4> + #smt.bv<92> : !smt.bv<8> + ``` + + The explicit type-suffix is mandatory to uniquely represent the attribute, + i.e., this attribute should always be used in the extended form (using the + `quantified` keyword in the operation assembly format string). + + The bit-width must be greater than zero (i.e., at least one digit has to be + present). + }]; + + let parameters = (ins "llvm::APInt":$value); + + let hasCustomAssemblyFormat = true; + let genVerifyDecl = true; + + let builders = [ + AttrBuilder<(ins "llvm::StringRef":$value)>, + AttrBuilder<(ins "unsigned":$value, "unsigned":$width)>, + ]; + + let extraClassDeclaration = [{ + /// Return the bit-vector constant as a SMT-LIB formatted string. + std::string getValueAsString(bool prefix = true) const; + }]; +} + +#endif // CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD diff --git a/include/circt/Dialect/SMT/SMTBitVectorOps.td b/include/circt/Dialect/SMT/SMTBitVectorOps.td new file mode 100644 index 000000000000..c2a352bb0e6c --- /dev/null +++ b/include/circt/Dialect/SMT/SMTBitVectorOps.td @@ -0,0 +1,64 @@ +//===- SMTBitVectorOps.td - SMT bit-vector dialect ops -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD +#define CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD + +include "circt/Dialect/SMT/SMTDialect.td" +include "circt/Dialect/SMT/SMTAttributes.td" +include "circt/Dialect/SMT/SMTTypes.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class SMTBVOp traits = []> : + Op; + +def BVConstantOp : SMTBVOp<"constant", [ + Pure, + ConstantLike, + FirstAttrDerivedResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods +]> { + let summary = "produce a constant bit-vector"; + let description = [{ + This operation produces an SSA value equal to the bit-vector constant + specified by the 'value' attribute. + Refer to the `BitVectorAttr` documentation for more information about + the semantics of bit-vector constants, their format, and associated sort. + The result type always matches the attribute's type. + + Examples: + ```mlir + %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> + %c5_bv4 = smt.bv.constant #smt.bv<5> : !smt.bv<4> + ``` + }]; + + let arguments = (ins BitVectorAttr:$value); + let results = (outs BitVectorType:$result); + + let assemblyFormat = "qualified($value) attr-dict"; + + let builders = [ + OpBuilder<(ins "const llvm::APInt &":$value), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value)); + }]>, + OpBuilder<(ins "unsigned":$value, "unsigned":$width), [{ + build($_builder, $_state, + BitVectorAttr::get($_builder.getContext(), value, width)); + }]>, + ]; + + let hasFolder = true; +} + +#endif // CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD diff --git a/include/circt/Dialect/SMT/SMTDialect.td b/include/circt/Dialect/SMT/SMTDialect.td index 72e038b0d3b0..27de8dff4d71 100644 --- a/include/circt/Dialect/SMT/SMTDialect.td +++ b/include/circt/Dialect/SMT/SMTDialect.td @@ -16,9 +16,13 @@ def SMTDialect : Dialect { let summary = "a dialect that models satisfiability modulo theories"; let cppNamespace = "circt::smt"; + let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let hasConstantMaterializer = 1; + let extraClassDeclaration = [{ + void registerAttributes(); void registerTypes(); }]; } diff --git a/include/circt/Dialect/SMT/SMTOps.h b/include/circt/Dialect/SMT/SMTOps.h index aa747b43927c..6fda00d703bb 100644 --- a/include/circt/Dialect/SMT/SMTOps.h +++ b/include/circt/Dialect/SMT/SMTOps.h @@ -14,6 +14,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "circt/Dialect/SMT/SMTAttributes.h" #include "circt/Dialect/SMT/SMTDialect.h" #include "circt/Dialect/SMT/SMTTypes.h" diff --git a/include/circt/Dialect/SMT/SMTOps.td b/include/circt/Dialect/SMT/SMTOps.td index 4af067333e5d..5cbe55756982 100644 --- a/include/circt/Dialect/SMT/SMTOps.td +++ b/include/circt/Dialect/SMT/SMTOps.td @@ -10,6 +10,7 @@ #define CIRCT_DIALECT_SMT_SMTOPS_TD include "circt/Dialect/SMT/SMTDialect.td" +include "circt/Dialect/SMT/SMTAttributes.td" include "circt/Dialect/SMT/SMTTypes.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpAsmInterface.td" @@ -19,5 +20,4 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class SMTOp traits = []> : Op; - #endif // CIRCT_DIALECT_SMT_SMTOPS_TD diff --git a/include/circt/Dialect/SMT/SMTTypes.td b/include/circt/Dialect/SMT/SMTTypes.td index 5dba084f7115..63588fad427c 100644 --- a/include/circt/Dialect/SMT/SMTTypes.td +++ b/include/circt/Dialect/SMT/SMTTypes.td @@ -19,4 +19,20 @@ def BoolType : SMTTypeDef<"Bool"> { let assemblyFormat = ""; } +def BitVectorType : SMTTypeDef<"BitVector"> { + let mnemonic = "bv"; + let description = [{ + This type represents the `(_ BitVec width)` sort as described in the + [SMT bit-vector + theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml). + + The bit-width must be strictly greater than zero. + }]; + + let parameters = (ins "unsigned":$width); + let assemblyFormat = "`<` $width `>`"; + + let genVerifyDecl = true; +} + #endif // CIRCT_DIALECT_SMT_SMTTYPES_TD diff --git a/lib/Dialect/SMT/CMakeLists.txt b/lib/Dialect/SMT/CMakeLists.txt index f3b7defdb58f..56c2dc1f4900 100644 --- a/lib/Dialect/SMT/CMakeLists.txt +++ b/lib/Dialect/SMT/CMakeLists.txt @@ -1,4 +1,5 @@ add_circt_dialect_library(CIRCTSMT + SMTAttributes.cpp SMTDialect.cpp SMTOps.cpp SMTTypes.cpp @@ -7,6 +8,7 @@ add_circt_dialect_library(CIRCTSMT ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/SMT DEPENDS + CIRCTSMTAttrIncGen CIRCTSMTEnumsIncGen MLIRSMTIncGen diff --git a/lib/Dialect/SMT/SMTAttributes.cpp b/lib/Dialect/SMT/SMTAttributes.cpp new file mode 100644 index 000000000000..cc00ffa3d383 --- /dev/null +++ b/lib/Dialect/SMT/SMTAttributes.cpp @@ -0,0 +1,159 @@ +//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Format.h" + +using namespace circt; +using namespace circt::smt; + +//===----------------------------------------------------------------------===// +// BitVectorAttr +//===----------------------------------------------------------------------===// + +LogicalResult BitVectorAttr::verify( + function_ref emitError, + APInt value) { // NOLINT(performance-unnecessary-value-param) + if (value.getBitWidth() < 1) + return emitError() << "bit-width must be at least 1, but got " + << value.getBitWidth(); + return success(); +} + +std::string BitVectorAttr::getValueAsString(bool prefix) const { + unsigned width = getValue().getBitWidth(); + SmallVector toPrint; + StringRef pref = prefix ? "#" : ""; + if (width % 4 == 0) { + getValue().toString(toPrint, 16, false, false, false); + // APInt's 'toString' omits leading zeros. However, those are critical here + // because they determine the bit-width of the bit-vector. + SmallVector leadingZeros(width / 4 - toPrint.size(), '0'); + return (pref + "x" + Twine(leadingZeros) + toPrint).str(); + } + + getValue().toString(toPrint, 2, false, false, false); + // APInt's 'toString' omits leading zeros + SmallVector leadingZeros(width - toPrint.size(), '0'); + return (pref + "b" + Twine(leadingZeros) + toPrint).str(); +} + +/// Parse an SMT-LIB formatted bit-vector string. +static FailureOr +parseBitVectorString(function_ref emitError, + StringRef value) { + if (value[0] != '#') + return emitError() << "expected '#'"; + + if (value.size() < 3) + return emitError() << "expected at least one digit"; + + if (value[1] == 'b') + return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), + 2); + + if (value[1] == 'x') + return APInt((value.size() - 2) * 4, + std::string(value.begin() + 2, value.end()), 16); + + return emitError() << "expected either 'b' or 'x'"; +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { + auto maybeValue = parseBitVectorString(nullptr, value); + + assert(succeeded(maybeValue) && "string must have SMT-LIB format"); + return Base::get(context, *maybeValue); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, StringRef value) { + auto maybeValue = parseBitVectorString(emitError, value); + if (failed(maybeValue)) + return {}; + + return Base::getChecked(emitError, context, *maybeValue); +} + +BitVectorAttr BitVectorAttr::get(MLIRContext *context, unsigned value, + unsigned width) { + return Base::get(context, APInt(width, value)); +} + +BitVectorAttr +BitVectorAttr::getChecked(function_ref emitError, + MLIRContext *context, unsigned value, + unsigned width) { + if ((~((1U << width) - 1U) & value) != 0U) { + emitError() << "value does not fit in a bit-vector of desired width"; + return {}; + } + return Base::getChecked(emitError, context, APInt(width, value)); +} + +Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { + llvm::SMLoc loc = odsParser.getCurrentLocation(); + + APInt val; + if (odsParser.parseLess() || odsParser.parseInteger(val) || + odsParser.parseGreater()) + return {}; + + // Requires the use of `quantified()` in operation assembly formats. + if (!odsType || !llvm::isa(odsType)) { + odsParser.emitError(loc) << "explicit bit-vector type required"; + return {}; + } + + unsigned width = llvm::cast(odsType).getWidth(); + if (width > val.getBitWidth()) + val = val.sext(width); + + if (width < val.getBitWidth()) { + if ((val.isNegative() && val.getSignificantBits() > width) || + val.getActiveBits() > width) { + odsParser.emitError(loc) + << "integer value out of range for given bit-vector type"; + return {}; + } + val = val.trunc(width); + } + + return BitVectorAttr::get(odsParser.getContext(), val); +} + +void BitVectorAttr::print(AsmPrinter &odsPrinter) const { + // This printer only works for the extended format where the MLIR + // infrastructure prints the type for us. This means, the attribute should + // never be used without `quantified` in an assembly format. + odsPrinter << "<" << getValue() << ">"; +} + +Type BitVectorAttr::getType() const { + return BitVectorType::get(getContext(), getValue().getBitWidth()); +} + +//===----------------------------------------------------------------------===// +// ODS Boilerplate +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/SMT/SMTAttributes.cpp.inc" + +void SMTDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "circt/Dialect/SMT/SMTAttributes.cpp.inc" + >(); +} diff --git a/lib/Dialect/SMT/SMTDialect.cpp b/lib/Dialect/SMT/SMTDialect.cpp index 81d150704316..47d6a3c6b654 100644 --- a/lib/Dialect/SMT/SMTDialect.cpp +++ b/lib/Dialect/SMT/SMTDialect.cpp @@ -7,11 +7,15 @@ //===----------------------------------------------------------------------===// #include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTOps.h" +#include "circt/Dialect/SMT/SMTTypes.h" using namespace circt; using namespace smt; void SMTDialect::initialize() { + registerAttributes(); registerTypes(); addOperations< #define GET_OP_LIST @@ -19,5 +23,19 @@ void SMTDialect::initialize() { >(); } +Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // BitVectorType constants can materialize into smt.bv.constant + if (auto bvType = type.dyn_cast()) { + if (auto attrValue = value.dyn_cast()) { + bool typesMatch = bvType == attrValue.getType(); + assert(typesMatch && "attribute and desired result types have to match"); + return builder.create(loc, attrValue); + } + } + + return nullptr; +} + #include "circt/Dialect/SMT/SMTDialect.cpp.inc" #include "circt/Dialect/SMT/SMTEnums.cpp.inc" diff --git a/lib/Dialect/SMT/SMTOps.cpp b/lib/Dialect/SMT/SMTOps.cpp index 2a94063867eb..b086e77cac95 100644 --- a/lib/Dialect/SMT/SMTOps.cpp +++ b/lib/Dialect/SMT/SMTOps.cpp @@ -14,5 +14,33 @@ using namespace circt; using namespace smt; using namespace mlir; +//===----------------------------------------------------------------------===// +// BVConstantOp +//===----------------------------------------------------------------------===// + +LogicalResult BVConstantOp::inferReturnTypes( + mlir::MLIRContext *context, std::optional location, + ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, + ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, + ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { + inferredReturnTypes.push_back( + properties.as()->getValue().getType()); + return success(); +} + +void BVConstantOp::getAsmResultNames( + function_ref setNameFn) { + SmallVector specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "c" << getValue().getValue() << "_bv" + << getValue().getValue().getBitWidth(); + setNameFn(getResult(), specialName.str()); +} + +OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + #define GET_OP_CLASSES #include "circt/Dialect/SMT/SMT.cpp.inc" diff --git a/lib/Dialect/SMT/SMTTypes.cpp b/lib/Dialect/SMT/SMTTypes.cpp index 96e8d9e855ad..c57a6ab12946 100644 --- a/lib/Dialect/SMT/SMTTypes.cpp +++ b/lib/Dialect/SMT/SMTTypes.cpp @@ -25,3 +25,15 @@ void SMTDialect::registerTypes() { #include "circt/Dialect/SMT/SMTTypes.cpp.inc" >(); } + +//===----------------------------------------------------------------------===// +// BitVectorType +//===----------------------------------------------------------------------===// + +LogicalResult +BitVectorType::verify(function_ref emitError, + unsigned width) { + if (width <= 0) + return emitError() << "bit-vector must have at least a width of one"; + return success(); +} diff --git a/test/Dialect/SMT/basic.mlir b/test/Dialect/SMT/basic.mlir index 5f3064bb48bb..fcde1a475f86 100644 --- a/test/Dialect/SMT/basic.mlir +++ b/test/Dialect/SMT/basic.mlir @@ -1,7 +1,7 @@ // RUN: circt-opt %s | circt-opt | FileCheck %s // CHECK-LABEL: func @types -// CHECK-SAME: (%{{.*}}: !smt.bool) -func.func @types(%arg0: !smt.bool) { +// CHECK-SAME: (%{{.*}}: !smt.bool, %{{.*}}: !smt.bv<32>) +func.func @types(%arg0: !smt.bool, %arg1: !smt.bv<32>) { return } diff --git a/test/Dialect/SMT/bitvector-errors.mlir b/test/Dialect/SMT/bitvector-errors.mlir new file mode 100644 index 000000000000..7c2d1952bdf6 --- /dev/null +++ b/test/Dialect/SMT/bitvector-errors.mlir @@ -0,0 +1,36 @@ +// RUN: circt-opt %s --split-input-file --verify-diagnostics + +// expected-error @below {{bit-vector must have at least a width of one}} +func.func @at_least_size_one(%arg0: !smt.bv<0>) { + return +} + +// ----- + +func.func @attr_type_and_return_type_match() { + // expected-error @below {{inferred type(s) '!smt.bv<1>' are incompatible with return type(s) of operation '!smt.bv<32>'}} + // expected-error @below {{failed to infer returned types}} + %c0_bv32 = "smt.bv.constant"() <{value = #smt.bv<0> : !smt.bv<1>}> : () -> !smt.bv<32> + return +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{explicit bit-vector type required}} + smt.bv.constant #smt.bv<5> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<32> : !smt.bv<2> +} + +// ----- + +func.func @invalid_bitvector_attr() { + // expected-error @below {{integer value out of range for given bit-vector type}} + smt.bv.constant #smt.bv<-4> : !smt.bv<2> +} diff --git a/test/Dialect/SMT/bitvectors.mlir b/test/Dialect/SMT/bitvectors.mlir new file mode 100644 index 000000000000..15d07ef197d5 --- /dev/null +++ b/test/Dialect/SMT/bitvectors.mlir @@ -0,0 +1,13 @@ +// RUN: circt-opt %s | circt-opt | FileCheck %s + +// CHECK-LABEL: func @bitvectors +func.func @bitvectors() { + // CHECK: %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + %c5_bv32 = smt.bv.constant #smt.bv<5> : !smt.bv<32> {smt.some_attr} + // CHECK: %c92_bv8 = smt.bv.constant #smt.bv<92> : !smt.bv<8> {smt.some_attr} + %c92_bv8 = smt.bv.constant #smt.bv<0x5c> : !smt.bv<8> {smt.some_attr} + // CHECK: %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + %c-1_bv8 = smt.bv.constant #smt.bv<-1> : !smt.bv<8> + + return +} diff --git a/unittests/Dialect/CMakeLists.txt b/unittests/Dialect/CMakeLists.txt index aaa5a1db6a87..1b05fe1054f6 100644 --- a/unittests/Dialect/CMakeLists.txt +++ b/unittests/Dialect/CMakeLists.txt @@ -2,3 +2,4 @@ add_subdirectory(Moore) add_subdirectory(FIRRTL) add_subdirectory(HW) add_subdirectory(OM) +add_subdirectory(SMT) diff --git a/unittests/Dialect/SMT/AttributeTest.cpp b/unittests/Dialect/SMT/AttributeTest.cpp new file mode 100644 index 000000000000..f24dcc111670 --- /dev/null +++ b/unittests/Dialect/SMT/AttributeTest.cpp @@ -0,0 +1,107 @@ +//===- AttributeTest.cpp - SMT attribute unit tests -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/SMT/SMTAttributes.h" +#include "circt/Dialect/SMT/SMTDialect.h" +#include "circt/Dialect/SMT/SMTTypes.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace circt; +using namespace smt; + +namespace { + +TEST(BitVectorAttrTest, MinBitWidth) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = BitVectorAttr::getChecked(loc, &context, 0U, 0U); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "bit-width must be at least 1, but got 0"); + }); +} + +TEST(BitVectorAttrTest, ParserAndPrinterCorrect) { + MLIRContext context; + context.loadDialect(); + + auto attr = BitVectorAttr::get(&context, "#b1010"); + ASSERT_EQ(attr.getValue(), APInt(4, 10)); + ASSERT_EQ(attr.getType(), BitVectorType::get(&context, 4)); + + // A bit-width divisible by 4 is always printed in hex + attr = BitVectorAttr::get(&context, "#b01011010"); + ASSERT_EQ(attr.getValueAsString(), "#x5a"); + + // A bit-width not divisible by 4 is always printed in binary + // Also, make sure leading zeros are printed + attr = BitVectorAttr::get(&context, "#b0101101"); + ASSERT_EQ(attr.getValueAsString(), "#b0101101"); + + attr = BitVectorAttr::get(&context, "#x3c"); + ASSERT_EQ(attr.getValueAsString(), "#x3c"); + + attr = BitVectorAttr::get(&context, "#x03c"); + ASSERT_EQ(attr.getValueAsString(), "#x03c"); +} + +TEST(BitVectorAttrTest, ExpectedOneDigit) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("#b")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "expected at least one digit"); + }); +} + +TEST(BitVectorAttrTest, ExpectedBOrX) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("#c0")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), "expected either 'b' or 'x'"); + }); +} + +TEST(BitVectorAttrTest, ExpectedHashtag) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = + BitVectorAttr::getChecked(loc, &context, static_cast("b0")); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler( + [&](Diagnostic &diag) { ASSERT_EQ(diag.str(), "expected '#'"); }); +} + +TEST(BitVectorAttrTest, OutOfRange) { + MLIRContext context; + context.loadDialect(); + Location loc(UnknownLoc::get(&context)); + + auto attr = BitVectorAttr::getChecked(loc, &context, 2U, 1U); + ASSERT_EQ(attr, BitVectorAttr()); + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + ASSERT_EQ(diag.str(), + "value does not fit in a bit-vector of desired width"); + }); +} + +} // namespace diff --git a/unittests/Dialect/SMT/CMakeLists.txt b/unittests/Dialect/SMT/CMakeLists.txt new file mode 100644 index 000000000000..2e9ba7eef89f --- /dev/null +++ b/unittests/Dialect/SMT/CMakeLists.txt @@ -0,0 +1,8 @@ +add_circt_unittest(CIRCTSMTTests + AttributeTest.cpp +) + +target_link_libraries(CIRCTSMTTests + PRIVATE + CIRCTSMT +)