From 9efa177b7a687dd34bb2627309d59000a26b46b8 Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Sun, 5 Sep 2021 18:30:53 +0100 Subject: [PATCH] experimental symbolic path call range --- ir/state.cpp | 115 +++++++++++++++++++++++++++++++++++++++----------- ir/state.h | 12 ++++-- smt/exprs.cpp | 9 ++++ smt/exprs.h | 10 ++++- 4 files changed, 117 insertions(+), 29 deletions(-) diff --git a/ir/state.cpp b/ir/state.cpp index 49dbdc985..a34983759 100644 --- a/ir/state.cpp +++ b/ir/state.cpp @@ -5,8 +5,10 @@ #include "ir/function.h" #include "ir/globals.h" #include "smt/smt.h" +#include "smt/solver.h" #include "util/errors.h" #include +#include using namespace smt; using namespace util; @@ -30,34 +32,39 @@ static T intersect_set(const T &a, const T &b) { return results; } +State::ValueAnalysis State::ValueAnalysis::jump(const expr &path) const { + auto ret = *this; + for (auto &[fn, calls] : ranges_fn_calls) { + auto &m = ret.ranges_fn_calls[fn]; + m.clear(); + for (auto &[n, c] : calls) { + m.emplace(n, c && path); + } + } + ret.ranges_fn_calls.path = path; + return ret; +} + void State::ValueAnalysis::meet_with(const State::ValueAnalysis &other) { non_poison_vals = intersect_set(non_poison_vals, other.non_poison_vals); non_undef_vals = intersect_set(non_undef_vals, other.non_undef_vals); unused_vars = intersect_set(unused_vars, other.unused_vars); for (auto &[fn, calls] : other.ranges_fn_calls) { - auto [I, inserted] = ranges_fn_calls.try_emplace(fn, calls); - if (inserted) { - I->second.emplace(0); - } else { - I->second.insert(calls.begin(), calls.end()); - } - } - - for (auto &[fn, calls] : ranges_fn_calls) { - if (!other.ranges_fn_calls.count(fn)) - calls.emplace(0); + ranges_fn_calls[fn].insert(calls.begin(), calls.end()); } + ranges_fn_calls.path.add(other.ranges_fn_calls.path); } -void State::ValueAnalysis::FnCallRanges::inc(const std::string &name) { +void State::ValueAnalysis::FnCallRanges::inc(const string &name, + const expr &cond) { auto [I, inserted] = try_emplace(name); if (inserted) { - I->second.emplace(1); + I->second.emplace(1, cond); } else { - set new_set; - for (unsigned n : I->second) { - new_set.emplace(n+1); + decltype(I->second) new_set; + for (auto &[n, c0] : I->second) { + new_set.emplace(n+1, c0 && cond); } I->second = move(new_set); } @@ -65,25 +72,69 @@ void State::ValueAnalysis::FnCallRanges::inc(const std::string &name) { bool State::ValueAnalysis::FnCallRanges::overlaps(const FnCallRanges &other) const { + auto may_overlap = [](const expr &p1, const expr &p2) { + return !(p1 && p2).simplify().isFalse() && !check_expr(p1 && p2).isUnsat(); + }; + auto must_overlap = [&](const OrExpr &p1, const expr &p2) { + return p1.contains(p2) || + (p1() && p2).simplify().isTrue() || + check_expr(p1() && p2).isSat(); + }; + for (auto &[fn, calls] : *this) { auto I = other.find(fn); if (I == other.end()) { - if (calls.count(0)) - continue; - return false; + for (auto &[n, p] : calls) { + if (must_overlap(other.path, p)) + return false; + } + continue; } - if (intersect_set(calls, I->second).empty()) + + bool overlaps = false; + for (auto &[n1, p1] : calls) { + for (auto &[n2, p2] : I->second) { + if ((overlaps |= n1 == n2 && may_overlap(p1, p2))) + break; + } + } + if (!overlaps) return false; } for (auto &[fn, calls] : other) { - if (!calls.count(0) && !count(fn)) - return false; + if (!count(fn)) { + for (auto &[n, p] : calls) { + if (must_overlap(path, p)) + return false; + } + } } return true; } +bool State::ValueAnalysis::FnCallRanges:: +operator==(const FnCallRanges &other) const { + if (size() != other.size()) + return false; + + for (auto &[fn, calls] : *this) { + auto I2 = other.find(fn); + if (I2 == other.end()) + return false; + + set s1, s2; + for (auto &[n, p] : calls) + s1.insert(n); + for (auto &[n, p] : I2->second) + s2.insert(n); + if (s1 != s2) + return false; + } + return true; +} + State::VarArgsData State::VarArgsData::mkIf(const expr &cond, const VarArgsData &then, const VarArgsData &els) { @@ -575,10 +626,10 @@ void State::addJump(const BasicBlock &dst0, expr &&cond) { auto &data = predecessor_data[dst][current_bb]; data.mem.add(memory, cond); data.UB.add(domain.UB(), cond); - data.path.add(move(cond)); + data.path.add(cond); data.undef_vars.insert(undef_vars.begin(), undef_vars.end()); data.undef_vars.insert(domain.undef_vars.begin(), domain.undef_vars.end()); - data.analysis = analysis; + data.analysis = analysis.jump(cond); data.var_args = var_args_data; } @@ -844,6 +895,8 @@ State::addFnCall(const string &name, vector &&inputs, data.add(out, move(refined)); } + //cout << "SIZE: " << data.size() << endl; + if (data) { auto [d, domain, qvar, pre] = data(); addUB(move(domain)); @@ -865,7 +918,19 @@ State::addFnCall(const string &name, vector &&inputs, } if (writes_memory) - analysis.ranges_fn_calls.inc(name); + analysis.ranges_fn_calls.inc(name, domain.path); + +#if 0 + cout << "CALL " << (isSource() ? "src\n" : "tgt\n"); + cout << "PATH: " << domain.path << endl; + for (auto &[name, ranges] : analysis.ranges_fn_calls) { + cout << name; + for (auto &[n,p] : ranges) + cout << "\t" << n << "/"< unused_vars; - // Possible number of calls per functio name that occurred so far + // Possible number of calls per function name that occurred so far + // Plus the path to the last call // This is an over-approximation, union over all predecessors - struct FnCallRanges : public std::map> { - void inc(const std::string &name); + struct FnCallRanges + : public std::map>> { + smt::OrExpr path { smt::expr(true) }; + void inc(const std::string &name, const smt::expr &cond); bool overlaps(const FnCallRanges &other) const; + // compare only number of fn calls; ignores path + bool operator==(const FnCallRanges &other) const; }; FnCallRanges ranges_fn_calls; + ValueAnalysis jump(const smt::expr &path) const; void meet_with(const ValueAnalysis &other); }; diff --git a/smt/exprs.cpp b/smt/exprs.cpp index 0348a1f58..c21c953b5 100644 --- a/smt/exprs.cpp +++ b/smt/exprs.cpp @@ -67,6 +67,11 @@ ostream &operator<<(ostream &os, const AndExpr &e) { } +void OrExpr::add(const expr &e) { + if (!e.isFalse()) + exprs.insert(e); +} + void OrExpr::add(expr &&e) { if (!e.isFalse()) exprs.insert(move(e)); @@ -76,6 +81,10 @@ void OrExpr::add(const OrExpr &other) { exprs.insert(other.exprs.begin(), other.exprs.end()); } +bool OrExpr::contains(const expr &e) const { + return exprs.count(e); +} + expr OrExpr::operator()() const { return expr::mk_or(exprs); } diff --git a/smt/exprs.h b/smt/exprs.h index ca9df3dcd..155fa0f33 100644 --- a/smt/exprs.h +++ b/smt/exprs.h @@ -20,7 +20,7 @@ class AndExpr { std::set exprs; public: - AndExpr() {} + AndExpr() = default; template AndExpr(T &&e) { add(std::forward(e)); } @@ -41,8 +41,14 @@ class OrExpr { std::set exprs; public: + OrExpr() = default; + template + OrExpr(T &&e) { add(std::forward(e)); } + + void add(const expr &e); void add(expr &&e); void add(const OrExpr &other); + bool contains(const expr &e) const; expr operator()() const; friend std::ostream &operator<<(std::ostream &os, const OrExpr &e); }; @@ -123,6 +129,8 @@ class ChoiceExpr { I->second |= std::forward(domain); } + auto size() const { return vals.size(); } + operator bool() const { return !vals.empty(); }