Skip to content

Commit

Permalink
[SMT] Add bitvector type, attribute, and constant operation
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Mar 9, 2024
1 parent 5703ed8 commit bd9d336
Show file tree
Hide file tree
Showing 20 changed files with 502 additions and 3 deletions.
6 changes: 6 additions & 0 deletions include/circt/Dialect/SMT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ add_circt_doc(SMTOps Dialects/SMTOps -gen-op-doc)
add_circt_doc(SMTTypes Dialects/SMTTypes -gen-typedef-doc -dialect smt)

set(LLVM_TARGET_DEFINITIONS SMT.td)

mlir_tablegen(SMTAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(SMTAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(CIRCTSMTAttrIncGen)
add_dependencies(circt-headers CIRCTSMTAttrIncGen)

mlir_tablegen(SMTEnums.h.inc -gen-enum-decls)
mlir_tablegen(SMTEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(CIRCTSMTEnumsIncGen)
Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/SMT/SMT.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

include "mlir/IR/OpBase.td"

include "circt/Dialect/SMT/SMTAttributes.td"
include "circt/Dialect/SMT/SMTDialect.td"
include "circt/Dialect/SMT/SMTTypes.td"
include "circt/Dialect/SMT/SMTOps.td"
include "circt/Dialect/SMT/SMTBitVectorOps.td"

#endif // CIRCT_DIALECT_SMT_SMT_TD
19 changes: 19 additions & 0 deletions include/circt/Dialect/SMT/SMTAttributes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- SMTAttributes.h - Declare SMT dialect attributes ----------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_DIALECT_SMT_ATTRIBUTES_H
#define CIRCT_DIALECT_SMT_ATTRIBUTES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"

#define GET_ATTRDEF_CLASSES
#include "circt/Dialect/SMT/SMTAttributes.h.inc"

#endif // CIRCT_DIALECT_SMT_ATTRIBUTES_H
61 changes: 61 additions & 0 deletions include/circt/Dialect/SMT/SMTAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//===- SMTAttributes.td - Attributes for SMT dialect -------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines SMT dialect specific attributes.
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD
#define CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD

include "circt/Dialect/SMT/SMTDialect.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"

def BitVectorAttr : AttrDef<SMTDialect, "BitVector", [
DeclareAttrInterfaceMethods<TypedAttrInterface>
]> {
let mnemonic = "bv";
let description = [{
This attribute represents a constant value of the `(_ BitVec width)` sort as
described in the [SMT bit-vector
theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).

The constant is parsed and printed as #bX (binary) or #xX (hexadecimal)
where X is the value in the corresponding format without any further
prefixing just as in SMT-LIB. The number of digits determines the bit-width
of the bit-vector. This means, leading zeros are important!

Examples:
```mlir
#smt.bv<"#b0101"> : !smt.bv<4>
#smt.bv<"#x5c"> : !smt.bv<8>
```
The explicit type-suffix is optional.

The bit-width must be greater than zero (i.e., at least one digit as to be
present).
}];

let parameters = (ins "llvm::APInt":$value);

let hasCustomAssemblyFormat = true;
let genVerifyDecl = true;

let builders = [
AttrBuilder<(ins "llvm::StringRef":$value)>,
AttrBuilder<(ins "unsigned":$value, "unsigned":$width)>,
];

let extraClassDeclaration = [{
/// Return the bit-vector constant as a SMT-LIB formatted string.
std::string getValueAsString(bool prefix = true) const;
}];
}

#endif // CIRCT_DIALECT_SMT_SMTATTRIBUTES_TD
54 changes: 54 additions & 0 deletions include/circt/Dialect/SMT/SMTBitVectorOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- SMTBitVectorOps.td - SMT bit-vector dialect ops -----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD
#define CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD

include "circt/Dialect/SMT/SMTDialect.td"
include "circt/Dialect/SMT/SMTAttributes.td"
include "circt/Dialect/SMT/SMTTypes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class SMTBVOp<string mnemonic, list<Trait> traits = []> :
Op<SMTDialect, "bv." # mnemonic, traits>;

def BVConstantOp : SMTBVOp<"constant", [
Pure,
ConstantLike,
FirstAttrDerivedResultType,
DeclareOpInterfaceMethods<InferTypeOpInterface, ["inferReturnTypes"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>
]> {
let summary = "produce a constant bit-vector";
let description = [{
This operation produces an SSA value equal to the bit-vector constant
specified by the 'value' attribute.
Refer to the `BitVectorAttr` documentation for more information about
the semantics of bit-vector constants, their format, and associated sort.
The result type always matches the attribute's type.

Examples:
```mlir
%bv_x5c = smt.bv.constant <"#x5c">
%bv_b0101 = smt.bv.constant <"#b0101">
```
}];

let arguments = (ins BitVectorAttr:$value);
let results = (outs BitVectorType:$result);

let assemblyFormat = "$value attr-dict";

let hasFolder = true;
let hasVerifier = true;
}

#endif // CIRCT_DIALECT_SMT_SMTBITVECTOROPS_TD
4 changes: 4 additions & 0 deletions include/circt/Dialect/SMT/SMTDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ def SMTDialect : Dialect {
let summary = "a dialect that models satisfiability modulo theories";
let cppNamespace = "circt::smt";

let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;

let hasConstantMaterializer = 1;

let extraClassDeclaration = [{
void registerAttributes();
void registerTypes();
}];
}
Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/SMT/SMTOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "circt/Dialect/SMT/SMTAttributes.h"
#include "circt/Dialect/SMT/SMTDialect.h"
#include "circt/Dialect/SMT/SMTTypes.h"

Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/SMT/SMTOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define CIRCT_DIALECT_SMT_SMTOPS_TD

include "circt/Dialect/SMT/SMTDialect.td"
include "circt/Dialect/SMT/SMTAttributes.td"
include "circt/Dialect/SMT/SMTTypes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpAsmInterface.td"
Expand All @@ -19,5 +20,4 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class SMTOp<string mnemonic, list<Trait> traits = []> :
Op<SMTDialect, mnemonic, traits>;


#endif // CIRCT_DIALECT_SMT_SMTOPS_TD
16 changes: 16 additions & 0 deletions include/circt/Dialect/SMT/SMTTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,20 @@ def BoolType : SMTTypeDef<"Bool"> {
let assemblyFormat = "";
}

def BitVectorType : SMTTypeDef<"BitVector"> {
let mnemonic = "bv";
let description = [{
This type represents the `(_ BitVec width)` sort as described in the
[SMT bit-vector
theory](https://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml).

The bit-width must be strictly greater than zero.
}];

let parameters = (ins "unsigned":$width);
let assemblyFormat = "`<` $width `>`";

let genVerifyDecl = true;
}

#endif // CIRCT_DIALECT_SMT_SMTTYPES_TD
2 changes: 2 additions & 0 deletions lib/Dialect/SMT/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTSMT
SMTAttributes.cpp
SMTDialect.cpp
SMTOps.cpp
SMTTypes.cpp
Expand All @@ -7,6 +8,7 @@ add_circt_dialect_library(CIRCTSMT
${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/SMT

DEPENDS
CIRCTSMTAttrIncGen
CIRCTSMTEnumsIncGen
MLIRSMTIncGen

Expand Down
155 changes: 155 additions & 0 deletions lib/Dialect/SMT/SMTAttributes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/SMT/SMTAttributes.h"
#include "circt/Dialect/SMT/SMTDialect.h"
#include "circt/Dialect/SMT/SMTTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Format.h"

using namespace circt;
using namespace circt::smt;

//===----------------------------------------------------------------------===//
// BitVectorAttr
//===----------------------------------------------------------------------===//

LogicalResult
BitVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,
APInt value) {
if (value.getBitWidth() < 1)
return emitError() << "bit-width must be at least 1, but got "
<< value.getBitWidth();
return success();
}

std::string BitVectorAttr::getValueAsString(bool prefix) const {
unsigned width = getValue().getBitWidth();
SmallVector<char> toPrint;
StringRef pref = prefix ? "#" : "";
if (width % 4 == 0) {
getValue().toString(toPrint, 16, false, false, false);
// APInt's 'toString' omits leading zeros. However, those are critical here
// because they determine the bit-width of the bit-vector.
SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
return (pref + "x" + Twine(leadingZeros) + toPrint).str();
}

getValue().toString(toPrint, 2, false, false, false);
// APInt's 'toString' omits leading zeros
SmallVector<char> leadingZeros(width - toPrint.size(), '0');
return (pref + "b" + Twine(leadingZeros) + toPrint).str();
}

/// Parse an SMT-LIB formatted bit-vector string.
static FailureOr<APInt>
parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
StringRef value) {
if (value[0] != '#')
return emitError() << "expected '#'";

if (value.size() < 3)
return emitError() << "expected at least one digit";

if (value[1] == 'b')
return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
2);

if (value[1] == 'x')
return APInt((value.size() - 2) * 4,
std::string(value.begin() + 2, value.end()), 16);

return emitError() << "expected either 'b' or 'x'";
}

BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
InFlightDiagnostic diag;
auto maybeValue =
parseBitVectorString([&]() { return std::move(diag); }, value);
diag.abandon();

assert(succeeded(maybeValue) && "string must have SMT-LIB format");
return Base::get(context, *maybeValue);
}

BitVectorAttr
BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, StringRef value) {
auto maybeValue = parseBitVectorString(emitError, value);
if (failed(maybeValue))
return {};

return Base::getChecked(emitError, context, *maybeValue);
}

BitVectorAttr BitVectorAttr::get(MLIRContext *context, unsigned value,
unsigned width) {
return Base::get(context, APInt(width, value));
}

BitVectorAttr
BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
MLIRContext *context, unsigned value,
unsigned width) {
return Base::getChecked(emitError, context, APInt(width, value));
}

Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
llvm::SMLoc loc = odsParser.getCurrentLocation();

std::string val;
if (odsParser.parseLess() || odsParser.parseString(&val))
return {};

auto maybeVal =
parseBitVectorString([&]() { return odsParser.emitError(loc); }, val);
if (failed(maybeVal))
return {};

if (odsParser.parseGreater())
return {};

// We implement the TypedAttr interface, i.e., the AsmParser allows an
// optional explicit type as suffix which is provided here as 'odsType'. We
// can always build the type from the constant itself, therefore, we can
// ignore this explicit type. However, to ensure consistency we should verify
// that the correct type is provided.
auto bv = BitVectorAttr::get(odsParser.getContext(), *maybeVal);
if (odsType && bv.getType() != odsType) {
odsParser.emitError(loc, "expected type for constant does not match "
"explicitly provided attribute type, got ")
<< odsType << ", expected " << bv.getType();
return {};
}

return bv;
}

void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
odsPrinter << "<\"" << getValueAsString() << "\">";
}

Type BitVectorAttr::getType() const {
return BitVectorType::get(getContext(), getValue().getBitWidth());
}

//===----------------------------------------------------------------------===//
// ODS Boilerplate
//===----------------------------------------------------------------------===//

#define GET_ATTRDEF_CLASSES
#include "circt/Dialect/SMT/SMTAttributes.cpp.inc"

void SMTDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "circt/Dialect/SMT/SMTAttributes.cpp.inc"
>();
}
Loading

0 comments on commit bd9d336

Please sign in to comment.