Skip to content

Commit 9a9ccf1

Browse files
committed
introdure lar_term.ext_coeffs(), dio passes some tests
Signed-off-by: Lev Nachmanson <levnach@hotmail.com>
1 parent 083926c commit 9a9ccf1

File tree

2 files changed

+117
-33
lines changed

2 files changed

+117
-33
lines changed

src/math/lp/dioph_eq.cpp

+12-27
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,6 @@ namespace lp {
174174
else if (val != numeric_traits<mpq>::one())
175175
out << T_to_string(val);
176176
out << "x";
177-
if (is_fresh_var(j))
178-
out << "~";
179177
out << j;
180178
}
181179

@@ -374,7 +372,7 @@ namespace lp {
374372

375373
void register_columns_to_term(const lar_term& t) {
376374
TRACE("dioph_eq", tout << "register term:"; lra.print_term(t, tout););
377-
for (const auto &p: t) {
375+
for (const auto &p: t.ext_coeffs()) {
378376
auto it = m_columns_to_terms.find(p.var());
379377
if (it != m_columns_to_terms.end()) {
380378
it->second.insert(t.j());
@@ -405,7 +403,7 @@ namespace lp {
405403
m_e_matrix.add_row();
406404
SASSERT(m_e_matrix.row_count() == m_entries.size());
407405

408-
for (const auto& p : t) {
406+
for (const auto& p : t.ext_coeffs()) {
409407
SASSERT(p.coeff().is_int());
410408
if (is_fixed(p.var()))
411409
e.m_c += p.coeff() * lia.lower_bound(p.var()).x;
@@ -415,13 +413,6 @@ namespace lp {
415413
m_e_matrix.add_new_element(entry_index, lj, p.coeff());
416414
}
417415
}
418-
if (is_fixed(t.j())) {
419-
e.m_c -= lia.lower_bound(t.j()).x;
420-
} else {
421-
unsigned lj = add_var(t.j());
422-
m_e_matrix.add_columns_up_to(lj);
423-
m_e_matrix.add_new_element(entry_index, lj, -mpq(1));
424-
}
425416
SASSERT(entry_invariant(entry_index));
426417
}
427418

@@ -866,7 +857,7 @@ namespace lp {
866857
print_lar_term_L(term_to_tighten, tout) << std::endl;
867858
tout << "m_tmp_l:"; print_lar_term_L(m_tmp_l, tout) << std::endl;
868859
tout << "open_ml:";
869-
print_term_o(open_ml(m_tmp_l), tout) << std::endl;
860+
print_lar_term_L(open_ml(m_tmp_l), tout) << std::endl;
870861
tout << "term_to_tighten + open_ml:";
871862
print_term_o(term_to_tighten + open_ml(m_tmp_l), tout)
872863
<< std::endl;
@@ -1307,7 +1298,7 @@ namespace lp {
13071298
for (unsigned k = 0; k < lra.terms().size(); k ++ ) {
13081299
const lar_term* t = lra.terms()[k];
13091300
if (!all_vars_are_int(*t)) continue;
1310-
for (const auto& p: *t) {
1301+
for (const auto& p: (*t).ext_coeffs()) {
13111302
unsigned j = p.var();
13121303
auto it = c2t.find(j);
13131304
if (it == c2t.end()) {
@@ -1317,7 +1308,8 @@ namespace lp {
13171308
} else {
13181309
it->second.insert(t->j());
13191310
}
1320-
1311+
1312+
13211313
}
13221314
}
13231315
for (const auto & p : c2t) {
@@ -1502,14 +1494,14 @@ namespace lp {
15021494
{
15031495
tout << "get_term_from_entry(" << ei << "):";
15041496
print_term_o(get_term_from_entry(ei), tout) << std::endl;
1505-
tout << "remove_fresh_vars:";
1497+
tout << "ls:";
15061498
print_term_o(remove_fresh_vars(get_term_from_entry(ei)), tout)
15071499
<< std::endl;
15081500
tout << "e.m_l:"; print_lar_term_L(l_term_from_row(ei), tout) << std::endl;
15091501
tout << "open_ml(e.m_l):";
1510-
print_term_o(open_ml(l_term_from_row(ei)), tout) << std::endl;
1511-
tout << "fix_vars(open_ml(e.m_l)):";
1512-
print_term_o(fix_vars(open_ml(l_term_from_row(ei))), tout) << std::endl;
1502+
print_lar_term_L(open_ml(l_term_from_row(ei)), tout) << std::endl;
1503+
tout << "rs:";
1504+
print_term_o(fix_vars(open_ml(m_l_matrix.m_rows[ei])), tout) << std::endl;
15131505
}
15141506
);
15151507
return ret;
@@ -1544,7 +1536,7 @@ namespace lp {
15441536

15451537
std::ostream& print_ml(const lar_term& ml, std::ostream& out) {
15461538
term_o opened_ml = open_ml(ml);
1547-
return print_term_o(opened_ml, out);
1539+
return print_lar_term_L(opened_ml, out);
15481540
}
15491541

15501542
template <typename T> term_o open_ml(const T& ml) const {
@@ -1564,21 +1556,14 @@ namespace lp {
15641556
m_indexed_work_vector.clear();
15651557
for (const auto & p: m_l_matrix.m_rows[ei]) {
15661558
const lar_term& t = lra.get_term(p.var());
1567-
for (const auto & q: t) {
1559+
for (const auto & q: t.ext_coeffs()) {
15681560
if (is_fixed(q.var())) {
15691561
c += p.coeff()*q.coeff()*lia.lower_bound(q.var()).x;
15701562
} else {
15711563
make_space_in_work_vector(q.var());
15721564
m_indexed_work_vector.add_value_at_index(q.var(), p.coeff() * q.coeff());
15731565
}
15741566
}
1575-
if (is_fixed(t.j())) {
1576-
c -= lia.lower_bound(t.j()).x;
1577-
}
1578-
else {
1579-
make_space_in_work_vector(t.j());
1580-
m_indexed_work_vector.add_value_at_index(t.j(), -p.coeff());
1581-
}
15821567
}
15831568
}
15841569

src/math/lp/lar_term.h

+105-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class lar_term {
3030
// the column index related to the term
3131
lpvar m_j = -1;
3232
public:
33+
34+
3335
// the column index related to the term
3436
lpvar j() const { return m_j; }
3537
void set_j(unsigned j) {
@@ -92,8 +94,6 @@ class lar_term {
9294
}
9395
// constructors
9496
lar_term() = default;
95-
<<<<<<< HEAD
96-
=======
9797
lar_term(lar_term&& other) noexcept = default;
9898
// copy assignment operator
9999
lar_term& operator=(const lar_term& other) = default;
@@ -106,7 +106,6 @@ class lar_term {
106106
}
107107
}
108108

109-
>>>>>>> 956229fb6 (test that pivoting is correct in dioph_eq.cpp)
110109
lar_term(const vector<std::pair<mpq, unsigned>>& coeffs) {
111110
for (auto const& p : coeffs) {
112111
add_monomial(p.first, p.second);
@@ -257,10 +256,9 @@ class lar_term {
257256
m_coeffs.reset();
258257
}
259258

260-
class ival {
259+
struct ival {
261260
lpvar m_var;
262261
const mpq & m_coeff;
263-
public:
264262
ival(lpvar var, const mpq & val) : m_var(var), m_coeff(val) { }
265263
lpvar j() const { return m_var; }
266264
lpvar var() const { return m_var; }
@@ -274,7 +272,12 @@ class lar_term {
274272
const_iterator operator++() { const_iterator i = *this; m_it++; return i; }
275273
const_iterator operator++(int) { m_it++; return *this; }
276274
const_iterator(u_map<mpq>::iterator it) : m_it(it) {}
277-
bool operator!=(const const_iterator &other) const { return m_it != other.m_it; }
275+
bool operator==(const const_iterator &other) const { return m_it == other.m_it; }
276+
bool operator!=(const const_iterator &other) const { return !(*this == other); }
277+
// Return a pointer to the same object returned by operator*.
278+
const ival* operator->() const {
279+
return &(**this);
280+
}
278281
};
279282

280283
bool is_normalized() const {
@@ -316,5 +319,101 @@ class lar_term {
316319
}
317320
const_iterator begin() const { return m_coeffs.begin();}
318321
const_iterator end() const { return m_coeffs.end(); }
322+
// This iterator yields all (coefficient, variable) pairs
323+
// plus one final pair: (mpq(-1), j()).
324+
class ext_const_iterator {
325+
// We'll store a reference to the lar_term, and an
326+
// iterator into m_coeffs. Once we reach end of m_coeffs,
327+
// we'll yield exactly one extra pair, then we are done.
328+
const lar_term& m_term;
329+
lar_term::const_iterator m_it;
330+
bool m_done; // Have we gone past m_coeffs?
331+
332+
public:
333+
// Construct either a "begin" iterator (end=false) or "end" iterator (end=true).
334+
ext_const_iterator(const lar_term& t, bool is_end)
335+
: m_term(t)
336+
, m_it(is_end ? t.end() : t.begin())
337+
, m_done(false)
338+
{
339+
// If it is_end == true, we represent a genuine end-iterator.
340+
if (is_end) {
341+
m_done = true;
342+
}
343+
}
344+
345+
// Compare iterators. Two iterators are equal if both are "done" or hold the same internal iterator.
346+
bool operator==(ext_const_iterator const &other) const {
347+
// They are equal if they are both at the special extra pair or both at the same spot in m_coeffs.
348+
if (m_done && other.m_done) {
349+
return true;
350+
}
351+
return (!m_done && !other.m_done && m_it == other.m_it);
352+
}
353+
354+
bool operator!=(ext_const_iterator const &other) const {
355+
return !(*this == other);
356+
}
357+
358+
// Return the element we point to:
359+
// 1) If we haven't finished m_coeffs, yield (coefficient, var).
360+
// 2) If we've iterated past m_coeffs exactly once, return (mpq(-1), j()).
361+
auto operator*() const {
362+
if (!m_done && m_it != m_term.end()) {
363+
// Normal monomial from m_coeffs
364+
// Each entry is of type { m_value, m_key } in this context
365+
return *m_it;
366+
}
367+
else {
368+
// We've gone past normal entries, so return the extra pair
369+
// (mpq(-1), j()).
370+
return ival(m_term.j(), rational::minus_one());
371+
}
372+
}
373+
374+
// Pre-increment
375+
ext_const_iterator& operator++() {
376+
if (!m_done && m_it != m_term.end()) {
377+
++m_it;
378+
}
379+
else {
380+
// We were about to return that extra pair:
381+
// after we move once more, we are done.
382+
m_done = true;
383+
}
384+
return *this;
385+
}
386+
387+
// Post-increment
388+
ext_const_iterator operator++(int) {
389+
ext_const_iterator temp(*this);
390+
++(*this);
391+
return temp;
392+
}
393+
};
394+
395+
// Return the begin/end of our extended iteration.
396+
// begin: starts at first real monomial
397+
// end: marks a finalized end of iteration
398+
ext_const_iterator ext_coeffs_begin() const {
399+
return ext_const_iterator(*this, /*is_end=*/false);
400+
}
401+
ext_const_iterator ext_coeffs_end() const {
402+
return ext_const_iterator(*this, /*is_end=*/true);
403+
}
404+
405+
// Provide a small helper for "range-based for":
406+
// for (auto & [coef, var] : myTerm.ext_coeffs()) { ... }
407+
struct ext_range {
408+
ext_const_iterator b, e;
409+
ext_const_iterator begin() const { return b; }
410+
ext_const_iterator end() const { return e; }
411+
};
412+
413+
// return an object that can be used in range-based for loops
414+
ext_range ext_coeffs() const {
415+
return { ext_coeffs_begin(), ext_coeffs_end() };
416+
}
417+
319418
};
320419
}

0 commit comments

Comments
 (0)