Skip to content

Commit

Permalink
Update flaw initialization and resolver application logic
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardodebenedictis committed Apr 13, 2024
1 parent 4a30812 commit 15c23b9
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 20 deletions.
2 changes: 1 addition & 1 deletion extern/riddle
2 changes: 1 addition & 1 deletion include/flaw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace ratio
*/
void init() noexcept;
/**
* \brief Computes the resolvers for the flaw.
* @brief Computes the resolvers for the flaw.
*
* This function is a pure virtual function that must be implemented by derived classes.
* It is responsible for computing the resolvers for the flaw.
Expand Down
2 changes: 1 addition & 1 deletion include/flaws/atom_flaw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ratio
{
class atom;

class atom_flaw : public flaw
class atom_flaw final : public flaw
{
public:
atom_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, bool is_fact, riddle::predicate &pred, std::map<std::string, std::shared_ptr<riddle::item>> &&arguments) noexcept;
Expand Down
2 changes: 1 addition & 1 deletion include/flaws/bool_flaw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

namespace ratio
{
class bool_flaw : public flaw
class bool_flaw final : public flaw
{
public:
bool_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::shared_ptr<riddle::bool_item> b_item) noexcept;
Expand Down
11 changes: 10 additions & 1 deletion include/flaws/disj_flaw.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "flaw.hpp"
#include "resolver.hpp"

namespace ratio
{
class disj_flaw : public flaw
class disj_flaw final : public flaw
{
public:
disj_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<utils::lit> &&lits, bool exclusive = false) noexcept;
Expand All @@ -14,6 +15,14 @@ namespace ratio
private:
void compute_resolvers() override;

class choose_lit final : public resolver
{
public:
choose_lit(disj_flaw &ef, const utils::rational &cost, const utils::lit &l);

void apply() override {}
};

private:
std::vector<utils::lit> lits;
};
Expand Down
20 changes: 16 additions & 4 deletions include/flaws/disjunction_flaw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,32 @@

#include "flaw.hpp"
#include "conjunction.hpp"
#include "resolver.hpp"

namespace ratio
{
class disjunction_flaw : public flaw
class disjunction_flaw final : public flaw
{
public:
disjunction_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&xprs) noexcept;
disjunction_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&conjs) noexcept;

[[nodiscard]] const std::vector<std::unique_ptr<riddle::conjunction>> &get_conjunctions() const noexcept { return xprs; }
[[nodiscard]] const std::vector<std::unique_ptr<riddle::conjunction>> &get_conjunctions() const noexcept { return conjs; }

private:
void compute_resolvers() override;

class choose_conjunction final : public resolver
{
public:
choose_conjunction(disjunction_flaw &df, riddle::conjunction &conj, const utils::rational &cost);

void apply() override;

private:
riddle::conjunction &conj;
};

private:
std::vector<std::unique_ptr<riddle::conjunction>> xprs;
std::vector<std::unique_ptr<riddle::conjunction>> conjs;
};
} // namespace ratio
3 changes: 2 additions & 1 deletion include/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ namespace ratio
{
static_assert(std::is_base_of_v<resolver, Tp>, "Tp must be a subclass of resolver");
auto r = new Tp(std::forward<Args>(args)...);
rhos[r->get_rho()].push_back(std::unique_ptr<resolver>(r));
rhos[variable(r->get_rho())].push_back(std::unique_ptr<resolver>(r));
r->get_flaw().resolvers.push_back(*r);
return *r;
}

Expand Down
12 changes: 7 additions & 5 deletions src/flaws/disj_flaw.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
#include <cassert>
#include "disj_flaw.hpp"
#include "solver.hpp"
#include "graph.hpp"

namespace ratio
{
disj_flaw::disj_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<utils::lit> &&lits, bool exclusive) noexcept : flaw(s, std::move(causes), exclusive), lits(lits)
{
assert(!lits.empty());
}
disj_flaw::disj_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<utils::lit> &&lits, bool exclusive) noexcept : flaw(s, std::move(causes), exclusive), lits(lits) { assert(!lits.empty()); }

void disj_flaw::compute_resolvers()
{
throw std::runtime_error("Not implemented yet");
for (const auto &l : lits)
if (get_solver().get_sat().value(l) != utils::False)
get_solver().get_graph().new_resolver<choose_lit>(*this, utils::rational::one / lits.size(), l);
}

disj_flaw::choose_lit::choose_lit(disj_flaw &ef, const utils::rational &cost, const utils::lit &l) : resolver(ef, l, cost) {}
} // namespace ratio
16 changes: 11 additions & 5 deletions src/flaws/disjunction_flaw.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
#include <cassert>
#include "disjunction_flaw.hpp"
#include "solver.hpp"
#include "graph.hpp"

namespace ratio
{
disjunction_flaw::disjunction_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&xprs) noexcept : flaw(s, std::move(causes)), xprs(std::move(xprs))
{
assert(!xprs.empty());
}
disjunction_flaw::disjunction_flaw(solver &s, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&conjs) noexcept : flaw(s, std::move(causes)), conjs(std::move(conjs)) { assert(!conjs.empty()); }

void disjunction_flaw::compute_resolvers()
{
throw std::runtime_error("Not implemented yet");
for (auto &conj : conjs)
{
auto cost = conj->compute_cost();
get_solver().get_graph().new_resolver<choose_conjunction>(*this, *conj, get_solver().arithmetic_value(*cost).get_rational());
}
}

disjunction_flaw::choose_conjunction::choose_conjunction(disjunction_flaw &df, riddle::conjunction &conj, const utils::rational &cost) : resolver(df, cost), conj(conj) {}

void disjunction_flaw::choose_conjunction::apply() { conj.execute(); }
} // namespace ratio

0 comments on commit 15c23b9

Please sign in to comment.