Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interaction with Z3 through Z3's Java API instead of StdIO #633

Closed
wants to merge 20 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleanup and support for builtin SMT functions and sorts
  • Loading branch information
marcoeilers committed Jul 26, 2022
commit 28626cd530e8c5faf7d13eb9b351b0c434dd3549
125 changes: 39 additions & 86 deletions src/main/scala/decider/TermToZ3APIConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class TermToZ3APIConverter
var ctx: Context = _
val macros = mutable.HashMap[String, (Seq[Var], Term)]()

val termCache = mutable.HashMap[Term, Z3Expr]()
val sortCache = mutable.HashMap[Sort, Z3Sort]()
val funcDeclCache = mutable.HashMap[(String, Seq[Sort], Sort), Z3FuncDecl]()

def convert(s: Sort): Z3Sort = convertSort(s)
Expand Down Expand Up @@ -89,8 +87,6 @@ class TermToZ3APIConverter
var secondFunc: Z3FuncDecl = _

def convertSort(s: Sort): Z3Sort = {
if (sortCache.contains(s))
return sortCache(s)
val res = s match {
case sorts.Int => ctx.mkIntSort()
case sorts.Bool => ctx.mkBoolSort()
Expand All @@ -101,28 +97,33 @@ class TermToZ3APIConverter
case sorts.Seq(elementSort) => {
val res = ctx.mkUninterpretedSort("Seq<" + convertSortName(elementSort) + ">")
res
} // text("Seq<") <> doRender(elementSort, true) <> ">"
}
case sorts.Set(elementSort) => ctx.mkUninterpretedSort("Set<" + convertSortName(elementSort) + ">") // text("Set<") <> doRender(elementSort, true) <> ">"
case sorts.Multiset(elementSort) => ctx.mkUninterpretedSort("Multiset<" + convertSortName(elementSort) + ">") // // text("Multiset<") <> doRender(elementSort, true) <> ">"
case sorts.UserSort(id) => ctx.mkUninterpretedSort(convertId(id)) // render(id)
case sorts.SMTSort(id) => ??? // if (alwaysSanitize) render(id) else id.name
case sorts.SMTSort(id) => {
// workaround: since we cannot create a built-in sort from its name, we let Z3 parse
// a string that uses the sort, take the AST, and get the func decl from there, so that we can
// programmatically create a func app.
val workaround = ctx.parseSMTLIB2String(s"(declare-const workaround ${id}) (assert (= workaround workaround))", null, null, null, null)
val res = workaround(0).getArgs()(0).getSort
res
}

case sorts.Unit =>
/* Sort Unit corresponds to Scala's Unit type and is used, e.g., as the
* domain sort of nullary functions.
/* Should never occur
*/
???

case sorts.FieldValueFunction(codomainSort) => {
val name = convertSortName(codomainSort)
ctx.mkUninterpretedSort("$FVF<" + name + ">")
} // // text("$FVF<") <> doRender(codomainSort, true) <> ">"
}
case sorts.PredicateSnapFunction(codomainSort) => ctx.mkUninterpretedSort("$PSF<" + convertSortName(codomainSort) + ">") // text("$PSF<") <> doRender(codomainSort, true) <> ">"

case sorts.FieldPermFunction() => ctx.mkUninterpretedSort("$FPM") // text("$FPM")
case sorts.PredicatePermFunction() => ctx.mkUninterpretedSort("$PPM") // text("$PPM")
}
sortCache.update(s, res)
res
}

Expand All @@ -138,11 +139,12 @@ class TermToZ3APIConverter
case sorts.Set(elementSort) => Some(ctx.mkSymbol("Set<" + convertSortName(elementSort) + ">")) // text("Set<") <> doRender(elementSort, true) <> ">"
case sorts.Multiset(elementSort) => Some(ctx.mkSymbol("Multiset<" + convertSortName(elementSort) + ">")) // // text("Multiset<") <> doRender(elementSort, true) <> ">"
case sorts.UserSort(id) => Some(ctx.mkSymbol(convertId(id))) // render(id)
case sorts.SMTSort(id) => ??? // if (alwaysSanitize) render(id) else id.name
case sorts.SMTSort(id) => {
???
}

case sorts.Unit =>
/* Sort Unit corresponds to Scala's Unit type and is used, e.g., as the
* domain sort of nullary functions.
/* Should never occur
*/
???

Expand Down Expand Up @@ -187,30 +189,8 @@ class TermToZ3APIConverter
}

def convert(d: Decl): Unit = {
// not used
???
// d match {
// case SortDecl(sort: Sort) =>
// ??? // parens(text("declare-sort") <+> render(sort) <+> text("0"))
//
// case fd@FunctionDecl(fun: Function) =>
// convert(fd)
//
//
// case swd@SortWrapperDecl(from, to) =>
// // val id = swd.id
// // val fct = FunctionDecl(Fun(id, from, to))
// //
// // render(fct)
// ???
//
// case MacroDecl(id, args, body) =>
// //val idDoc = render(id)
// //val argDocs = (args map (v => parens(text(render(v.id)) <+> render(v.sort)))).to(collection.immutable.Seq)
// //val bodySortDoc = render(body.sort)
// //val bodyDoc = render(body)
//
// ??? // parens(text("define-fun") <+> idDoc <+> parens(ssep(argDocs, space)) <+> bodySortDoc <> nest(defaultIndent, line <> bodyDoc))
// }
}

def convert(t: Term): Z3Expr = {
Expand All @@ -219,8 +199,6 @@ class TermToZ3APIConverter


def convertTerm(term: Term): Z3Expr = {
//if (termCache.contains(term))
// return termCache(term)
val res = term match {
case l: Literal => {
l match {
Expand All @@ -233,7 +211,7 @@ class TermToZ3APIConverter
case True() => ctx.mkTrue()
case False() => ctx.mkFalse()
case Null() => ctx.mkConst("$Ref.null", ctx.mkUninterpretedSort("$Ref"))
case Unit => ctx.mkConst(getUnitConstructor)// ctx.mkConst("$Snap.unit", getSnapSort) //"$Snap.unit"
case Unit => ctx.mkConst(getUnitConstructor)
case _: SeqNil => renderApp("Seq_empty", Seq(), l.sort)
case _: EmptySet => renderApp("Set_empty", Seq(), l.sort)
case _: EmptyMultiset => renderApp("Multiset_empty", Seq(), l.sort)
Expand Down Expand Up @@ -406,10 +384,10 @@ class TermToZ3APIConverter

/* Quantified Permissions */

case Domain(id, fvf) => renderApp("$FVF.domain_" + id, Seq(fvf), term.sort) //parens(text("$FVF.domain_") <> id <+> render(fvf))
case Domain(id, fvf) => renderApp("$FVF.domain_" + id, Seq(fvf), term.sort)

case Lookup(field, fvf, at) =>
renderApp("$FVF.lookup_" + field, Seq(fvf, at), term.sort) // parens(text("$FVF.lookup_") <> field <+> render(fvf) <+> render(at))
renderApp("$FVF.lookup_" + field, Seq(fvf, at), term.sort)

case FieldTrigger(field, fvf, at) => renderApp("$FVF.loc_" + field, (fvf.sort match {
case sorts.FieldValueFunction(_) => Seq(Lookup(field, fvf, at), at)
Expand All @@ -434,53 +412,28 @@ class TermToZ3APIConverter

/* Other terms */

case First(t) => ctx.mkApp(firstFunc, convertTerm(t))//renderApp("$Snap.first", Seq(t), term.sort)//parens(text("$Snap.first") <+> render(t))
case Second(t) => ctx.mkApp(secondFunc, convertTerm(t))//renderApp("$Snap.second", Seq(t), term.sort)
case First(t) => ctx.mkApp(firstFunc, convertTerm(t))
case Second(t) => ctx.mkApp(secondFunc, convertTerm(t))

case bop: Combine =>
ctx.mkApp(combineConstructor, convertTerm(bop.p0), convertTerm(bop.p1))//renderApp("$Snap.combine", Seq(bop.p0, bop.p1), term.sort)
ctx.mkApp(combineConstructor, convertTerm(bop.p0), convertTerm(bop.p1))

case SortWrapper(t, to) =>
renderApp(convertId(SortWrapperId(t.sort, to)), Seq(t), to)
//parens(text(render(SortWrapperId(t.sort, to))) <+> render(t))

case Distinct(symbols) =>
ctx.mkDistinct(symbols.map(s => ctx.mkConst(convertId(s.id), convertSort(s.resultSort))).toSeq: _*)
//renderApp("distinct") <+> ssep((symbols.toSeq map (s => render(s.id): Cont)).to(collection.immutable.Seq), space))

case Let(bindings, body) =>
convert(body.replace(bindings))
//val docBindings = ssep((bindings.toSeq map (p => parens(render(p._1) <+> render(p._2)))).to(collection.immutable.Seq), space)
//parens(text("let") <+> parens(docBindings) <+> render(body))

case _: MagicWandChunkTerm
| _: Quantification =>
sys.error(s"Unexpected term $term cannot be translated to SMTLib code")
}
//termCache.update(term, res)
res
}

// @inline
// protected def renderUnaryOp(op: String, t: UnaryOp[Term]): Cont =
// parens(text(op) <> nest(defaultIndent, group(line <> render(t.p))))
//
// @inline
// protected def renderUnaryOp(op: String, doc: Cont): Cont =
// parens(text(op) <> nest(defaultIndent, group(line <> doc)))
//
// @inline
// protected def renderBinaryOp(op: String, t: BinaryOp[Term]): Cont =
// parens(text(op) <> nest(defaultIndent, group(line <> render(t.p0) <> line <> render(t.p1))))
//
// @inline
// protected def renderBinaryOp(op: String, left: Cont, right: Cont): Cont =
// parens(text(op) <> nest(defaultIndent, group(line <> left <> line <> right)))
//
// @inline
// protected def renderNAryOp(op: String, terms: Term*): Cont =
// parens(text(op) <> nest(defaultIndent, group(line <> ssep((terms map render).to(collection.immutable.Seq), line))))
//
@inline
protected def renderApp(functionName: String, args: Seq[Term], outSort: Sort): Z3Expr = {
ctx.mkApp(getFuncDecl(functionName, outSort, args.map(_.sort)), args.map(convertTerm(_)): _*)
Expand All @@ -497,23 +450,25 @@ class TermToZ3APIConverter

@inline
protected def renderSMTApp(functionName: String, args: Seq[Term], outSort: Sort): Z3Expr = {
// val docAppNoParens =
// text(functionName) <+> ssep((args map render).to(collection.immutable.Seq), space)
//
// if (args.nonEmpty)
// parens(docAppNoParens)
// else
// parens(text("as") <+> docAppNoParens <+> render(outSort))
// TODO: this needs to be special-cased unfortunately. Urgh.
???
// workaround: since we cannot create a function application with just the name, we let Z3 parse
// a string that uses the function, take the AST, and get the func decl from there, so that we can
// programmatically create a func app.
Comment on lines +460 to +462
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because we don't know the types for the arguments? Are we passing them to Z3 so it can figure out their sorts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the problem is that there is nothing like ctx.getInterpretedFuncDecl(name) in the API that I could give a string like "fp.add" to and get the declaration for floating point addition. Instead, for any built-in interpreted function, there is a separate creator method in the API (e.g., ctx.mkFloatAdd(arg1, arg2)). So I'd have to have a big match-case here that matches all names of all built-in functions and then calls the respectice API method. Which I didn't want to do now because there are a lot of them, and the way we currently use SMTFuncApps, sometimes, the strings aren't actually just a function name but a function name and some number of default argument (e.g., "fp.add RNE", which contains the function name and a first argument that determines the rounding mode).

val decls = args.zipWithIndex.map{case (a, i) => s"(declare-const workaround${i} ${smtlibConverter.convert(a.sort)})"}.mkString(" ")
val funcAppString = s"(${functionName} ${(0 until args.length).map(i => "workaround" + i).mkString(" ")})"
val assertion = decls + s" (assert (= ${funcAppString} ${funcAppString}))"
val workaround = ctx.parseSMTLIB2String(assertion, null, null, null, null)
val app = workaround(0).getArgs()(0)
val decl = app.getFuncDecl
val actualArgs = if (decl.getArity > args.length){
// the function name we got wasn't just a function name but also contained a first argument.
// this happens with float operations where functionName contains a rounding mode.
Comment on lines +470 to +471
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this happen for other reasons (e.g. accidentally calling a function with the wrong number of arguments)? Could you whitelist only float-related operations in this branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it could, yes.

I think if we want to avoid this special case (which really shouldn't be here), we should just forbid including arguments in the function name in general, then we wouldn't have to handle this case here. It's just that that was unproblematic when we were dealing only with strings, so I did it when making the factory for floating point operations. If we rewrite that code, I think the only consequence is that users of floating point functions would always have to pass the rounding mode explicitly.

app.getArgs.toSeq.slice(0, decl.getArity - args.length) ++ args.map(convertTerm(_))
}else {
args.map(convertTerm(_))
}
ctx.mkApp(decl, actualArgs.toArray : _*)
}

// protected def render(q: Quantifier): Cont = q match {
// case Forall => "forall"
// case Exists => "exists"
// }
//


protected def renderAsReal(t: Term): RealExpr =
if (t.sort == sorts.Int)
Expand Down Expand Up @@ -546,9 +501,7 @@ class TermToZ3APIConverter
def reset(): Unit = {
sanitizedNamesCache.clear()
macros.clear()
sortCache.clear()
funcDeclCache.clear()
termCache.clear()
unitConstructor = null
combineConstructor = null
firstFunc = null
Expand Down