diff --git a/extern/riddle b/extern/riddle index a2872d9..c042d5b 160000 --- a/extern/riddle +++ b/extern/riddle @@ -1 +1 @@ -Subproject commit a2872d9cb6fc4754a37052b211c7c9966c067127 +Subproject commit c042d5b285e833bda4960af254cd56536132884f diff --git a/include/flaw.hpp b/include/flaw.hpp index bbcf151..e4fdc67 100644 --- a/include/flaw.hpp +++ b/include/flaw.hpp @@ -107,7 +107,7 @@ namespace ratio */ void init() noexcept; /** - * \brief Computes the resolvers for the flaw. + * @brief Computes the resolvers for the flaw. * * This function is a pure virtual function that must be implemented by derived classes. * It is responsible for computing the resolvers for the flaw. diff --git a/include/flaws/atom_flaw.hpp b/include/flaws/atom_flaw.hpp index f03d4e5..051aea7 100644 --- a/include/flaws/atom_flaw.hpp +++ b/include/flaws/atom_flaw.hpp @@ -7,7 +7,7 @@ namespace ratio { class atom; - class atom_flaw : public flaw + class atom_flaw final : public flaw { public: atom_flaw(solver &s, std::vector> &&causes, bool is_fact, riddle::predicate &pred, std::map> &&arguments) noexcept; diff --git a/include/flaws/bool_flaw.hpp b/include/flaws/bool_flaw.hpp index bf1f6d6..de12ff2 100644 --- a/include/flaws/bool_flaw.hpp +++ b/include/flaws/bool_flaw.hpp @@ -5,7 +5,7 @@ namespace ratio { - class bool_flaw : public flaw + class bool_flaw final : public flaw { public: bool_flaw(solver &s, std::vector> &&causes, std::shared_ptr b_item) noexcept; diff --git a/include/flaws/disj_flaw.hpp b/include/flaws/disj_flaw.hpp index f537a9f..38c145c 100644 --- a/include/flaws/disj_flaw.hpp +++ b/include/flaws/disj_flaw.hpp @@ -1,10 +1,11 @@ #pragma once #include "flaw.hpp" +#include "resolver.hpp" namespace ratio { - class disj_flaw : public flaw + class disj_flaw final : public flaw { public: disj_flaw(solver &s, std::vector> &&causes, std::vector &&lits, bool exclusive = false) noexcept; @@ -14,6 +15,14 @@ namespace ratio private: void compute_resolvers() override; + class choose_lit final : public resolver + { + public: + choose_lit(disj_flaw &ef, const utils::rational &cost, const utils::lit &l); + + void apply() override {} + }; + private: std::vector lits; }; diff --git a/include/flaws/disjunction_flaw.hpp b/include/flaws/disjunction_flaw.hpp index d95b88e..2aead74 100644 --- a/include/flaws/disjunction_flaw.hpp +++ b/include/flaws/disjunction_flaw.hpp @@ -2,20 +2,32 @@ #include "flaw.hpp" #include "conjunction.hpp" +#include "resolver.hpp" namespace ratio { - class disjunction_flaw : public flaw + class disjunction_flaw final : public flaw { public: - disjunction_flaw(solver &s, std::vector> &&causes, std::vector> &&xprs) noexcept; + disjunction_flaw(solver &s, std::vector> &&causes, std::vector> &&conjs) noexcept; - [[nodiscard]] const std::vector> &get_conjunctions() const noexcept { return xprs; } + [[nodiscard]] const std::vector> &get_conjunctions() const noexcept { return conjs; } private: void compute_resolvers() override; + class choose_conjunction final : public resolver + { + public: + choose_conjunction(disjunction_flaw &df, riddle::conjunction &conj, const utils::rational &cost); + + void apply() override; + + private: + riddle::conjunction &conj; + }; + private: - std::vector> xprs; + std::vector> conjs; }; } // namespace ratio diff --git a/include/graph.hpp b/include/graph.hpp index 396c1d1..6d39afb 100644 --- a/include/graph.hpp +++ b/include/graph.hpp @@ -52,7 +52,8 @@ namespace ratio { static_assert(std::is_base_of_v, "Tp must be a subclass of resolver"); auto r = new Tp(std::forward(args)...); - rhos[r->get_rho()].push_back(std::unique_ptr(r)); + rhos[variable(r->get_rho())].push_back(std::unique_ptr(r)); + r->get_flaw().resolvers.push_back(*r); return *r; } diff --git a/src/flaws/disj_flaw.cpp b/src/flaws/disj_flaw.cpp index 769f97c..700955d 100644 --- a/src/flaws/disj_flaw.cpp +++ b/src/flaws/disj_flaw.cpp @@ -1,16 +1,18 @@ #include #include "disj_flaw.hpp" #include "solver.hpp" +#include "graph.hpp" namespace ratio { - disj_flaw::disj_flaw(solver &s, std::vector> &&causes, std::vector &&lits, bool exclusive) noexcept : flaw(s, std::move(causes), exclusive), lits(lits) - { - assert(!lits.empty()); - } + disj_flaw::disj_flaw(solver &s, std::vector> &&causes, std::vector &&lits, bool exclusive) noexcept : flaw(s, std::move(causes), exclusive), lits(lits) { assert(!lits.empty()); } void disj_flaw::compute_resolvers() { - throw std::runtime_error("Not implemented yet"); + for (const auto &l : lits) + if (get_solver().get_sat().value(l) != utils::False) + get_solver().get_graph().new_resolver(*this, utils::rational::one / lits.size(), l); } + + disj_flaw::choose_lit::choose_lit(disj_flaw &ef, const utils::rational &cost, const utils::lit &l) : resolver(ef, l, cost) {} } // namespace ratio diff --git a/src/flaws/disjunction_flaw.cpp b/src/flaws/disjunction_flaw.cpp index c4a8eae..0a6adb6 100644 --- a/src/flaws/disjunction_flaw.cpp +++ b/src/flaws/disjunction_flaw.cpp @@ -1,16 +1,22 @@ #include #include "disjunction_flaw.hpp" #include "solver.hpp" +#include "graph.hpp" namespace ratio { - disjunction_flaw::disjunction_flaw(solver &s, std::vector> &&causes, std::vector> &&xprs) noexcept : flaw(s, std::move(causes)), xprs(std::move(xprs)) - { - assert(!xprs.empty()); - } + disjunction_flaw::disjunction_flaw(solver &s, std::vector> &&causes, std::vector> &&conjs) noexcept : flaw(s, std::move(causes)), conjs(std::move(conjs)) { assert(!conjs.empty()); } void disjunction_flaw::compute_resolvers() { - throw std::runtime_error("Not implemented yet"); + for (auto &conj : conjs) + { + auto cost = conj->compute_cost(); + get_solver().get_graph().new_resolver(*this, *conj, get_solver().arithmetic_value(*cost).get_rational()); + } } + + disjunction_flaw::choose_conjunction::choose_conjunction(disjunction_flaw &df, riddle::conjunction &conj, const utils::rational &cost) : resolver(df, cost), conj(conj) {} + + void disjunction_flaw::choose_conjunction::apply() { conj.execute(); } } // namespace ratio