Skip to content

Commit

Permalink
Merge pull request #1050 from borglab/feature/betterMPE
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 22, 2022
2 parents 0638727 + 82fed10 commit b441eea
Show file tree
Hide file tree
Showing 35 changed files with 736 additions and 344 deletions.
9 changes: 4 additions & 5 deletions examples/DiscreteBayesNetExample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);

// solve
auto mpe = chordal->optimize();
auto mpe = fg.optimize();
GTSAM_PRINT(mpe);

// We can also build a Bayes tree (directed junction tree).
Expand All @@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1");

// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto mpe2 = chordal2->optimize();
auto mpe2 = fg.optimize();
GTSAM_PRINT(mpe2);

// We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample();
auto sample = chordal->sample();
GTSAM_PRINT(sample);
}
return 0;
Expand Down
8 changes: 4 additions & 4 deletions examples/DiscreteBayesNet_FG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ int main(int argc, char **argv) {
}

// "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize();
auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe);

Expand All @@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0");

// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto mpe_with_evidence = chordal->optimize();
auto mpe_with_evidence = graph.optimize();

cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence);
Expand All @@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl;

// We can also sample from it
// We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample();
Expand Down
8 changes: 4 additions & 4 deletions examples/HMMExample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph
DiscreteFactorGraph factorGraph(hmm);

// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);

// Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated");

// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);

// We can also sample from it
cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) {
Expand Down
5 changes: 2 additions & 3 deletions examples/UGM_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge).";

// "Decoding", i.e., configuration with largest value
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product.
auto optimalDecoding = graph.optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");

// "Inference" Computing marginals for each node
Expand Down
5 changes: 2 additions & 3 deletions examples/UGM_small.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ int main(int argc, char** argv) {
}

// "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product
auto optimalDecoding = graph.optimize();
GTSAM_PRINT(optimalDecoding);

// "Inference" Computing marginals
Expand Down
9 changes: 7 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/base/FastSet.h>

#include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility>

using namespace std;
Expand Down Expand Up @@ -65,9 +66,13 @@ namespace gtsam {

/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
cout << s;
ADT::print("Potentials:",formatter);
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
}

/* ************************************************************************* */
Expand Down
7 changes: 6 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add);
}

/// Create new factor by maximizing over all values with the same separator values
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}

/// @}
/// @name Advanced Interface
/// @{
Expand Down
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
}

/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}

DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
Expand Down
36 changes: 12 additions & 24 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@

namespace gtsam {

/** A Bayes net made from linear-Discrete densities */
/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

typedef FactorGraph<DiscreteConditional> Base;
typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
Expand All @@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors
/// @{

/** Construct empty factor graph */
/// Construct empty Bayes net.
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
Expand Down Expand Up @@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values);
}

/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;

/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;

/**
* @brief do ancestral sampling
*
Expand Down Expand Up @@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

///@}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{

DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @}
#endif

private:
/** Serialization function */
Expand Down
Loading

0 comments on commit b441eea

Please sign in to comment.