diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp index febc1e1288..dfd7beb63b 100644 --- a/examples/DiscreteBayesNetExample.cpp +++ b/examples/DiscreteBayesNetExample.cpp @@ -53,10 +53,9 @@ int main(int argc, char **argv) { // Create solver and eliminate Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); - DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); // solve - auto mpe = chordal->optimize(); + auto mpe = fg.optimize(); GTSAM_PRINT(mpe); // We can also build a Bayes tree (directed junction tree). @@ -69,14 +68,14 @@ int main(int argc, char **argv) { fg.add(Dyspnea, "0 1"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - auto mpe2 = chordal2->optimize(); + auto mpe2 = fg.optimize(); GTSAM_PRINT(mpe2); // We can also sample from it + DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - auto sample = chordal2->sample(); + auto sample = chordal->sample(); GTSAM_PRINT(sample); } return 0; diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 69283a1be7..88904001a0 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -85,7 +85,7 @@ int main(int argc, char **argv) { } // "Most Probable Explanation", i.e., configuration with largest value - auto mpe = graph.eliminateSequential()->optimize(); + auto mpe = graph.optimize(); cout << "\nMost Probable Explanation (MPE):" << endl; print(mpe); @@ -96,8 +96,7 @@ int main(int argc, char **argv) { graph.add(Cloudy, "1 0"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto mpe_with_evidence = chordal->optimize(); + auto mpe_with_evidence = graph.optimize(); cout << "\nMPE given C=0:" << endl; print(mpe_with_evidence); @@ -110,7 +109,8 @@ int main(int argc, char **argv) { cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] << endl; - // We can also sample from it + // We can also sample from the eliminated graph + DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { auto sample = chordal->sample(); diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp index b46baf4e09..3a76730016 100644 --- a/examples/HMMExample.cpp +++ b/examples/HMMExample.cpp @@ -59,16 +59,16 @@ int main(int argc, char **argv) { // Convert to factor graph DiscreteFactorGraph factorGraph(hmm); + // Do max-prodcut + auto mpe = factorGraph.optimize(); + GTSAM_PRINT(mpe); + // Create solver and eliminate // This will create a DAG ordered with arrow of time reversed DiscreteBayesNet::shared_ptr chordal = factorGraph.eliminateSequential(ordering); chordal->print("Eliminated"); - // solve - auto mpe = chordal->optimize(); - GTSAM_PRINT(mpe); - // We can also sample from it cout << "\n10 samples:" << endl; for (size_t k = 0; k < 10; k++) { diff --git a/examples/UGM_chain.cpp b/examples/UGM_chain.cpp index ababef0220..ad21af9fa7 100644 --- a/examples/UGM_chain.cpp +++ b/examples/UGM_chain.cpp @@ -68,9 +68,8 @@ int main(int argc, char** argv) { << graph.size() << " factors (Unary+Edge)."; // "Decoding", i.e., configuration with largest value - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto optimalDecoding = chordal->optimize(); + // Uses max-product. + auto optimalDecoding = graph.optimize(); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); // "Inference" Computing marginals for each node diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 24bd0c0ba7..bc6a413178 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -61,9 +61,8 @@ int main(int argc, char** argv) { } // "Decoding", i.e., configuration with largest value (MPE) - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto optimalDecoding = chordal->optimize(); + // Uses max-product + auto optimalDecoding = graph.optimize(); GTSAM_PRINT(optimalDecoding); // "Inference" Computing marginals diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ad4cbad434..9de750f2eb 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -22,6 +22,7 @@ #include #include +#include #include using namespace std; @@ -65,9 +66,13 @@ namespace gtsam { /* ************************************************************************* */ void DecisionTreeFactor::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { cout << s; - ADT::print("Potentials:",formatter); + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + ADT::print("Potentials:", formatter); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8beeb4c4a0..251575739d 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -127,11 +127,16 @@ namespace gtsam { return combine(keys, ADT::Ring::add); } - /// Create new factor by maximizing over all values with the same separator values + /// Create new factor by maximizing over all values with the same separator. shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, ADT::Ring::max); } + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, ADT::Ring::max); + } + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7294c8b296..ccc52585e6 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { } /* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues result; return optimize(result); @@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { // solve each node in turn in topological sort order (parents first) +#ifdef _MSC_VER +#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!") +#else +#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!" +#endif for (auto conditional : boost::adaptors::reverse(*this)) conditional->solveInPlace(&result); return result; } +#endif /* ************************************************************************* */ DiscreteValues DiscreteBayesNet::sample() const { diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index bd5536135a..4916cad7c0 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -31,12 +31,12 @@ namespace gtsam { -/** A Bayes net made from linear-Discrete densities */ +/** A Bayes net made from discrete conditional distributions. */ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { public: - typedef FactorGraph Base; + typedef BayesNet Base; typedef DiscreteBayesNet This; typedef DiscreteConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -45,7 +45,7 @@ namespace gtsam { /// @name Standard Constructors /// @{ - /** Construct empty factor graph */ + /// Construct empty Bayes net. DiscreteBayesNet() {} /** Construct from iterator over conditionals */ @@ -98,27 +98,6 @@ namespace gtsam { return evaluate(values); } - /** - * @brief solve by back-substitution. - * - * Assumes the Bayes net is reverse topologically sorted, i.e. last - * conditional will be optimized first. If the Bayes net resulted from - * eliminating a factor graph, this is true for the elimination ordering. - * - * @return a sampled value for all variables. - */ - DiscreteValues optimize() const; - - /** - * @brief solve by back-substitution, given certain variables. - * - * Assumes the Bayes net is reverse topologically sorted *and* that the - * Bayes net does not contain any conditionals for the given values. - * - * @return given values extended with optimized value for other variables. - */ - DiscreteValues optimize(DiscreteValues given) const; - /** * @brief do ancestral sampling * @@ -152,7 +131,16 @@ namespace gtsam { std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const DiscreteFactor::Names& names = {}) const; + ///@} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + + DiscreteValues GTSAM_DEPRECATED optimize() const; + DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const; /// @} +#endif private: /** Serialization function */ diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eb31d2e1ea..06b2856f8f 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -16,26 +16,25 @@ * @author Frank Dellaert */ +#include +#include #include #include #include -#include -#include - -#include #include +#include #include +#include #include #include -#include #include -#include +#include using namespace std; +using std::pair; using std::stringstream; using std::vector; -using std::pair; namespace gtsam { // Instantiate base class @@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s, cout << endl; } -/* ******************************************************************************** */ +/* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { if (!dynamic_cast(&other)) { @@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ************************************************************************** */ -static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, - const DiscreteValues& given, - bool forceComplete = true) { +DiscreteConditional::ADT DiscreteConditional::choose( + const DiscreteValues& given, bool forceComplete) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - DiscreteConditional::ADT adt(conditional); + DiscreteConditional::ADT adt(*this); size_t value; - for (Key j : conditional.parents()) { + for (Key j : parents()) { try { value = given.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. @@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, if (forceComplete) { given.print("parentsValues: "); throw runtime_error( - "DiscreteConditional::Choose: parent value missing"); + "DiscreteConditional::choose: parent value missing"); } } } @@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, /* ************************************************************************** */ DiscreteConditional::shared_ptr DiscreteConditional::choose( const DiscreteValues& given) const { - ADT adt = Choose(*this, given, false); // P(F|S=given) + ADT adt = choose(given, false); // P(F|S=given) // Collect all keys not in given. DiscreteKeys dKeys; @@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( return boost::make_shared(discreteKeys, adt); } -/* ******************************************************************************** */ +/* ****************************************************************************/ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t parent_value) const { if (nrFrontals() != 1) @@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } /* ************************************************************************** */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -248,59 +247,79 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // Get all Possible Configurations const auto allPosbValues = frontalAssignments(); - // Find the MPE + // Find the maximum for (const auto& frontalVals : allPosbValues) { double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Update maximum solution if better if (pValueS > maxP) { maxP = pValueS; mpe = frontalVals; } } - // set values (inPlace) to mpe + // set values (inPlace) to maximum for (Key j : frontals()) { (*values)[j] = mpe[j]; } } -/* ******************************************************************************** */ -void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - size_t sampled = sample(*values); // Sample variable given parents - (*values)[j] = sampled; // store result in partial solution -} - /* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining - // TODO, only works for one key now, seems horribly slow this way - size_t mpe = 0; - DiscreteValues frontals; + size_t max = 0; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { frontals[j] = value; double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update solution if better + if (pValueS > maxP) { + maxP = pValueS; + max = value; + } + } + return max; +} +#endif + +/* ************************************************************************** */ +size_t DiscreteConditional::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + assert(nrParents() == 0); + DiscreteValues frontals; + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = (*this)(frontals); // Update MPE solution if better if (pValueS > maxP) { maxP = pValueS; - mpe = value; + maxValue = value; } } - return mpe; + return maxValue; } -/* ******************************************************************************** */ +/* ************************************************************************** */ +void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + size_t sampled = sample(*values); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution +} + +/* ************************************************************************** */ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way if (nrFrontals() != 1) { @@ -323,7 +342,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t DiscreteConditional::sample(size_t parent_value) const { if (nrParents() != 1) throw std::invalid_argument( @@ -334,7 +353,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t DiscreteConditional::sample() const { if (nrParents() != 0) throw std::invalid_argument( diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 5908cc782e..48d94a3837 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) - * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. @@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** + /** * @brief restrict to given *parent* values. - * + * * Note: does not need be complete set. Examples: - * + * * P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + E -> P(C|D) * P(C|D,E) + D -> P(C|E) * P(C|D,E) + D,E -> P(C) * P(C|D,E) + C -> error! - * + * * @return a shared_ptr to a new DiscreteConditional */ shared_ptr choose(const DiscreteValues& given) const; @@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional /** Single variable version of likelihood. */ DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; - /** - * solve a conditional - * @param parentsValues Known values of the parents - * @return MPE value of the child (1 frontal variable). - */ - size_t solve(const DiscreteValues& parentsValues) const; - /** * sample * @param parentsValues Known values of the parents @@ -199,13 +192,16 @@ class GTSAM_EXPORT DiscreteConditional /// Zero parent version. size_t sample() const; + /** + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). + */ + size_t argmax() const; + /// @} /// @name Advanced Interface /// @{ - /// solve a conditional, in place - void solveInPlace(DiscreteValues* parentsValues) const; - /// sample in place, stores result in partial solution void sampleInPlace(DiscreteValues* parentsValues) const; @@ -228,6 +224,19 @@ class GTSAM_EXPORT DiscreteConditional const Names& names = {}) const override; /// @} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const; + void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; + /// @} +#endif + + protected: + /// Internal version of choose + DiscreteConditional::ADT choose(const DiscreteValues& given, + bool forceComplete) const; }; // DiscreteConditional diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index fae6e355bd..c5147dbc19 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -90,19 +90,13 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { /// Return entire probability mass function. std::vector pmf() const; - /** - * solve a conditional - * @return MPE value of the child (1 frontal variable). - */ - size_t solve() const { return Base::solve({}); } - - /** - * sample - * @return sample from conditional - */ - size_t sample() const { return Base::sample(); } - /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); } + /// @} +#endif }; // DiscreteDistribution diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index c1248c60b9..7c03d21f9c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -95,22 +96,85 @@ namespace gtsam { // } // } - /* ************************************************************************* */ - DiscreteValues DiscreteFactorGraph::optimize() const - { + /* ************************************************************************ */ + // Alternate eliminate function for MPE + std::pair // + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product; + for (auto&& factor : factors) product = (*factor) * product; + gttoc(product); + + // max out frontals, this is the factor on the separator + gttic(max); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + gttoc(max); + + // Ordering keys for the conditional so that frontalKeys are really in front + DiscreteKeys orderedKeys; + for (auto&& key : frontalKeys) + orderedKeys.emplace_back(key, product.cardinality(key)); + for (auto&& key : max->keys()) + orderedKeys.emplace_back(key, product.cardinality(key)); + + // Make lookup with product + gttic(lookup); + size_t nrFrontals = frontalKeys.size(); + auto lookup = boost::make_shared(nrFrontals, + orderedKeys, product); + gttoc(lookup); + + return std::make_pair( + boost::dynamic_pointer_cast(lookup), max); + } + + /* ************************************************************************ */ + // The max-product solution below is a bit clunky: the elimination machinery + // does not allow for differently *typed* versions of elimination, so we + // eliminate into a Bayes Net using the special eliminate function above, and + // then create the DiscreteLookupDAG after the fact, in linear time. + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } + + /* ************************************************************************ */ + DiscreteValues DiscreteFactorGraph::optimize( + OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - return BaseEliminateable::eliminateSequential()->optimize(); + DiscreteLookupDAG dag = maxProduct(orderingType); + return dag.argmax(); } - /* ************************************************************************* */ - std::pair // - EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { + DiscreteValues DiscreteFactorGraph::optimize( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(ordering); + return dag.argmax(); + } + /* ************************************************************************ */ + std::pair // + EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; - for(const DiscreteFactor::shared_ptr& factor: factors) - product = (*factor) * product; + for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator @@ -120,15 +184,18 @@ namespace gtsam { // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), + sum->keys().end()); // now divide product/sum to get conditional gttic(divide); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); + auto conditional = + boost::make_shared(product, *sum, orderedKeys); gttoc(divide); - return std::make_pair(cond, sum); + return std::make_pair(conditional, sum); } /* ************************************************************************ */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1da840eb8e..e0f0a104bd 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -18,10 +18,11 @@ #pragma once -#include +#include +#include #include +#include #include -#include #include #include @@ -128,18 +129,39 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** Solve the factor graph by performing variable elimination in COLAMD order using - * the dense elimination function specified in \c function, - * followed by back-substitution resulting from elimination. Is equivalent - * to calling graph.eliminateSequential()->optimize(). */ - DiscreteValues optimize() const; + /** + * @brief Implement the max-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables + */ + DiscreteLookupDAG maxProduct( + OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Implement the max-product algorithm + * + * @param ordering + * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + */ + DiscreteLookupDAG maxProduct(const Ordering& ordering) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param orderingType + * @return DiscreteValues : MPE + */ + DiscreteValues optimize( + OptionalOrderingType orderingType = boost::none) const; -// /** Permute the variables in the factors */ -// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); -// -// /** Apply a reduction, which is a remapping of variable indices. */ -// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param ordering + * @return DiscreteValues : MPE + */ + DiscreteValues optimize(const Ordering& ordering) const; /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 5ddad22b04..121d611038 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -33,16 +33,13 @@ namespace gtsam { KeyVector DiscreteKeys::indices() const { KeyVector js; - for(const DiscreteKey& key: *this) - js.push_back(key.first); + for (const DiscreteKey& key : *this) js.push_back(key.first); return js; } - map DiscreteKeys::cardinalities() const { - map cs; - cs.insert(begin(),end()); -// for(const DiscreteKey& key: *this) -// cs.insert(key); + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(), end()); return cs; } diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp new file mode 100644 index 0000000000..d96b38b0ec --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -0,0 +1,127 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.cpp + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = choose(*values, true); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + double maxP = 0; + DiscreteValues frontals; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + +/* ************************************************************************** */ +DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( + const DiscreteBayesNet& bayesNet) { + DiscreteLookupDAG dag; + for (auto&& conditional : bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; +} + +DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h new file mode 100644 index 0000000000..8cb651f28a --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.h + * @date January, 2022 + * @author Frank dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + +class DiscreteBayesNet; + +/** + * @brief DiscreteLookupTable table for max-product + * + * Inherits from discrete conditional for convenience, but is not normalized. + * Is used in the max-product algorithm. + */ +class DiscreteLookupTable : public DiscreteConditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a orted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; +}; + +/** A DAG made from lookup tables, as defined above. */ +class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = DiscreteLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + DiscreteLookupDAG() {} + + /// Create from BayesNet with LookupTables + static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); + + /// Destructor + virtual ~DiscreteLookupDAG() {} + + /// @} + + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This& bn, double tol = 1e-9) const; + + /// @} + + /// @name Standard Interface + /// @{ + + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + + /** + * @brief argmax by back-substitution, optionally given certain variables. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index e2310f4344..97c267aba2 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { gtsam::DecisionTreeFactor* likelihood( const gtsam::DiscreteValues& frontalValues) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const; - size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(size_t value) const; size_t sample() const; - void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { gtsam::DefaultKeyFormatter) const; double operator()(size_t value) const; std::vector pmf() const; - size_t solve() const; + size_t argmax() const; }; #include @@ -163,8 +161,6 @@ class DiscreteBayesNet { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; - gtsam::DiscreteValues optimize() const; - gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -217,6 +213,21 @@ class DiscreteBayesTree { std::map> names) const; }; +#include +class DiscreteLookupDAG { + DiscreteLookupDAG(); + void push_back(const gtsam::DiscreteLookupTable* table); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteLookupTable* at(size_t i) const; + void print(string s = "DiscreteLookupDAG\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + gtsam::DiscreteValues argmax() const; + gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const; +}; + #include class DotWriter { DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, @@ -260,6 +271,9 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); std::pair diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0ba53c69ab..c35d4742c0 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) { DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); - // solve - auto actualMPE = chordal->optimize(); - DiscreteValues expectedMPE; - insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 0); - EXPECT(assert_equal(expectedMPE, actualMPE)); - // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); fg.add(Dyspnea, "0 1"); // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - auto actualMPE2 = chordal2->optimize(); - DiscreteValues expectedMPE2; - insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 1); - EXPECT(assert_equal(expectedMPE2, actualMPE2)); + EXPECT(assert_equal(expected2, *chordal->back())); // now sample from it DiscreteValues expectedSample; diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp index 5c0c42e737..d88b510f81 100644 --- a/gtsam/discrete/tests/testDiscreteDistribution.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /* - * @file testDiscretePrior.cpp + * @file testDiscreteDistribution.cpp * @brief unit tests for DiscreteDistribution * @author Frank dellaert * @date December 2021 @@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) { prior.sample(); } +/* ************************************************************************* */ +TEST(DiscreteDistribution, argmax) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_LONGS_EQUAL(prior.argmax(), 1); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 579244c57f..f4819dab54 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -30,8 +30,8 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { - DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); +TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { + DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteFactorGraph graph; graph.add(AI, "1 0 0 1"); @@ -47,25 +47,11 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); -// graph.print("Graph: "); - DecisionTreeFactor product = graph.product(); - DecisionTreeFactor::shared_ptr sum = product.sum(1); -// sum->print("Debug SUM: "); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); - -// cond->print("marginal:"); - -// pair result = EliminateDiscrete(graph, 1); -// result.first->print("BayesNet: "); -// result.second->print("New factor: "); -// - Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3); - DiscreteEliminationTree eliminationTree(graph, ordering); -// eliminationTree.print("Elimination tree: "); - eliminationTree.eliminate(EliminateDiscrete); -// solver.optimize(); -// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate(); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); + EXPECT(assert_equal(mpe, actualMPE)); } /* ************************************************************************* */ @@ -115,10 +101,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, test) -{ +TEST(DiscreteFactorGraph, test) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); // A simple factor graph (A)-fAC-(C)-fBC-(B) // with smoothness priors @@ -127,77 +112,109 @@ TEST( DiscreteFactorGraph, test) graph.add(C & B, "3 1 1 3"); // Test EliminateDiscrete - // FIXME: apparently Eliminate returns a conditional rather than a net Ordering frontalKeys; frontalKeys += Key(0); DiscreteConditional::shared_ptr conditional; DecisionTreeFactor::shared_ptr newFactor; boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); - // Check Bayes net + // Check Conditional CHECK(conditional); - DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - // cout << signature << endl; DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); - expected.add(signature); // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); EXPECT(assert_equal(expectedFactor, *newFactor)); - // add conditionals to complete expected Bayes net - expected.add(B | A = "5/3 3/5"); - expected.add(A % "1/1"); - // GTSAM_PRINT(expected); - - // Test elimination tree + // Test using elimination tree Ordering ordering; ordering += Key(0), Key(1), Key(2); DiscreteEliminationTree etree(graph, ordering); DiscreteBayesNet::shared_ptr actual; DiscreteFactorGraph::shared_ptr remainingGraph; boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); - EXPECT(assert_equal(expected, *actual)); - -// // Test solver -// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); -// EXPECT(assert_equal(expected, *actual2)); - // Test optimization - DiscreteValues expectedValues; - insert(expectedValues)(0, 0)(1, 0)(2, 0); - auto actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, actualValues)); + // Check Bayes net + DiscreteBayesNet expectedBayesNet; + expectedBayesNet.add(signature); + expectedBayesNet.add(B | A = "5/3 3/5"); + expectedBayesNet.add(A % "1/1"); + EXPECT(assert_equal(expectedBayesNet, *actual)); + + // Test eliminateSequential + DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); + EXPECT(assert_equal(expectedBayesNet, *actual2)); + + // Test mpe + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE) -{ +TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) { // Declare a bunch of keys - DiscreteKey C(0,2), A(1,2), B(2,2); + DiscreteKey C(0, 2), A(1, 2), B(2, 2); // Create Factor graph DiscreteFactorGraph graph; graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // graph.product().print(); - // DiscreteSequentialSolver(graph).eliminate()->print(); - auto actualMPE = graph.optimize(); + // Created expected MPE + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 1)(2, 1); + + // Do max-product with different orderings + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + DiscreteLookupDAG dag = graph.maxProduct(orderingType); + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); + auto actualMPE2 = graph.optimize(); // all in one + EXPECT(assert_equal(mpe, actualMPE2)); + } +} - DiscreteValues expectedMPE; - insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, actualMPE)); +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, marginalIsNotMPE) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + // Optimize on BayesNet maximizes marginal, then the conditional marginals: + auto notOptimal = bayesNet.optimize(); + EXPECT(graph(notOptimal) < graph(mpe)); + EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +#endif } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -{ +TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { // The factor graph in Darwiche09book, page 244 - DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); + DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2); // Create Factor graph DiscreteFactorGraph graph; @@ -206,53 +223,35 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); - graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) - //graph.product().print("Darwiche-product"); - // graph.product().potentials().dot("Darwiche-product"); - // DiscreteSequentialSolver(graph).eliminate()->print(); - - DiscreteValues expectedMPE; - insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); - - // Use the solver machinery. - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, actualMPE)); -// DiscreteConditional::shared_ptr root = chordal->back(); -// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); - - // Let us create the Bayes tree here, just for fun, because we don't use it now -// typedef JunctionTreeOrdered JT; -// GenericMultifrontalSolver solver(graph); -// BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); -//// bayesTree->print("Bayes Tree"); -// EXPECT_LONGS_EQUAL(2,bayesTree->size()); + graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) - Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4); - DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); - // bayesTree->print("Bayes Tree"); - EXPECT_LONGS_EQUAL(2,bayesTree->size()); + DiscreteValues mpe; + insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1); + EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression + // You can check visually by printing product: + // graph.product().print("Darwiche-product"); -#ifdef OLD -// Create the elimination tree manually -VariableIndexOrdered structure(graph); -typedef EliminationTreeOrdered ETree; -ETree::shared_ptr eTree = ETree::Create(graph, structure); -//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<"); - -// eliminate normally and check solution -DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); -// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -auto actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, actualMPE)); - -// Approximate and check solution -// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); -// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<"); -// EXPECT(assert_equal(expectedMPE, *actualMPE)); + // Check MPE. + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + + // Check Bayes Net + Ordering ordering; + ordering += Key(0), Key(1), Key(2), Key(3), Key(4); + auto chordal = graph.eliminateSequential(ordering); + EXPECT_LONGS_EQUAL(5, chordal->size()); +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + auto notOptimal = chordal->optimize(); // not MPE ! + EXPECT(graph(notOptimal) < graph(mpe)); #endif + + // Let us create the Bayes tree here, just for fun, because we don't use it + DiscreteBayesTree::shared_ptr bayesTree = + graph.eliminateMultifrontal(ordering); + // bayesTree->print("Bayes Tree"); + EXPECT_LONGS_EQUAL(2, bayesTree->size()); } + #ifdef OLD /* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testDiscreteLookupDAG.cpp b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp new file mode 100644 index 0000000000..04b8597804 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteLookupDAG.cpp + * + * @date January, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using namespace gtsam; +using namespace boost::assign; + +/* ************************************************************************* */ +TEST(DiscreteLookupDAG, argmax) { + using ADT = AlgebraicDecisionTree; + + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create lookup table corresponding to "marginalIsNotMPE" in testDFG. + DiscreteLookupDAG dag; + + ADT adtB(DiscreteKeys{B, A}, std::vector{0.5, 1. / 3, 0.5, 2. / 3}); + dag.add(1, DiscreteKeys{B, A}, adtB); + + ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19)); + dag.add(1, DiscreteKeys{A}, adtA); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // check: + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 295122879e..7594da78d0 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -25,15 +25,12 @@ namespace gtsam { /** - * TODO: Update comments. The following comments are out of date!!! - * - * Base class for conditional densities, templated on KEY type. This class - * provides storage for the keys involved in a conditional, and iterators and + * Base class for conditional densities. This class iterators and * access to the frontal and separator keys. * * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * to the associated factor type and shared_ptr type of the derived class. See - * IndexConditional and GaussianConditional for examples. + * SymbolicConditional and GaussianConditional for examples. * \nosubgrouping */ template diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index e204a67796..08143c469f 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -14,18 +14,6 @@ using namespace std; namespace gtsam { -/// Find the best total assignment - can be expensive -DiscreteValues CSP::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); - return chordal->optimize(); -} - -/// Find the best total assignment - can be expensive -DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); - return chordal->optimize(); -} - bool CSP::runArcConsistency(const VariableIndex& index, Domains* domains) const { bool changed = false; diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index e7fb301156..40853bed66 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // return result; // } - /// Find the best total assignment - can be expensive. - DiscreteValues optimalAssignment() const; - - /// Find the best total assignment, with given ordering - can be expensive. - DiscreteValues optimalAssignment(const Ordering& ordering) const; - // /* // * Perform loopy belief propagation // * True belief propagation would check for each value in domain diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index f166405932..b86df6c290 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { return chordal; } -/** Find the best total assignment - can be expensive */ -DiscreteValues Scheduler::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = eliminate(); - - if (ISDEBUG("Scheduler::optimalAssignment")) { - DiscreteBayesNet::const_iterator it = chordal->end() - 1; - const Student& student = students_.front(); - cout << endl; - (*it)->print(student.name_); - } - - gttic(my_optimize); - DiscreteValues mpe = chordal->optimize(); - gttoc(my_optimize); - return mpe; -} - /** find the assignment of students to slots with most possible committees */ DiscreteValues Scheduler::bestSchedule() const { DiscreteValues best; diff --git a/gtsam_unstable/discrete/Scheduler.h b/gtsam_unstable/discrete/Scheduler.h index a97368bb25..8d269e81a6 100644 --- a/gtsam_unstable/discrete/Scheduler.h +++ b/gtsam_unstable/discrete/Scheduler.h @@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { /** Eliminate, return a Bayes net */ DiscreteBayesNet::shared_ptr eliminate() const; - /** Find the best total assignment - can be expensive */ - DiscreteValues optimalAssignment() const; - /** find the assignment of students to slots with most possible committees */ DiscreteValues bestSchedule() const; diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index 2a9addf918..487edc97a3 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -122,7 +122,7 @@ void runLargeExample() { // SETDEBUG("timing-verbose", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true); gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(6 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -319,11 +319,11 @@ void accomodateStudent() { // GTSAM_PRINT(*chordal); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(0); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 8260bfb068..830d59ba73 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -143,7 +143,7 @@ void runLargeExample() { } #else gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index cf3ce04535..b24f9bf0a4 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -167,7 +167,7 @@ void runLargeExample() { } #else gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); @@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; double count = (*root)(values); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 88defd9860..fb386b2553 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -132,7 +132,7 @@ TEST(CSP, allInOne) { EXPECT(assert_equal(expectedProduct, product)); // Solve - auto mpe = csp.optimalAssignment(); + auto mpe = csp.optimize(); DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); EXPECT(assert_equal(expected, mpe)); @@ -172,22 +172,18 @@ TEST(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0); + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); + // Solve using that ordering: - auto mpe = csp.optimalAssignment(ordering); - // GTSAM_PRINT(mpe); - DiscreteValues expected; - insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)( - MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)( - UT.first, 1)(AZ.first, 0); + auto actualMPE = csp.optimize(ordering); - // TODO: Fix me! mpe result seems to be right. (See the printing) - // It has the same prob as the expected solution. - // Is mpe another solution, or the expected solution is unique??? - EXPECT(assert_equal(expected, mpe)); + EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); // Write out the dual graph for hmetis @@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Solve - auto mpe = csp.optimalAssignment(); + auto mpe = csp.optimize(); DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); EXPECT(assert_equal(expected, mpe)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index 7822cbd38b..086057a466 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -122,7 +122,7 @@ TEST(schedulingExample, test) { // Do exact inference gttic(small); - auto MPE = s.optimalAssignment(); + auto MPE = s.optimize(); gttoc(small); // print MPE, commented out as unit tests don't print diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index 35f3ba8437..8b28581699 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -100,7 +100,7 @@ class Sudoku : public CSP { /// solve and print solution void printSolution() const { - auto MPE = optimalAssignment(); + auto MPE = optimize(); printAssignment(MPE); } @@ -126,7 +126,7 @@ TEST(Sudoku, small) { 0, 1, 0, 0); // optimize and check - auto solution = csp.optimalAssignment(); + auto solution = csp.optimize(); DiscreteValues expected; insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( @@ -148,7 +148,7 @@ TEST(Sudoku, small) { EXPECT_LONGS_EQUAL(16, new_csp.size()); // Check that solution - auto new_solution = new_csp.optimalAssignment(); + auto new_solution = new_csp.optimize(); // csp.printAssignment(new_solution); EXPECT(assert_equal(expected, new_solution)); } @@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) { EXPECT_LONGS_EQUAL(81, new_csp.size()); // Check that solution - auto solution = new_csp.optimalAssignment(); + auto solution = new_csp.optimize(); // csp.printAssignment(solution); EXPECT_LONGS_EQUAL(6, solution.at(key99)); } diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 6abd660cfc..3ae3b625cd 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -79,7 +79,7 @@ def test_Asia(self): self.gtsamAssertEquals(chordal.at(7), expected2) # solve - actualMPE = chordal.optimize() + actualMPE = fg.optimize() expectedMPE = DiscreteValues() for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: expectedMPE[key[0]] = 0 @@ -94,8 +94,7 @@ def test_Asia(self): fg.add(Dyspnea, "0 1") # solve again, now with evidence - chordal2 = fg.eliminateSequential(ordering) - actualMPE2 = chordal2.optimize() + actualMPE2 = fg.optimize() expectedMPE2 = DiscreteValues() for key in [XRay, Tuberculosis, Either, LungCancer]: expectedMPE2[key[0]] = 0 @@ -105,6 +104,7 @@ def test_Asia(self): list(expectedMPE2.items())) # now sample from it + chordal2 = fg.eliminateSequential(ordering) actualSample = chordal2.sample() self.assertEqual(len(actualSample), 8) @@ -122,10 +122,6 @@ def test_fragment(self): for key in [Asia, Smoking]: given[key[0]] = 0 - # Now optimize fragment: - actual = fragment.optimize(given) - self.assertEqual(len(actual), 5) - # Now sample from fragment: actual = fragment.sample(given) self.assertEqual(len(actual), 5) diff --git a/python/gtsam/tests/test_DiscreteDistribution.py b/python/gtsam/tests/test_DiscreteDistribution.py index fa999fd6b5..3986bf2dfc 100644 --- a/python/gtsam/tests/test_DiscreteDistribution.py +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -20,7 +20,7 @@ X = 0, 2 -class TestDiscretePrior(GtsamTestCase): +class TestDiscreteDistribution(GtsamTestCase): """Tests for Discrete Priors.""" def test_constructor(self):