-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
Copy pathrecfun_decl_plugin.h
367 lines (279 loc) · 13.8 KB
/
recfun_decl_plugin.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
/*++
Copyright (c) 2017 Microsoft Corporation, Simon Cruanes
Module Name:
recfun_decl_plugin.h
Abstract:
Declaration and definition of (potentially recursive) functions
Author:
Simon Cruanes 2017-11
Revision History:
--*/
#pragma once
#include "ast/ast.h"
#include "ast/ast_pp.h"
#include "ast/ast_translation.h"
#include "util/obj_hashtable.h"
namespace recfun {
class case_def; //<! one possible control path of a function
class case_pred; //<! a predicate guarding a given control flow path of a function
class util; //<! util for other modules
class def; //!< definition of a (recursive) function
class promise_def; //!< definition to be complete
enum op_kind {
OP_FUN_DEFINED, // defined function with one or more cases, possibly recursive
OP_FUN_CASE_PRED, // predicate guarding a given control flow path
OP_NUM_ROUNDS // predicate round
};
/*! A predicate `p(t1...tn)`, that, if true, means `f(t1...tn)` is following
a given path of its control flow and can be unrolled.
For example, `fact n := if n<2 then 1 else n * fact(n-1)` would have two cases,
and therefore two case predicates `C_fact_0` and `C_fact_1`, where
`C_fact_0(t)=true` means `t<2` (first path) and `C_fact_1(t)=true` means `not (t<2)` (second path).
*/
typedef var_ref_vector vars;
class replace {
public:
virtual ~replace() = default;
virtual void reset() = 0;
virtual void insert(expr* d, expr* r) = 0;
virtual expr_ref operator()(expr* e) = 0;
};
class case_def {
friend class def;
func_decl_ref m_pred; //<! predicate used for this case
expr_ref_vector m_guards; //<! conjunction that is equivalent to this case
expr_ref m_rhs; //<! if guard is true, `f(t1...tn) = rhs` holds
def * m_def; //<! definition this is a part of
bool m_immediate; //<! does `rhs` contain no defined_fun/case_pred?
case_def(ast_manager& m):
m_pred(m),
m_guards(m),
m_rhs(m)
{}
case_def(ast_manager & m,
family_id fid,
def * d,
std::string & name,
unsigned case_index,
sort_ref_vector const & arg_sorts,
expr_ref_vector const& guards,
expr* rhs);
void add_guard(expr_ref && e) { m_guards.push_back(e); }
public:
func_decl* get_decl() const { return m_pred; }
app_ref apply_case_predicate(expr_ref_vector const & args) const {
ast_manager& m = m_pred.get_manager();
return app_ref(m.mk_app(m_pred, args.size(), args.data()), m);
}
def * get_def() const { return m_def; }
expr_ref_vector const & get_guards() const { return m_guards; }
expr * get_guards_c_ptr() const { return *m_guards.data(); }
expr * get_guard(unsigned i) const { return m_guards[i]; }
expr * get_rhs() const { return m_rhs; }
unsigned num_guards() const { return m_guards.size(); }
bool is_immediate() const { return m_immediate; };
void set_is_immediate(bool b) { m_immediate = b; }
};
// closure for computing whether a `rhs` expression is immediate
struct is_immediate_pred {
virtual bool operator()(expr * rhs) = 0;
virtual ~is_immediate_pred() = default;
};
class def {
friend class util;
friend class promise_def;
typedef vector<case_def> cases;
ast_manager & m;
symbol m_name; //<! name of function
sort_ref_vector m_domain; //<! type of arguments
sort_ref m_range; //<! return type
vars m_vars; //<! variables of the function
cases m_cases; //!< possible cases
func_decl_ref m_decl; //!< generic declaration
expr_ref m_rhs; //!< definition
family_id m_fid;
def(ast_manager &m, family_id fid, symbol const & s, unsigned arity, sort *const * domain, sort* range, bool is_generated);
// compute cases for a function, given its RHS (possibly containing `ite`).
void compute_cases(util& u, replace& subst, is_immediate_pred &,
bool is_macro, unsigned n_vars, var *const * vars, expr* rhs);
void add_case(std::string & name, unsigned case_index, expr_ref_vector const& conditions, expr* rhs, bool is_imm = false);
bool contains_ite(util& u, expr* e); // expression contains a test over a def?
bool contains_def(util& u, expr* e); // expression contains a def
public:
symbol const & get_name() const { return m_name; }
vars const & get_vars() const { return m_vars; }
cases & get_cases() { return m_cases; }
unsigned get_arity() const { return m_domain.size(); }
sort_ref_vector const & get_domain() const { return m_domain; }
sort_ref const & get_range() const { return m_range; }
func_decl * get_decl() const { return m_decl.get(); }
expr * get_rhs() const { return m_rhs; }
bool is_fun_macro() const { return m_cases.size() == 1; }
bool is_fun_defined() const { return !is_fun_macro(); }
def* copy(util& dst, ast_translation& tr);
};
// definition to be complete (missing RHS)
class promise_def {
friend class util;
util * u;
def * d;
void set_definition(replace& r, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs); // call only once
public:
promise_def(util * u, def * d) : u(u), d(d) {}
promise_def(promise_def const & from) : u(from.u), d(from.d) {}
def * get_def() const { return d; }
};
namespace decl {
class plugin : public decl_plugin {
typedef obj_map<func_decl, def*> def_map;
typedef obj_map<func_decl, case_def*> case_def_map;
mutable scoped_ptr<util> m_util;
def_map m_defs; // function->def
case_def_map m_case_defs; // case_pred->def
ast_manager & m() { return *m_manager; }
void compute_scores(expr* e, obj_map<expr, unsigned>& scores);
public:
plugin();
~plugin() override;
void finalize() override;
util & u() const; // build or return util
bool is_fully_interp(sort * s) const override { return false; } // might depend on unin sorts
decl_plugin * mk_fresh() override { return alloc(plugin); }
sort * mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) override {
UNREACHABLE(); return nullptr;
}
func_decl * mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters,
unsigned arity, sort * const * domain, sort * range) override;
promise_def mk_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated = false);
promise_def ensure_def(symbol const& name, unsigned n, sort *const * params, sort * range, bool is_generated = false);
void set_definition(replace& r, promise_def & d, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs);
def* mk_def(replace& subst, bool is_macro, symbol const& name, unsigned n, sort ** params, sort * range, unsigned n_vars, var ** vars, expr * rhs);
void erase_def(func_decl* f);
bool has_def(func_decl* f) const { return m_defs.contains(f); }
bool has_defs() const;
def const& get_def(func_decl* f) const { return *(m_defs[f]); }
promise_def get_promise_def(func_decl* f) const { return promise_def(&u(), m_defs[f]); }
def& get_def(func_decl* f) { return *(m_defs[f]); }
bool has_case_def(func_decl* f) const { return m_case_defs.contains(f); }
case_def& get_case_def(func_decl* f) { SASSERT(has_case_def(f)); return *(m_case_defs[f]); }
func_decl_ref_vector get_rec_funs() {
func_decl_ref_vector result(m());
for (auto& kv : m_defs) result.push_back(kv.m_key);
return result;
}
expr_ref redirect_ite(replace& subst, unsigned n, var * const* vars, expr * e);
void inherit(decl_plugin* other, ast_translation& tr) override;
};
}
// Various utils for recursive functions
class util {
friend class decl::plugin;
ast_manager & m_manager;
family_id m_fid;
decl::plugin * m_plugin;
bool compute_is_immediate(expr * rhs);
void set_definition(replace& r, promise_def & d, bool is_macro, unsigned n_vars, var * const * vars, expr * rhs);
public:
util(ast_manager &m);
~util();
ast_manager & m() { return m_manager; }
family_id get_family_id() const { return m_fid; }
decl::plugin& get_plugin() { return *m_plugin; }
bool is_case_pred(expr * e) const { return is_app_of(e, m_fid, OP_FUN_CASE_PRED); }
bool is_defined(expr * e) const { return is_app_of(e, m_fid, OP_FUN_DEFINED); }
bool is_defined(func_decl* f) const { return is_decl_of(f, m_fid, OP_FUN_DEFINED); }
bool is_generated(func_decl* f) const { return is_defined(f) && f->get_parameter(0).get_int() == 1; }
bool is_num_rounds(expr * e) const { return is_app_of(e, m_fid, OP_NUM_ROUNDS); }
bool owns_app(app * e) const { return e->get_family_id() == m_fid; }
//<! don't use native theory if recursive function declarations are not populated with defs
bool has_defs() const { return m_plugin->has_defs(); }
//<! add a function declaration
def * decl_fun(symbol const & s, unsigned n_args, sort *const * args, sort * range, bool is_generated);
bool has_def(func_decl* f) const {
return m_plugin->has_def(f);
}
def& get_def(func_decl* f) {
SASSERT(has_def(f));
return m_plugin->get_def(f);
}
case_def& get_case_def(expr* e) {
SASSERT(is_case_pred(e));
return m_plugin->get_case_def(to_app(e)->get_decl());
}
app* mk_fun_defined(def const & d, unsigned n_args, expr * const * args) {
return m().mk_app(d.get_decl(), n_args, args);
}
app* mk_fun_defined(def const & d, ptr_vector<expr> const & args) {
return mk_fun_defined(d, args.size(), args.data());
}
app* mk_fun_defined(def const & d, expr_ref_vector const & args) {
return mk_fun_defined(d, args.size(), args.data());
}
func_decl_ref_vector get_rec_funs() {
return m_plugin->get_rec_funs();
}
app_ref mk_num_rounds_pred(unsigned d);
};
// one case-expansion of `f(t1...tn)`
struct case_expansion {
app_ref m_lhs; // the term to expand
recfun::def * m_def;
expr_ref_vector m_args;
case_expansion(recfun::util& u, app * n);
case_expansion(case_expansion const & from);
case_expansion(case_expansion && from);
std::ostream& display(std::ostream& out) const;
};
inline std::ostream& operator<<(std::ostream& out, case_expansion const & e) {
return e.display(out);
}
// one body-expansion of `f(t1...tn)` using a `C_f_i(t1...tn)`
struct body_expansion {
app_ref m_pred;
recfun::case_def const * m_cdef;
expr_ref_vector m_args;
body_expansion(recfun::util& u, app * n) :
m_pred(n, u.m()), m_cdef(nullptr), m_args(u.m()) {
m_cdef = &u.get_case_def(n);
m_args.append(n->get_num_args(), n->get_args());
}
body_expansion(app_ref & pred, recfun::case_def const & d, expr_ref_vector & args) :
m_pred(pred), m_cdef(&d), m_args(args) {}
body_expansion(body_expansion const & from):
m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(from.m_args) {}
body_expansion(body_expansion && from) noexcept :
m_pred(from.m_pred), m_cdef(from.m_cdef), m_args(std::move(from.m_args)) {}
std::ostream& display(std::ostream& out) const;
};
inline std::ostream& operator<<(std::ostream& out, body_expansion const& e) {
return e.display(out);
}
struct propagation_item {
case_expansion* m_case { nullptr };
body_expansion* m_body { nullptr };
expr_ref_vector* m_core { nullptr };
expr* m_guard { nullptr };
~propagation_item() {
dealloc(m_case);
dealloc(m_body);
dealloc(m_core);
}
propagation_item(expr* guard):
m_guard(guard) {}
propagation_item(expr_ref_vector const& core):
m_core(alloc(expr_ref_vector, core)) {
}
propagation_item(body_expansion* b):
m_body(b) {}
propagation_item(case_expansion* c):
m_case(c) {}
bool is_guard() const { return m_guard != nullptr; }
bool is_core() const { return m_core != nullptr; }
bool is_case() const { return m_case != nullptr; }
bool is_body() const { return m_body != nullptr; }
expr_ref_vector const& core() const { SASSERT(is_core()); return *m_core; }
body_expansion & body() const { SASSERT(is_body()); return *m_body; }
case_expansion & case_ex() const { SASSERT(is_case()); return *m_case; }
expr* guard() const { SASSERT(is_guard()); return m_guard; }
};
}