Skip to content

Commit

Permalink
switch to solve_eqs2 tactic
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Nov 8, 2022
1 parent f769e2f commit 3a37cfc
Show file tree
Hide file tree
Showing 24 changed files with 149 additions and 52 deletions.
5 changes: 5 additions & 0 deletions src/ast/rewriter/th_rewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,11 @@ void th_rewriter::get_param_descrs(param_descrs & r) {
rewriter_params::collect_param_descrs(r);
}

void th_rewriter::set_flat_and_or(bool f) {
m_imp->cfg().m_b_rw.set_flat_and_or(f);
}


th_rewriter::~th_rewriter() {
dealloc(m_imp);
}
Expand Down
3 changes: 3 additions & 0 deletions src/ast/rewriter/th_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class th_rewriter {

void updt_params(params_ref const & p);
static void get_param_descrs(param_descrs & r);

void set_flat_and_or(bool f);

unsigned get_cache_size() const;
unsigned get_num_steps() const;

Expand Down
11 changes: 11 additions & 0 deletions src/ast/simplifiers/extract_eqs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,23 @@ namespace euf {
class basic_extract_eq : public extract_eq {
ast_manager& m;
bool m_ite_solver = true;
bool m_allow_bool = true;

public:
basic_extract_eq(ast_manager& m) : m(m) {}

virtual void set_allow_booleans(bool f) {
m_allow_bool = f;
}

void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) override {
auto [f, d] = e();
expr* x, * y;
if (m.is_eq(f, x, y)) {
if (x == y)
return;
if (!m_allow_bool && m.is_bool(x))
return;
if (is_uninterp_const(x))
eqs.push_back(dependent_eq(e.fml(), to_app(x), expr_ref(y, m), d));
if (is_uninterp_const(y))
Expand All @@ -47,6 +54,8 @@ namespace euf {
expr* c, * th, * el, * x1, * y1, * x2, * y2;
if (m_ite_solver && m.is_ite(f, c, th, el)) {
if (m.is_eq(th, x1, y1) && m.is_eq(el, x2, y2)) {
if (!m_allow_bool && m.is_bool(x1))
return;
if (x1 == y2 && is_uninterp_const(x1))
std::swap(x2, y2);
if (x2 == y2 && is_uninterp_const(x2))
Expand All @@ -57,6 +66,8 @@ namespace euf {
eqs.push_back(dependent_eq(e.fml(), to_app(x1), expr_ref(m.mk_ite(c, y1, y2), m), d));
}
}
if (!m_allow_bool)
return;
if (is_uninterp_const(f))
eqs.push_back(dependent_eq(e.fml(), to_app(f), expr_ref(m.mk_true(), m), d));
if (m.is_not(f, x) && is_uninterp_const(x))
Expand Down
6 changes: 6 additions & 0 deletions src/ast/simplifiers/extract_eqs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Module Name:

#pragma once

#include "ast/ast_pp.h"
#include "ast/simplifiers/dependent_expr_state.h"
#include "ast/rewriter/th_rewriter.h"
#include "ast/expr_substitution.h"
Expand All @@ -42,8 +43,13 @@ namespace euf {
virtual void get_eqs(dependent_expr const& e, dep_eq_vector& eqs) = 0;
virtual void pre_process(dependent_expr_state& fmls) {}
virtual void updt_params(params_ref const& p) {}
virtual void set_allow_booleans(bool f) {}
};

void register_extract_eqs(ast_manager& m, scoped_ptr_vector<extract_eq>& ex);

}

inline std::ostream& operator<<(std::ostream& out, euf::dependent_eq const& eq) {
return out << mk_pp(eq.var, eq.term.m()) << " = " << eq.term << "\n";
}
2 changes: 1 addition & 1 deletion src/ast/simplifiers/model_reconstruction_trail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ void model_reconstruction_trail::replay(dependent_expr const& d, vector<dependen

added.push_back(d);


for (auto& t : m_trail) {
if (!t->m_active)
continue;
Expand Down Expand Up @@ -69,6 +68,7 @@ model_converter_ref model_reconstruction_trail::get_model_converter() {
// substituted variables by their terms.
//


scoped_ptr<expr_replacer> rp = mk_default_expr_replacer(m, false);
expr_substitution subst(m, true, false);
rp->set_substitution(&subst);
Expand Down
65 changes: 46 additions & 19 deletions src/ast/simplifiers/solve_context_eqs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ namespace euf {
if (!contains_v(f))
return true;
signed_expressions conjuncts;
if (contains_conjunctively(f, sign, e, conjuncts))
if (contains_conjunctively(f, sign, e, conjuncts))
return true;
if (recursion_depth > 3)
return false;
Expand All @@ -67,9 +67,9 @@ namespace euf {
/*
* Every disjunction in f that contains v also contains the equation e.
*/
bool solve_context_eqs::is_disjunctively_safe(unsigned recursion_depth, expr* f, bool sign, expr* e) {
bool solve_context_eqs::is_disjunctively_safe(unsigned recursion_depth, expr* f0, bool sign, expr* e) {
signed_expressions todo;
todo.push_back({sign, f});
todo.push_back({sign, f0});
while (!todo.empty()) {
auto [s, f] = todo.back();
todo.pop_back();
Expand All @@ -93,11 +93,21 @@ namespace euf {
todo.push_back({s, arg});
else if (m.is_not(f, f))
todo.push_back({!s, f});
else if (!is_conjunction(s, f))
return false;
else if (!is_safe_eq(recursion_depth + 1, f, s, e))
return false;
}
return true;
}

bool solve_context_eqs::is_conjunction(bool sign, expr* f) const {
if (!sign && m.is_and(f))
return true;
if (sign && m.is_or(f))
return true;
return false;
}

/**
* Determine whether some conjunction in f contains e.
Expand Down Expand Up @@ -140,29 +150,43 @@ namespace euf {
for (unsigned i = m_solve_eqs.m_qhead; i < m_fmls.size(); ++i)
collect_nested_equalities(m_fmls[i], visited, eqs);

std::stable_sort(eqs.begin(), eqs.end(), [&](dependent_eq const& e1, dependent_eq const& e2) {
return e1.var->get_id() < e2.var->get_id(); });
unsigned j = 0;
expr* last_var = nullptr;
for (auto const& eq : eqs) {

m_contains_v.reset();

// first check if v is in term. If it is, then the substitution candidate is unsafe
m_todo.push_back(eq.term);
mark_occurs(m_todo, eq.var, m_contains_v);
SASSERT(m_todo.empty());
if (m_contains_v.is_marked(eq.term))
continue;

// then mark occurrences
for (unsigned i = 0; i < m_fmls.size(); ++i)
m_todo.push_back(m_fmls[i].fml());
mark_occurs(m_todo, eq.var, m_contains_v);
SASSERT(m_todo.empty());
SASSERT(!m.is_bool(eq.var));

if (eq.var != last_var) {

m_contains_v.reset();

// first check if v is in term. If it is, then the substitution candidate is unsafe
m_todo.push_back(eq.term);
mark_occurs(m_todo, eq.var, m_contains_v);
SASSERT(m_todo.empty());
last_var = eq.var;
if (m_contains_v.is_marked(eq.term))
continue;

// then mark occurrences
for (unsigned i = 0; i < m_fmls.size(); ++i)
m_todo.push_back(m_fmls[i].fml());
mark_occurs(m_todo, eq.var, m_contains_v);
SASSERT(m_todo.empty());
}
else if (m_contains_v.is_marked(eq.term))
continue;

// subject to occurrences, check if equality is safe
if (is_safe_eq(eq.orig))
if (is_safe_eq(eq.orig))
eqs[j++] = eq;
}
eqs.shrink(j);
TRACE("solve_eqs",
for (auto const& eq : eqs)
tout << eq << "\n");
}

void solve_context_eqs::collect_nested_equalities(dependent_expr const& df, expr_mark& visited, dep_eq_vector& eqs) {
Expand Down Expand Up @@ -204,8 +228,11 @@ namespace euf {
else if (m.is_not(f, f))
todo.push_back({ !s, depth, f });
else if (!s && 1 == depth % 2) {
for (extract_eq* ex : m_solve_eqs.m_extract_plugins)
for (extract_eq* ex : m_solve_eqs.m_extract_plugins) {
ex->set_allow_booleans(false);
ex->get_eqs(dependent_expr(m, f, df.dep()), eqs);
ex->set_allow_booleans(true);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/simplifiers/solve_context_eqs.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace euf {
bool is_safe_eq(expr* f, expr* e) { return is_safe_eq(0, f, false, e); }
bool is_disjunctively_safe(unsigned recursion_depth, expr* f, bool sign, expr* e);
bool contains_conjunctively(expr* f, bool sign, expr* e, signed_expressions& conjuncts);
bool is_conjunction(bool sign, expr* f) const;

void collect_nested_equalities(dependent_expr const& f, expr_mark& visited, dep_eq_vector& eqs);

Expand Down
15 changes: 9 additions & 6 deletions src/ast/simplifiers/solve_eqs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ namespace euf {
}

void solve_eqs::normalize() {
if (m_subst_ids.empty())
return;
scoped_ptr<expr_replacer> rp = mk_default_expr_replacer(m, false);
rp->set_substitution(m_subst.get());

Expand Down Expand Up @@ -152,15 +154,18 @@ namespace euf {
void solve_eqs::apply_subst(vector<dependent_expr>& old_fmls) {
if (!m.inc())
return;
if (m_subst_ids.empty())
return;

scoped_ptr<expr_replacer> rp = mk_default_expr_replacer(m, false);
rp->set_substitution(m_subst.get());

for (unsigned i = m_qhead; i < m_fmls.size() && !m_fmls.inconsistent(); ++i) {
auto [f, d] = m_fmls[i]();
auto [new_f, new_dep] = rp->replace_with_dep(f);
m_rewriter(new_f);
if (new_f == f)
continue;
m_rewriter(new_f);
new_dep = m.mk_join(d, new_dep);
old_fmls.push_back(m_fmls[i]);
m_fmls.update(i, dependent_expr(m, new_f, new_dep));
Expand All @@ -185,14 +190,13 @@ namespace euf {
normalize();
apply_subst(old_fmls);
++count;
save_subst({});
}
while (!m_subst_ids.empty() && count < 20 && m.inc());

if (!m.inc())
return;

save_subst({});

if (m_config.m_context_solve) {
old_fmls.reset();
m_subst_ids.reset();
Expand All @@ -211,7 +215,7 @@ namespace euf {

void solve_eqs::save_subst(vector<dependent_expr> const& old_fmls) {
if (!m_subst->empty())
m_fmls.model_trail().push(m_subst.detach(), old_fmls);
m_fmls.model_trail().push(m_subst.detach(), old_fmls);
}

void solve_eqs::filter_unsafe_vars() {
Expand All @@ -222,11 +226,10 @@ namespace euf {
m_unsafe_vars.mark(term);
}



solve_eqs::solve_eqs(ast_manager& m, dependent_expr_state& fmls) :
dependent_expr_simplifier(m, fmls), m_rewriter(m) {
register_extract_eqs(m, m_extract_plugins);
m_rewriter.set_flat_and_or(false);
}

void solve_eqs::updt_params(params_ref const& p) {
Expand Down
1 change: 1 addition & 0 deletions src/nlsat/tactic/qfnra_nlsat_tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Module Name:
#include "tactic/core/elim_uncnstr_tactic.h"
#include "tactic/core/propagate_values_tactic.h"
#include "tactic/core/solve_eqs_tactic.h"
#include "tactic/core/solve_eqs2_tactic.h"
#include "tactic/core/elim_term_ite_tactic.h"

tactic * mk_qfnra_nlsat_tactic(ast_manager & m, params_ref const & p) {
Expand Down
1 change: 1 addition & 0 deletions src/opt/opt_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Module Name:
#include "tactic/tactic.h"
#include "tactic/arith/lia2card_tactic.h"
#include "tactic/core/solve_eqs_tactic.h"
#include "tactic/core/solve_eqs2_tactic.h"
#include "tactic/core/simplify_tactic.h"
#include "tactic/core/propagate_values_tactic.h"
#include "tactic/core/solve_eqs_tactic.h"
Expand Down
12 changes: 9 additions & 3 deletions src/tactic/core/solve_eqs2_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ class solve_eqs2_tactic_factory : public dependent_expr_simplifier_factory {
}
};

inline tactic * mk_solve_eqs2_tactic(ast_manager& m, params_ref const& p) {
return alloc(dependent_expr_state_tactic, m, p, alloc(solve_eqs2_tactic_factory), "solve-eqs2");
inline tactic * mk_solve_eqs2_tactic(ast_manager& m, params_ref const& p = params_ref()) {
return alloc(dependent_expr_state_tactic, m, p, alloc(solve_eqs2_tactic_factory), "solve-eqs");
}

#if 1
inline tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p = params_ref()) {
return mk_solve_eqs2_tactic(m, p);
}
#endif


/*
ADD_TACTIC("solve-eqs2", "solve for variables.", "mk_solve_eqs2_tactic(m, p)")
ADD_TACTIC("solve-eqs", "solve for variables.", "mk_solve_eqs2_tactic(m, p)")
*/


5 changes: 3 additions & 2 deletions src/tactic/core/solve_eqs_tactic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,8 @@ class solve_eqs_tactic : public tactic {
//
void operator()(goal_ref const & g, goal_ref_buffer & result) {
model_converter_ref mc;
std::function<void(statistics&)> coll = [&](statistics& st) { collect_statistics(st); };
statistics_report sreport(coll);
tactic_report report("solve_eqs", *g);
TRACE("goal", g->display(tout););
m_produce_models = g->models_enabled();
Expand Down Expand Up @@ -1042,7 +1044,6 @@ class solve_eqs_tactic : public tactic {
result.push_back(g.get());


IF_VERBOSE(10, statistics st; collect_statistics(st); st.display_smt2(verbose_stream()));
}
};

Expand Down Expand Up @@ -1103,6 +1104,6 @@ class solve_eqs_tactic : public tactic {

};

tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p) {
tactic * mk_solve_eqs1_tactic(ast_manager & m, params_ref const & p) {
return clean(alloc(solve_eqs_tactic, m, p, mk_expr_simp_replacer(m, p), true));
}
10 changes: 8 additions & 2 deletions src/tactic/core/solve_eqs_tactic.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ Revision History:
class ast_manager;
class tactic;

tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p = params_ref());
tactic * mk_solve_eqs1_tactic(ast_manager & m, params_ref const & p = params_ref());

#if 0
inline tactic * mk_solve_eqs_tactic(ast_manager & m, params_ref const & p = params_ref()) {
return mk_solve_eqs1_tactic(m, p);
}
#endif

/*
ADD_TACTIC("solve-eqs", "eliminate variables by solving equations.", "mk_solve_eqs_tactic(m, p)")
ADD_TACTIC("solve-eqs1", "eliminate variables by solving equations.", "mk_solve_eqs1_tactic(m, p)")
*/


Loading

0 comments on commit 3a37cfc

Please sign in to comment.