From 7e956d2bb7829eb9a3e367395938875d543518c2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 10:10:47 -0500 Subject: [PATCH 01/20] Fix docs --- gtsam/discrete/DiscreteConditional.cpp | 6 +++--- gtsam/discrete/DiscreteConditional.h | 2 +- gtsam/inference/Conditional.h | 7 ++----- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index eb31d2e1ea..8c0f91807a 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -248,17 +248,17 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // Get all Possible Configurations const auto allPosbValues = frontalAssignments(); - // Find the MPE + // Find the maximum for (const auto& frontalVals : allPosbValues) { double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Update maximum solution if better if (pValueS > maxP) { maxP = pValueS; mpe = frontalVals; } } - // set values (inPlace) to mpe + // set values (inPlace) to maximum for (Key j : frontals()) { (*values)[j] = mpe[j]; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 5908cc782e..de9d949714 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -182,7 +182,7 @@ class GTSAM_EXPORT DiscreteConditional /** * solve a conditional * @param parentsValues Known values of the parents - * @return MPE value of the child (1 frontal variable). + * @return maximum value for the (single) frontal variable. */ size_t solve(const DiscreteValues& parentsValues) const; diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 295122879e..7594da78d0 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -25,15 +25,12 @@ namespace gtsam { /** - * TODO: Update comments. The following comments are out of date!!! - * - * Base class for conditional densities, templated on KEY type. This class - * provides storage for the keys involved in a conditional, and iterators and + * Base class for conditional densities. This class iterators and * access to the frontal and separator keys. * * Derived classes *must* redefine the Factor and shared_ptr typedefs to refer * to the associated factor type and shared_ptr type of the derived class. See - * IndexConditional and GaussianConditional for examples. + * SymbolicConditional and GaussianConditional for examples. * \nosubgrouping */ template From 0076db7e20c77b8ce8f9c04aab2dc36bd8dd2f93 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 10:11:32 -0500 Subject: [PATCH 02/20] cleanup --- gtsam/discrete/DiscreteKey.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 5ddad22b04..121d611038 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -33,16 +33,13 @@ namespace gtsam { KeyVector DiscreteKeys::indices() const { KeyVector js; - for(const DiscreteKey& key: *this) - js.push_back(key.first); + for (const DiscreteKey& key : *this) js.push_back(key.first); return js; } - map DiscreteKeys::cardinalities() const { - map cs; - cs.insert(begin(),end()); -// for(const DiscreteKey& key: *this) -// cs.insert(key); + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(), end()); return cs; } From 6e4f50dfacc46bf5f73570bff163d51274ece64d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 10:12:07 -0500 Subject: [PATCH 03/20] Better print and new `max` variant --- gtsam/discrete/DecisionTreeFactor.cpp | 9 +++++++-- gtsam/discrete/DecisionTreeFactor.h | 7 ++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ad4cbad434..9de750f2eb 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -22,6 +22,7 @@ #include #include +#include #include using namespace std; @@ -65,9 +66,13 @@ namespace gtsam { /* ************************************************************************* */ void DecisionTreeFactor::print(const string& s, - const KeyFormatter& formatter) const { + const KeyFormatter& formatter) const { cout << s; - ADT::print("Potentials:",formatter); + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + ADT::print("Potentials:", formatter); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 8beeb4c4a0..251575739d 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -127,11 +127,16 @@ namespace gtsam { return combine(keys, ADT::Ring::add); } - /// Create new factor by maximizing over all values with the same separator values + /// Create new factor by maximizing over all values with the same separator. shared_ptr max(size_t nrFrontals) const { return combine(nrFrontals, ADT::Ring::max); } + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, ADT::Ring::max); + } + /// @} /// @name Advanced Interface /// @{ From ec39197cc3fb8b3a547d2b015f52537770eb80eb Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 10:12:31 -0500 Subject: [PATCH 04/20] `optimize` now computes MPE --- gtsam/discrete/DiscreteFactorGraph.cpp | 81 ++++++-- gtsam/discrete/DiscreteFactorGraph.h | 37 ++-- .../tests/testDiscreteFactorGraph.cpp | 183 +++++++++--------- 3 files changed, 189 insertions(+), 112 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index c1248c60b9..d8e9aa244f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -95,22 +95,74 @@ namespace gtsam { // } // } - /* ************************************************************************* */ - DiscreteValues DiscreteFactorGraph::optimize() const - { + /* ************************************************************************ */ + /** + * @brief Lookup table for max-product + * + * This inherits from a DiscreteConditional but is not normalized to 1 + * + */ + class Lookup : public DiscreteConditional { + public: + Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) + : DiscreteConditional(nFrontals, keys, potentials) {} + }; + + // Alternate eliminate function for MPE + std::pair // + EliminateForMPE(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { + // PRODUCT: multiply all factors + gttic(product); + DecisionTreeFactor product; + for (auto&& factor : factors) product = (*factor) * product; + gttoc(product); + + // max out frontals, this is the factor on the separator + gttic(max); + DecisionTreeFactor::shared_ptr max = product.max(frontalKeys); + gttoc(max); + + // Ordering keys for the conditional so that frontalKeys are really in front + DiscreteKeys orderedKeys; + for (auto&& key : frontalKeys) + orderedKeys.emplace_back(key, product.cardinality(key)); + for (auto&& key : max->keys()) + orderedKeys.emplace_back(key, product.cardinality(key)); + + // Make lookup with product + gttic(lookup); + size_t nrFrontals = frontalKeys.size(); + auto lookup = boost::make_shared(nrFrontals, orderedKeys, product); + gttoc(lookup); + + return std::make_pair( + boost::dynamic_pointer_cast(lookup), max); + } + + /* ************************************************************************ */ + DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct( + OptionalOrderingType orderingType) const { + gttic(DiscreteFactorGraph_maxProduct); + return BaseEliminateable::eliminateSequential(orderingType, + EliminateForMPE); + } + + /* ************************************************************************ */ + DiscreteValues DiscreteFactorGraph::optimize( + OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - return BaseEliminateable::eliminateSequential()->optimize(); + return maxProduct()->optimize(); } - /* ************************************************************************* */ + /* ************************************************************************ */ std::pair // - EliminateDiscrete(const DiscreteFactorGraph& factors, const Ordering& frontalKeys) { - + EliminateDiscrete(const DiscreteFactorGraph& factors, + const Ordering& frontalKeys) { // PRODUCT: multiply all factors gttic(product); DecisionTreeFactor product; - for(const DiscreteFactor::shared_ptr& factor: factors) - product = (*factor) * product; + for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator @@ -120,15 +172,18 @@ namespace gtsam { // Ordering keys for the conditional so that frontalKeys are really in front Ordering orderedKeys; - orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); - orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); + orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), + frontalKeys.end()); + orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), + sum->keys().end()); // now divide product/sum to get conditional gttic(divide); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum, orderedKeys)); + auto conditional = + boost::make_shared(product, *sum, orderedKeys); gttoc(divide); - return std::make_pair(cond, sum); + return std::make_pair(conditional, sum); } /* ************************************************************************ */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 1da840eb8e..b4e98c876c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -128,18 +128,31 @@ class GTSAM_EXPORT DiscreteFactorGraph const std::string& s = "DiscreteFactorGraph", const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /** Solve the factor graph by performing variable elimination in COLAMD order using - * the dense elimination function specified in \c function, - * followed by back-substitution resulting from elimination. Is equivalent - * to calling graph.eliminateSequential()->optimize(). */ - DiscreteValues optimize() const; - - -// /** Permute the variables in the factors */ -// GTSAM_EXPORT void permuteWithInverse(const Permutation& inversePermutation); -// -// /** Apply a reduction, which is a remapping of variable indices. */ -// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction); + /** + * @brief Implement the max-product algorithm + * + * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM + * @return DiscreteBayesNet::shared_ptr DAG with lookup tables + */ + boost::shared_ptr maxProduct( + OptionalOrderingType orderingType = boost::none) const; + + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param orderingType + * @return DiscreteValues : MPE + */ + DiscreteValues optimize( + OptionalOrderingType orderingType = boost::none) const; + + // /** Permute the variables in the factors */ + // GTSAM_EXPORT void permuteWithInverse(const Permutation& + // inversePermutation); + // + // /** Apply a reduction, which is a remapping of variable indices. */ + // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& + // inverseReduction); /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 579244c57f..14432d08cb 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -30,8 +30,8 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { - DiscreteKey PC(0,4), ME(1, 4), AI(2, 4), A(3, 3); +TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { + DiscreteKey PC(0, 4), ME(1, 4), AI(2, 4), A(3, 3); DiscreteFactorGraph graph; graph.add(AI, "1 0 0 1"); @@ -47,25 +47,18 @@ TEST_UNSAFE( DiscreteFactorGraph, debugScheduler) { graph.add(PC & ME, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); graph.add(PC & AI, "0 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0"); -// graph.print("Graph: "); - DecisionTreeFactor product = graph.product(); - DecisionTreeFactor::shared_ptr sum = product.sum(1); -// sum->print("Debug SUM: "); - DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); - -// cond->print("marginal:"); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); + EXPECT(assert_equal(mpe, actualMPE)); -// pair result = EliminateDiscrete(graph, 1); -// result.first->print("BayesNet: "); -// result.second->print("New factor: "); -// + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3); - DiscreteEliminationTree eliminationTree(graph, ordering); -// eliminationTree.print("Elimination tree: "); - eliminationTree.eliminate(EliminateDiscrete); -// solver.optimize(); -// DiscreteBayesNet::shared_ptr bayesNet = solver.eliminate(); + ordering += Key(0), Key(1), Key(2), Key(3); + auto chordal = graph.eliminateSequential(ordering); + // happens to be the same, but not in general! + EXPECT(assert_equal(mpe, chordal->optimize())); } /* ************************************************************************* */ @@ -115,10 +108,9 @@ TEST_UNSAFE( DiscreteFactorGraph, DiscreteFactorGraphEvaluationTest) { } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, test) -{ +TEST(DiscreteFactorGraph, test) { // Declare keys and ordering - DiscreteKey C(0,2), B(1,2), A(2,2); + DiscreteKey C(0, 2), B(1, 2), A(2, 2); // A simple factor graph (A)-fAC-(C)-fBC-(B) // with smoothness priors @@ -127,7 +119,6 @@ TEST( DiscreteFactorGraph, test) graph.add(C & B, "3 1 1 3"); // Test EliminateDiscrete - // FIXME: apparently Eliminate returns a conditional rather than a net Ordering frontalKeys; frontalKeys += Key(0); DiscreteConditional::shared_ptr conditional; @@ -138,7 +129,7 @@ TEST( DiscreteFactorGraph, test) CHECK(conditional); DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - // cout << signature << endl; + DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); expected.add(signature); @@ -151,7 +142,6 @@ TEST( DiscreteFactorGraph, test) // add conditionals to complete expected Bayes net expected.add(B | A = "5/3 3/5"); expected.add(A % "1/1"); - // GTSAM_PRINT(expected); // Test elimination tree Ordering ordering; @@ -162,42 +152,82 @@ TEST( DiscreteFactorGraph, test) boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); EXPECT(assert_equal(expected, *actual)); -// // Test solver -// DiscreteBayesNet::shared_ptr actual2 = solver.eliminate(); -// EXPECT(assert_equal(expected, *actual2)); + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression + + // Check Bayes Net + auto chordal = graph.eliminateSequential(); + auto notOptimal = chordal->optimize(); + // happens to be the same but not in general! + EXPECT(assert_equal(mpe, notOptimal)); - // Test optimization - DiscreteValues expectedValues; - insert(expectedValues)(0, 0)(1, 0)(2, 0); - auto actualValues = graph.optimize(); - EXPECT(assert_equal(expectedValues, actualValues)); + // Test eliminateSequential + DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); + EXPECT(assert_equal(expected, *actual2)); + auto notOptimal2 = actual2->optimize(); + // happens to be the same but not in general! + EXPECT(assert_equal(mpe, notOptimal2)); + + // Test mpe + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE) -{ +TEST_UNSAFE(DiscreteFactorGraph, testMPE) { // Declare a bunch of keys - DiscreteKey C(0,2), A(1,2), B(2,2); + DiscreteKey C(0, 2), A(1, 2), B(2, 2); // Create Factor graph DiscreteFactorGraph graph; graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // graph.product().print(); - // DiscreteSequentialSolver(graph).eliminate()->print(); + // Check MPE. + auto actualMPE = graph.optimize(); + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 1)(2, 1); + EXPECT(assert_equal(mpe, actualMPE)); + + // Check Bayes Net + auto chordal = graph.eliminateSequential(); + auto notOptimal = chordal->optimize(); + // happens to be the same but not in general + EXPECT(assert_equal(mpe, notOptimal)); +} + +/* ************************************************************************* */ +TEST(DiscreteFactorGraph, marginalIsNotMPE) { + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create Bayes net such that marginal on A is bigger for 0 than 1, but the + // MPE does not have A=0. + DiscreteBayesNet bayesNet; + bayesNet.add(B | A = "1/1 1/2"); + bayesNet.add(A % "10/9"); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // Which we verify using max-product: + DiscreteFactorGraph graph(bayesNet); auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression - DiscreteValues expectedMPE; - insert(expectedMPE)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(expectedMPE, actualMPE)); + // Optimize on BayesNet maximizes marginal, then the conditional marginals: + auto notOptimal = bayesNet.optimize(); + EXPECT(graph(notOptimal) < graph(mpe)); + EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression } /* ************************************************************************* */ -TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) -{ +TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { // The factor graph in Darwiche09book, page 244 - DiscreteKey A(4,2), C(3,2), S(2,2), T1(0,2), T2(1,2); + DiscreteKey A(4, 2), C(3, 2), S(2, 2), T1(0, 2), T2(1, 2); // Create Factor graph DiscreteFactorGraph graph; @@ -206,53 +236,32 @@ TEST( DiscreteFactorGraph, testMPE_Darwiche09book_p244) graph.add(C & T1, "0.80 0.20 0.20 0.80"); graph.add(S & C & T2, "0.80 0.20 0.20 0.80 0.95 0.05 0.05 0.95"); graph.add(T1 & T2 & A, "1 0 0 1 0 1 1 0"); - graph.add(A, "1 0");// evidence, A = yes (first choice in Darwiche) - //graph.product().print("Darwiche-product"); - // graph.product().potentials().dot("Darwiche-product"); - // DiscreteSequentialSolver(graph).eliminate()->print(); - - DiscreteValues expectedMPE; - insert(expectedMPE)(4, 0)(2, 0)(3, 1)(0, 1)(1, 1); - - // Use the solver machinery. - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto actualMPE = chordal->optimize(); - EXPECT(assert_equal(expectedMPE, actualMPE)); -// DiscreteConditional::shared_ptr root = chordal->back(); -// EXPECT_DOUBLES_EQUAL(0.4, (*root)(*actualMPE), 1e-9); - - // Let us create the Bayes tree here, just for fun, because we don't use it now -// typedef JunctionTreeOrdered JT; -// GenericMultifrontalSolver solver(graph); -// BayesTreeOrdered::shared_ptr bayesTree = solver.eliminate(&EliminateDiscrete); -//// bayesTree->print("Bayes Tree"); -// EXPECT_LONGS_EQUAL(2,bayesTree->size()); + graph.add(A, "1 0"); // evidence, A = yes (first choice in Darwiche) + + DiscreteValues mpe; + insert(mpe)(4, 0)(2, 1)(3, 1)(0, 1)(1, 1); + EXPECT_DOUBLES_EQUAL(0.33858, graph(mpe), 1e-5); // regression + // You can check visually by printing product: + // graph.product().print("Darwiche-product"); + // Check MPE. + auto actualMPE = graph.optimize(); + EXPECT(assert_equal(mpe, actualMPE)); + + // Check Bayes Net Ordering ordering; - ordering += Key(0),Key(1),Key(2),Key(3),Key(4); - DiscreteBayesTree::shared_ptr bayesTree = graph.eliminateMultifrontal(ordering); + ordering += Key(0), Key(1), Key(2), Key(3), Key(4); + auto chordal = graph.eliminateSequential(ordering); + auto notOptimal = chordal->optimize(); // not MPE ! + EXPECT(graph(notOptimal) < graph(mpe)); + + // Let us create the Bayes tree here, just for fun, because we don't use it + DiscreteBayesTree::shared_ptr bayesTree = + graph.eliminateMultifrontal(ordering); // bayesTree->print("Bayes Tree"); - EXPECT_LONGS_EQUAL(2,bayesTree->size()); - -#ifdef OLD -// Create the elimination tree manually -VariableIndexOrdered structure(graph); -typedef EliminationTreeOrdered ETree; -ETree::shared_ptr eTree = ETree::Create(graph, structure); -//eTree->print(">>>>>>>>>>> Elimination Tree <<<<<<<<<<<<<<<<<"); - -// eliminate normally and check solution -DiscreteBayesNet::shared_ptr bayesNet = eTree->eliminate(&EliminateDiscrete); -// bayesNet->print(">>>>>>>>>>>>>> Bayes Net <<<<<<<<<<<<<<<<<<"); -auto actualMPE = optimize(*bayesNet); -EXPECT(assert_equal(expectedMPE, actualMPE)); - -// Approximate and check solution -// DiscreteBayesNet::shared_ptr approximateNet = eTree->approximate(); -// approximateNet->print(">>>>>>>>>>>>>> Approximate Net <<<<<<<<<<<<<<<<<<"); -// EXPECT(assert_equal(expectedMPE, *actualMPE)); -#endif + EXPECT_LONGS_EQUAL(2, bayesTree->size()); } + #ifdef OLD /* ************************************************************************* */ From 34a3b022d948a93b8e741a2366e2d20316e1b52e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 13:08:16 -0500 Subject: [PATCH 05/20] New lookup classes --- gtsam/discrete/DiscreteLookupDAG.cpp | 153 +++++++++++++++++++++++++++ gtsam/discrete/DiscreteLookupDAG.h | 138 ++++++++++++++++++++++++ 2 files changed, 291 insertions(+) create mode 100644 gtsam/discrete/DiscreteLookupDAG.cpp create mode 100644 gtsam/discrete/DiscreteLookupDAG.h diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp new file mode 100644 index 0000000000..37e45de80e --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -0,0 +1,153 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupTable.cpp + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +#include +#include + +using std::pair; +using std::vector; + +namespace gtsam { + +// Instantiate base class +template class GTSAM_EXPORT + Conditional; + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +void DiscreteLookupTable::print(const std::string& s, + const KeyFormatter& formatter) const { + using std::cout; + using std::endl; + + cout << s << " g( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "; "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + ADT::print("", formatter); + cout << endl; +} + +/* ************************************************************************* */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +vector DiscreteLookupTable::frontalAssignments() const { + vector> pairs; + for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); + vector> rpairs(pairs.rbegin(), pairs.rend()); + return DiscreteValues::CartesianProduct(rpairs); +} + +/* ************************************************************************** */ +// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( +static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional, + const DiscreteValues& given, + bool forceComplete = true) { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteLookupTable::ADT adt(conditional); + size_t value; + for (Key j : conditional.parents()) { + try { + value = given.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + if (forceComplete) { + given.print("parentsValues: "); + throw std::runtime_error( + "DiscreteLookupTable::Choose: parent value missing"); + } + } + } + return adt; +} + +/* ************************************************************************** */ +void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + + // Initialize + DiscreteValues mpe; + double maxP = 0; + + // Get all Possible Configurations + const auto allPosbValues = frontalAssignments(); + + // Find the maximum + for (const auto& frontalVals : allPosbValues) { + double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues) + // Update maximum solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = frontalVals; + } + } + + // set values (inPlace) to maximum + for (Key j : frontals()) { + (*values)[j] = mpe[j]; + } +} + +/* ************************************************************************** */ +size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO(Duy): only works for one key now, seems horribly slow this way + size_t mpe = 0; + DiscreteValues frontals; + double maxP = 0; + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; +} + +/* ************************************************************************** */ +DiscreteValues DiscreteLookupDAG::argmax() const { + DiscreteValues result; + return argmax(result); +} + +DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { + // Argmax each node in turn in topological sort order (parents first). + for (auto lookupTable : boost::adaptors::reverse(*this)) + lookupTable->argmaxInPlace(&result); + return result; +} +/* ************************************************************************** */ + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h new file mode 100644 index 0000000000..a69b0b1eea --- /dev/null +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -0,0 +1,138 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteLookupDAG.h + * @date JAnuary, 2022 + * @author Frank dellaert + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +namespace gtsam { + +/** + * @brief DiscreteLookupTable table for max-product + */ +class DiscreteLookupTable + : public DecisionTreeFactor, + public Conditional { + public: + using This = DiscreteLookupTable; + using shared_ptr = boost::shared_ptr; + using BaseConditional = Conditional; + + /** + * @brief Construct a new Discrete Lookup Table object + * + * @param nFrontals number of frontal variables + * @param keys a orted list of gtsam::Keys + * @param potentials the algebraic decision tree with lookup values + */ + DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials) + : DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {} + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Lookup Table: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /** + * @brief return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues) const; + + /** + * @brief Calculate assignment for frontal variables that maximizes value. + * @param (in/out) parentsValues Known assignments for the parents. + */ + void argmaxInPlace(DiscreteValues* parentsValues) const; + + /// Return all assignments for frontal variables. + std::vector frontalAssignments() const; +}; + +/** A DAG made from lookup tables, as defined above. */ +class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { + public: + using Base = BayesNet; + using This = DiscreteLookupDAG; + using shared_ptr = boost::shared_ptr; + + /// @name Standard Constructors + /// @{ + + /// Construct empty DAG. + DiscreteLookupDAG() {} + + /// Destructor + virtual ~DiscreteLookupDAG() {} + + /// @} + + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This& bn, double tol = 1e-9) const; + + /// @} + + /// @name Standard Interface + /// @{ + + /** + * @brief argmax by back-substitution. + * + * Assumes the DAG is reverse topologically sorted, i.e. last + * conditional will be optimized first. If the DAG resulted from + * eliminating a factor graph, this is true for the elimination ordering. + * + * @return optimal assignment for all variables. + */ + DiscreteValues argmax() const; + + /** + * @brief argmax by back-substitution, given certain variables. + * + * Assumes the DAG is reverse topologically sorted *and* that the + * DAG does not contain any conditionals for the given variables. + * + * @return given assignment extended w. optimal assignment for all variables. + */ + DiscreteValues argmax(DiscreteValues given) const; + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam From fcdb5b43c1d67160b814b7a877a8eda8b1bc3f48 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 13:09:04 -0500 Subject: [PATCH 06/20] Deprecated solve --- gtsam/discrete/DiscreteDistribution.cpp | 17 +++++++++++++++++ gtsam/discrete/DiscreteDistribution.h | 12 +++++++++--- .../discrete/tests/testDiscreteDistribution.cpp | 6 ++++++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DiscreteDistribution.cpp b/gtsam/discrete/DiscreteDistribution.cpp index 7397714709..5f6fba6a28 100644 --- a/gtsam/discrete/DiscreteDistribution.cpp +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -49,4 +49,21 @@ std::vector DiscreteDistribution::pmf() const { return array; } +/* ************************************************************************** */ +size_t DiscreteDistribution::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + double pValueS = (*this)(value); + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + maxValue = value; + } + } + return maxValue; +} + } // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index fae6e355bd..8dcc75733f 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -91,10 +91,10 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { std::vector pmf() const; /** - * solve a conditional - * @return MPE value of the child (1 frontal variable). + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). */ - size_t solve() const { return Base::solve({}); } + size_t argmax() const; /** * sample @@ -103,6 +103,12 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { size_t sample() const { return Base::sample(); } /// @} +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve() const { return Base::solve({}); } + /// @} +#endif }; // DiscreteDistribution diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp index 5c0c42e737..5e59aaa65b 100644 --- a/gtsam/discrete/tests/testDiscreteDistribution.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -74,6 +74,12 @@ TEST(DiscreteDistribution, sample) { prior.sample(); } +/* ************************************************************************* */ +TEST(DiscreteDistribution, argmax) { + DiscreteDistribution prior(X % "2/3"); + EXPECT_LONGS_EQUAL(prior.argmax(), 1); +} + /* ************************************************************************* */ int main() { TestResult tr; From 5add858c24c15df4336f6654ab0d867b6757c1e7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 13:18:28 -0500 Subject: [PATCH 07/20] Now doing MPE with DAG class --- gtsam/discrete/DiscreteFactorGraph.cpp | 43 +++++++----- gtsam/discrete/DiscreteFactorGraph.h | 9 +-- .../tests/testDiscreteFactorGraph.cpp | 68 ++++++++----------- 3 files changed, 60 insertions(+), 60 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index d8e9aa244f..a166fdce93 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -96,18 +97,6 @@ namespace gtsam { // } /* ************************************************************************ */ - /** - * @brief Lookup table for max-product - * - * This inherits from a DiscreteConditional but is not normalized to 1 - * - */ - class Lookup : public DiscreteConditional { - public: - Lookup(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) - : DiscreteConditional(nFrontals, keys, potentials) {} - }; - // Alternate eliminate function for MPE std::pair // EliminateForMPE(const DiscreteFactorGraph& factors, @@ -133,7 +122,8 @@ namespace gtsam { // Make lookup with product gttic(lookup); size_t nrFrontals = frontalKeys.size(); - auto lookup = boost::make_shared(nrFrontals, orderedKeys, product); + auto lookup = boost::make_shared(nrFrontals, + orderedKeys, product); gttoc(lookup); return std::make_pair( @@ -141,18 +131,37 @@ namespace gtsam { } /* ************************************************************************ */ - DiscreteBayesNet::shared_ptr DiscreteFactorGraph::maxProduct( + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - return BaseEliminateable::eliminateSequential(orderingType, - EliminateForMPE); + + // The solution below is a bitclunky: the elimination machinery does not + // allow for differently *typed* versions of elimination, so we eliminate + // into a Bayes Net using the special eliminate function above, and then + // create the DiscreteLookupDAG after the fact, in linear time. + auto bayesNet = + BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + + // Copy to the DAG + DiscreteLookupDAG dag; + for (auto&& conditional : *bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; } /* ************************************************************************ */ DiscreteValues DiscreteFactorGraph::optimize( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - return maxProduct()->optimize(); + DiscreteLookupDAG dag = maxProduct(); + return dag.argmax(); } /* ************************************************************************ */ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index b4e98c876c..7c658f5484 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -18,10 +18,11 @@ #pragma once -#include +#include +#include #include +#include #include -#include #include #include @@ -132,9 +133,9 @@ class GTSAM_EXPORT DiscreteFactorGraph * @brief Implement the max-product algorithm * * @param orderingType : one of COLAMD, METIS, NATURAL, CUSTOM - * @return DiscreteBayesNet::shared_ptr DAG with lookup tables + * @return DiscreteLookupDAG::shared_ptr DAG with lookup tables */ - boost::shared_ptr maxProduct( + DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; /** diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index 14432d08cb..e63cc26b8f 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -52,13 +52,6 @@ TEST_UNSAFE(DiscreteFactorGraph, debugScheduler) { DiscreteValues mpe; insert(mpe)(0, 2)(1, 1)(2, 0)(3, 0); EXPECT(assert_equal(mpe, actualMPE)); - - // Check Bayes Net - Ordering ordering; - ordering += Key(0), Key(1), Key(2), Key(3); - auto chordal = graph.eliminateSequential(ordering); - // happens to be the same, but not in general! - EXPECT(assert_equal(mpe, chordal->optimize())); } /* ************************************************************************* */ @@ -125,57 +118,46 @@ TEST(DiscreteFactorGraph, test) { DecisionTreeFactor::shared_ptr newFactor; boost::tie(conditional, newFactor) = EliminateDiscrete(graph, frontalKeys); - // Check Bayes net + // Check Conditional CHECK(conditional); - DiscreteBayesNet expected; Signature signature((C | B, A) = "9/1 1/1 1/1 1/9"); - DiscreteConditional expectedConditional(signature); EXPECT(assert_equal(expectedConditional, *conditional)); - expected.add(signature); // Check Factor CHECK(newFactor); DecisionTreeFactor expectedFactor(B & A, "10 6 6 10"); EXPECT(assert_equal(expectedFactor, *newFactor)); - // add conditionals to complete expected Bayes net - expected.add(B | A = "5/3 3/5"); - expected.add(A % "1/1"); - - // Test elimination tree + // Test using elimination tree Ordering ordering; ordering += Key(0), Key(1), Key(2); DiscreteEliminationTree etree(graph, ordering); DiscreteBayesNet::shared_ptr actual; DiscreteFactorGraph::shared_ptr remainingGraph; boost::tie(actual, remainingGraph) = etree.eliminate(&EliminateDiscrete); - EXPECT(assert_equal(expected, *actual)); - - DiscreteValues mpe; - insert(mpe)(0, 0)(1, 0)(2, 0); - EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression - // Check Bayes Net - auto chordal = graph.eliminateSequential(); - auto notOptimal = chordal->optimize(); - // happens to be the same but not in general! - EXPECT(assert_equal(mpe, notOptimal)); + // Check Bayes net + DiscreteBayesNet expectedBayesNet; + expectedBayesNet.add(signature); + expectedBayesNet.add(B | A = "5/3 3/5"); + expectedBayesNet.add(A % "1/1"); + EXPECT(assert_equal(expectedBayesNet, *actual)); // Test eliminateSequential DiscreteBayesNet::shared_ptr actual2 = graph.eliminateSequential(ordering); - EXPECT(assert_equal(expected, *actual2)); - auto notOptimal2 = actual2->optimize(); - // happens to be the same but not in general! - EXPECT(assert_equal(mpe, notOptimal2)); + EXPECT(assert_equal(expectedBayesNet, *actual2)); // Test mpe + DiscreteValues mpe; + insert(mpe)(0, 0)(1, 0)(2, 0); auto actualMPE = graph.optimize(); EXPECT(assert_equal(mpe, actualMPE)); + EXPECT_DOUBLES_EQUAL(9, graph(mpe), 1e-5); // regression } /* ************************************************************************* */ -TEST_UNSAFE(DiscreteFactorGraph, testMPE) { +TEST_UNSAFE(DiscreteFactorGraph, testMaxProduct) { // Declare a bunch of keys DiscreteKey C(0, 2), A(1, 2), B(2, 2); @@ -184,17 +166,20 @@ TEST_UNSAFE(DiscreteFactorGraph, testMPE) { graph.add(C & A, "0.2 0.8 0.3 0.7"); graph.add(C & B, "0.1 0.9 0.4 0.6"); - // Check MPE. - auto actualMPE = graph.optimize(); + // Created expected MPE DiscreteValues mpe; insert(mpe)(0, 0)(1, 1)(2, 1); - EXPECT(assert_equal(mpe, actualMPE)); - // Check Bayes Net - auto chordal = graph.eliminateSequential(); - auto notOptimal = chordal->optimize(); - // happens to be the same but not in general - EXPECT(assert_equal(mpe, notOptimal)); + // Do max-product with different orderings + for (Ordering::OrderingType orderingType : + {Ordering::COLAMD, Ordering::METIS, Ordering::NATURAL, + Ordering::CUSTOM}) { + DiscreteLookupDAG dag = graph.maxProduct(orderingType); + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); + auto actualMPE2 = graph.optimize(); // all in one + EXPECT(assert_equal(mpe, actualMPE2)); + } } /* ************************************************************************* */ @@ -218,10 +203,12 @@ TEST(DiscreteFactorGraph, marginalIsNotMPE) { EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(0.315789, graph(mpe), 1e-5); // regression +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 // Optimize on BayesNet maximizes marginal, then the conditional marginals: auto notOptimal = bayesNet.optimize(); EXPECT(graph(notOptimal) < graph(mpe)); EXPECT_DOUBLES_EQUAL(0.263158, graph(notOptimal), 1e-5); // regression +#endif } /* ************************************************************************* */ @@ -252,8 +239,11 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4); auto chordal = graph.eliminateSequential(ordering); + EXPECT_LONGS_EQUAL(2, chordal->size()); +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 auto notOptimal = chordal->optimize(); // not MPE ! EXPECT(graph(notOptimal) < graph(mpe)); +#endif // Let us create the Bayes tree here, just for fun, because we don't use it DiscreteBayesTree::shared_ptr bayesTree = From 756430074478dd8c91b40a1519961752dca02633 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 13:18:46 -0500 Subject: [PATCH 08/20] deprecated solve --- gtsam/discrete/DiscreteConditional.cpp | 20 +++++++++++--------- gtsam/discrete/DiscreteConditional.h | 18 ++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 8c0f91807a..db0ef1048e 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -238,6 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( } /* ************************************************************************** */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void DiscreteConditional::solveInPlace(DiscreteValues* values) const { ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) @@ -264,14 +265,6 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { } } -/* ******************************************************************************** */ -void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { - assert(nrFrontals() == 1); - Key j = (firstFrontalKey()); - size_t sampled = sample(*values); // Sample variable given parents - (*values)[j] = sampled; // store result in partial solution -} - /* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) @@ -294,8 +287,17 @@ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { } return mpe; } +#endif -/* ******************************************************************************** */ +/* ************************************************************************** */ +void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { + assert(nrFrontals() == 1); + Key j = (firstFrontalKey()); + size_t sampled = sample(*values); // Sample variable given parents + (*values)[j] = sampled; // store result in partial solution +} + +/* ************************************************************************** */ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index de9d949714..ef0a4c9072 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -179,13 +179,6 @@ class GTSAM_EXPORT DiscreteConditional /** Single variable version of likelihood. */ DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; - /** - * solve a conditional - * @param parentsValues Known values of the parents - * @return maximum value for the (single) frontal variable. - */ - size_t solve(const DiscreteValues& parentsValues) const; - /** * sample * @param parentsValues Known values of the parents @@ -203,9 +196,6 @@ class GTSAM_EXPORT DiscreteConditional /// @name Advanced Interface /// @{ - /// solve a conditional, in place - void solveInPlace(DiscreteValues* parentsValues) const; - /// sample in place, stores result in partial solution void sampleInPlace(DiscreteValues* parentsValues) const; @@ -228,6 +218,14 @@ class GTSAM_EXPORT DiscreteConditional const Names& names = {}) const override; /// @} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + size_t GTSAM_DEPRECATED solve(const DiscreteValues& parentsValues) const; + void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; + /// @} +#endif }; // DiscreteConditional From e22f8f04bc7352adbcce3d59be4cdfcec3e9f602 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 13:18:54 -0500 Subject: [PATCH 09/20] deprecated optimize --- gtsam/discrete/DiscreteBayesNet.cpp | 7 ++++ gtsam/discrete/DiscreteBayesNet.h | 36 +++++++------------ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 15 +------- 3 files changed, 20 insertions(+), 38 deletions(-) diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index 7294c8b296..ccc52585e6 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -43,6 +43,7 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const { } /* ************************************************************************* */ +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues result; return optimize(result); @@ -50,10 +51,16 @@ DiscreteValues DiscreteBayesNet::optimize() const { DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const { // solve each node in turn in topological sort order (parents first) +#ifdef _MSC_VER +#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!") +#else +#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!" +#endif for (auto conditional : boost::adaptors::reverse(*this)) conditional->solveInPlace(&result); return result; } +#endif /* ************************************************************************* */ DiscreteValues DiscreteBayesNet::sample() const { diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index bd5536135a..4916cad7c0 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -31,12 +31,12 @@ namespace gtsam { -/** A Bayes net made from linear-Discrete densities */ +/** A Bayes net made from discrete conditional distributions. */ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet { public: - typedef FactorGraph Base; + typedef BayesNet Base; typedef DiscreteBayesNet This; typedef DiscreteConditional ConditionalType; typedef boost::shared_ptr shared_ptr; @@ -45,7 +45,7 @@ namespace gtsam { /// @name Standard Constructors /// @{ - /** Construct empty factor graph */ + /// Construct empty Bayes net. DiscreteBayesNet() {} /** Construct from iterator over conditionals */ @@ -98,27 +98,6 @@ namespace gtsam { return evaluate(values); } - /** - * @brief solve by back-substitution. - * - * Assumes the Bayes net is reverse topologically sorted, i.e. last - * conditional will be optimized first. If the Bayes net resulted from - * eliminating a factor graph, this is true for the elimination ordering. - * - * @return a sampled value for all variables. - */ - DiscreteValues optimize() const; - - /** - * @brief solve by back-substitution, given certain variables. - * - * Assumes the Bayes net is reverse topologically sorted *and* that the - * Bayes net does not contain any conditionals for the given values. - * - * @return given values extended with optimized value for other variables. - */ - DiscreteValues optimize(DiscreteValues given) const; - /** * @brief do ancestral sampling * @@ -152,7 +131,16 @@ namespace gtsam { std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, const DiscreteFactor::Names& names = {}) const; + ///@} + +#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 + /// @name Deprecated functionality + /// @{ + + DiscreteValues GTSAM_DEPRECATED optimize() const; + DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const; /// @} +#endif private: /** Serialization function */ diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 0ba53c69ab..c35d4742c0 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -106,26 +106,13 @@ TEST(DiscreteBayesNet, Asia) { DiscreteConditional expected2(Bronchitis % "11/9"); EXPECT(assert_equal(expected2, *chordal->back())); - // solve - auto actualMPE = chordal->optimize(); - DiscreteValues expectedMPE; - insert(expectedMPE)(Asia.first, 0)(Dyspnea.first, 0)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 0)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 0); - EXPECT(assert_equal(expectedMPE, actualMPE)); - // add evidence, we were in Asia and we have dyspnea fg.add(Asia, "0 1"); fg.add(Dyspnea, "0 1"); // solve again, now with evidence DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - auto actualMPE2 = chordal2->optimize(); - DiscreteValues expectedMPE2; - insert(expectedMPE2)(Asia.first, 1)(Dyspnea.first, 1)(XRay.first, 0)( - Tuberculosis.first, 0)(Smoking.first, 1)(Either.first, 0)( - LungCancer.first, 0)(Bronchitis.first, 1); - EXPECT(assert_equal(expectedMPE2, actualMPE2)); + EXPECT(assert_equal(expected2, *chordal->back())); // now sample from it DiscreteValues expectedSample; From 2f49612b8c1132fa97457439f57330e8eb914c70 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 14:06:50 -0500 Subject: [PATCH 10/20] MPE now works --- gtsam/discrete/DiscreteLookupDAG.cpp | 4 -- gtsam/discrete/DiscreteLookupDAG.h | 18 ++++-- .../tests/testDiscreteFactorGraph.cpp | 2 +- .../discrete/tests/testDiscreteLookupDAG.cpp | 58 +++++++++++++++++++ 4 files changed, 72 insertions(+), 10 deletions(-) create mode 100644 gtsam/discrete/tests/testDiscreteLookupDAG.cpp diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 37e45de80e..4fe3a53a47 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -27,10 +27,6 @@ using std::vector; namespace gtsam { -// Instantiate base class -template class GTSAM_EXPORT - Conditional; - /* ************************************************************************** */ // TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( void DiscreteLookupTable::print(const std::string& s, diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index a69b0b1eea..31cb3dfbf8 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -22,17 +22,19 @@ #include #include - +#include +#include #include namespace gtsam { /** * @brief DiscreteLookupTable table for max-product + * + * Inherits from discrete conditional for convenience, but is not normalized. + * Is used in pax-product algorithm. */ -class DiscreteLookupTable - : public DecisionTreeFactor, - public Conditional { +class DiscreteLookupTable : public DiscreteConditional { public: using This = DiscreteLookupTable; using shared_ptr = boost::shared_ptr; @@ -47,7 +49,7 @@ class DiscreteLookupTable */ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) - : DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {} + : DiscreteConditional(nFrontals, keys, potentials) {} /// GTSAM-style print void print( @@ -100,6 +102,12 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { /// @name Standard Interface /// @{ + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + /** * @brief argmax by back-substitution. * diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index e63cc26b8f..f4819dab54 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -239,7 +239,7 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4); auto chordal = graph.eliminateSequential(ordering); - EXPECT_LONGS_EQUAL(2, chordal->size()); + EXPECT_LONGS_EQUAL(5, chordal->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 auto notOptimal = chordal->optimize(); // not MPE ! EXPECT(graph(notOptimal) < graph(mpe)); diff --git a/gtsam/discrete/tests/testDiscreteLookupDAG.cpp b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp new file mode 100644 index 0000000000..04b8597804 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteLookupDAG.cpp + * + * @date January, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using namespace gtsam; +using namespace boost::assign; + +/* ************************************************************************* */ +TEST(DiscreteLookupDAG, argmax) { + using ADT = AlgebraicDecisionTree; + + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create lookup table corresponding to "marginalIsNotMPE" in testDFG. + DiscreteLookupDAG dag; + + ADT adtB(DiscreteKeys{B, A}, std::vector{0.5, 1. / 3, 0.5, 2. / 3}); + dag.add(1, DiscreteKeys{B, A}, adtB); + + ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19)); + dag.add(1, DiscreteKeys{A}, adtA); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // check: + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From e713897235ce2dda6363f81675186c09143d361b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 14:26:35 -0500 Subject: [PATCH 11/20] made internal protected choose to avoid copy/paste in Lookup --- gtsam/discrete/DiscreteConditional.cpp | 53 +++++++++++++------------- gtsam/discrete/DiscreteConditional.h | 19 +++++---- gtsam/discrete/DiscreteLookupDAG.cpp | 39 ++----------------- gtsam/discrete/DiscreteLookupDAG.h | 3 -- 4 files changed, 41 insertions(+), 73 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index db0ef1048e..164a45f407 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -16,26 +16,25 @@ * @author Frank Dellaert */ +#include +#include #include #include #include -#include -#include - -#include #include +#include #include +#include #include #include -#include #include -#include +#include using namespace std; +using std::pair; using std::stringstream; using std::vector; -using std::pair; namespace gtsam { // Instantiate base class @@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s, cout << endl; } -/* ******************************************************************************** */ +/* ************************************************************************** */ bool DiscreteConditional::equals(const DiscreteFactor& other, double tol) const { if (!dynamic_cast(&other)) { @@ -159,14 +158,13 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ************************************************************************** */ -static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, - const DiscreteValues& given, - bool forceComplete = true) { +DiscreteConditional::ADT DiscreteConditional::choose( + const DiscreteValues& given, bool forceComplete) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - DiscreteConditional::ADT adt(conditional); + DiscreteConditional::ADT adt(*this); size_t value; - for (Key j : conditional.parents()) { + for (Key j : parents()) { try { value = given.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. @@ -174,7 +172,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, if (forceComplete) { given.print("parentsValues: "); throw runtime_error( - "DiscreteConditional::Choose: parent value missing"); + "DiscreteConditional::choose: parent value missing"); } } } @@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, /* ************************************************************************** */ DiscreteConditional::shared_ptr DiscreteConditional::choose( const DiscreteValues& given) const { - ADT adt = Choose(*this, given, false); // P(F|S=given) + ADT adt = choose(given, false); // P(F|S=given) // Collect all keys not in given. DiscreteKeys dKeys; @@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( return boost::make_shared(discreteKeys, adt); } -/* ******************************************************************************** */ +/* ****************************************************************************/ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( size_t parent_value) const { if (nrFrontals() != 1) @@ -240,7 +238,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( /* ************************************************************************** */ #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 void DiscreteConditional::solveInPlace(DiscreteValues* values) const { - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -267,25 +265,24 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { /* ************************************************************************** */ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining - // TODO, only works for one key now, seems horribly slow this way - size_t mpe = 0; - DiscreteValues frontals; + size_t max = 0; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { frontals[j] = value; double pValueS = pFS(frontals); // P(F=value|S=parentsValues) - // Update MPE solution if better + // Update solution if better if (pValueS > maxP) { maxP = pValueS; - mpe = value; + max = value; } } - return mpe; + return max; } #endif @@ -302,7 +299,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way if (nrFrontals() != 1) { @@ -325,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** */ +/* ******************************************************************************** + */ size_t DiscreteConditional::sample(size_t parent_value) const { if (nrParents() != 1) throw std::invalid_argument( @@ -336,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } -/* ******************************************************************************** */ +/* ******************************************************************************** + */ size_t DiscreteConditional::sample() const { if (nrParents() != 0) throw std::invalid_argument( diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index ef0a4c9072..af05e932bb 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -93,14 +93,14 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) - * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** + /** * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) * Assumes but *does not check* that f(Y)=sum_X f(X,Y). * Makes sure the keys are ordered as given. Does not check orderedKeys. @@ -157,17 +157,17 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** + /** * @brief restrict to given *parent* values. - * + * * Note: does not need be complete set. Examples: - * + * * P(C|D,E) + . -> P(C|D,E) * P(C|D,E) + E -> P(C|D) * P(C|D,E) + D -> P(C|E) * P(C|D,E) + D,E -> P(C) * P(C|D,E) + C -> error! - * + * * @return a shared_ptr to a new DiscreteConditional */ shared_ptr choose(const DiscreteValues& given) const; @@ -226,6 +226,11 @@ class GTSAM_EXPORT DiscreteConditional void GTSAM_DEPRECATED solveInPlace(DiscreteValues* parentsValues) const; /// @} #endif + + protected: + /// Internal version of choose + DiscreteConditional::ADT choose(const DiscreteValues& given, + bool forceComplete) const; }; // DiscreteConditional diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 4fe3a53a47..1edf508a1e 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -49,42 +49,9 @@ void DiscreteLookupTable::print(const std::string& s, cout << endl; } -/* ************************************************************************* */ -// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( -vector DiscreteLookupTable::frontalAssignments() const { - vector> pairs; - for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key)); - vector> rpairs(pairs.rbegin(), pairs.rend()); - return DiscreteValues::CartesianProduct(rpairs); -} - -/* ************************************************************************** */ -// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( -static DiscreteLookupTable::ADT Choose(const DiscreteLookupTable& conditional, - const DiscreteValues& given, - bool forceComplete = true) { - // Get the big decision tree with all the levels, and then go down the - // branches based on the value of the parent variables. - DiscreteLookupTable::ADT adt(conditional); - size_t value; - for (Key j : conditional.parents()) { - try { - value = given.at(j); - adt = adt.choose(j, value); // ADT keeps getting smaller. - } catch (std::out_of_range&) { - if (forceComplete) { - given.print("parentsValues: "); - throw std::runtime_error( - "DiscreteLookupTable::Choose: parent value missing"); - } - } - } - return adt; -} - /* ************************************************************************** */ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { - ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) + ADT pFS = choose(*values, true); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -111,13 +78,13 @@ void DiscreteLookupTable::argmaxInPlace(DiscreteValues* values) const { /* ************************************************************************** */ size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { - ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) + ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO(Duy): only works for one key now, seems horribly slow this way size_t mpe = 0; - DiscreteValues frontals; double maxP = 0; + DiscreteValues frontals; assert(nrFrontals() == 1); Key j = (firstFrontalKey()); for (size_t value = 0; value < cardinality(j); value++) { diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 31cb3dfbf8..1b3a38b406 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -68,9 +68,6 @@ class DiscreteLookupTable : public DiscreteConditional { * @param (in/out) parentsValues Known assignments for the parents. */ void argmaxInPlace(DiscreteValues* parentsValues) const; - - /// Return all assignments for frontal variables. - std::vector frontalAssignments() const; }; /** A DAG made from lookup tables, as defined above. */ From b17fcfb64f77ff2f867a2f5a214a6817aefca708 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 14:47:28 -0500 Subject: [PATCH 12/20] optimalAssignment -> optimize. Not deprecating as in unstable. --- gtsam_unstable/discrete/CSP.cpp | 12 ----------- gtsam_unstable/discrete/CSP.h | 6 ------ gtsam_unstable/discrete/Scheduler.cpp | 17 ---------------- gtsam_unstable/discrete/Scheduler.h | 3 --- .../discrete/examples/schedulingExample.cpp | 2 +- .../discrete/examples/schedulingQuals12.cpp | 2 +- .../discrete/examples/schedulingQuals13.cpp | 2 +- gtsam_unstable/discrete/tests/testCSP.cpp | 20 ++++++++----------- .../discrete/tests/testScheduler.cpp | 2 +- gtsam_unstable/discrete/tests/testSudoku.cpp | 8 ++++---- 10 files changed, 16 insertions(+), 58 deletions(-) diff --git a/gtsam_unstable/discrete/CSP.cpp b/gtsam_unstable/discrete/CSP.cpp index e204a67796..08143c469f 100644 --- a/gtsam_unstable/discrete/CSP.cpp +++ b/gtsam_unstable/discrete/CSP.cpp @@ -14,18 +14,6 @@ using namespace std; namespace gtsam { -/// Find the best total assignment - can be expensive -DiscreteValues CSP::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(); - return chordal->optimize(); -} - -/// Find the best total assignment - can be expensive -DiscreteValues CSP::optimalAssignment(const Ordering& ordering) const { - DiscreteBayesNet::shared_ptr chordal = this->eliminateSequential(ordering); - return chordal->optimize(); -} - bool CSP::runArcConsistency(const VariableIndex& index, Domains* domains) const { bool changed = false; diff --git a/gtsam_unstable/discrete/CSP.h b/gtsam_unstable/discrete/CSP.h index e7fb301156..40853bed66 100644 --- a/gtsam_unstable/discrete/CSP.h +++ b/gtsam_unstable/discrete/CSP.h @@ -43,12 +43,6 @@ class GTSAM_UNSTABLE_EXPORT CSP : public DiscreteFactorGraph { // return result; // } - /// Find the best total assignment - can be expensive. - DiscreteValues optimalAssignment() const; - - /// Find the best total assignment, with given ordering - can be expensive. - DiscreteValues optimalAssignment(const Ordering& ordering) const; - // /* // * Perform loopy belief propagation // * True belief propagation would check for each value in domain diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index f166405932..b86df6c290 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -255,23 +255,6 @@ DiscreteBayesNet::shared_ptr Scheduler::eliminate() const { return chordal; } -/** Find the best total assignment - can be expensive */ -DiscreteValues Scheduler::optimalAssignment() const { - DiscreteBayesNet::shared_ptr chordal = eliminate(); - - if (ISDEBUG("Scheduler::optimalAssignment")) { - DiscreteBayesNet::const_iterator it = chordal->end() - 1; - const Student& student = students_.front(); - cout << endl; - (*it)->print(student.name_); - } - - gttic(my_optimize); - DiscreteValues mpe = chordal->optimize(); - gttoc(my_optimize); - return mpe; -} - /** find the assignment of students to slots with most possible committees */ DiscreteValues Scheduler::bestSchedule() const { DiscreteValues best; diff --git a/gtsam_unstable/discrete/Scheduler.h b/gtsam_unstable/discrete/Scheduler.h index a97368bb25..8d269e81a6 100644 --- a/gtsam_unstable/discrete/Scheduler.h +++ b/gtsam_unstable/discrete/Scheduler.h @@ -147,9 +147,6 @@ class GTSAM_UNSTABLE_EXPORT Scheduler : public CSP { /** Eliminate, return a Bayes net */ DiscreteBayesNet::shared_ptr eliminate() const; - /** Find the best total assignment - can be expensive */ - DiscreteValues optimalAssignment() const; - /** find the assignment of students to slots with most possible committees */ DiscreteValues bestSchedule() const; diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index 2a9addf918..7ed00bcf61 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -122,7 +122,7 @@ void runLargeExample() { // SETDEBUG("timing-verbose", true); SETDEBUG("DiscreteConditional::DiscreteConditional", true); gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 8260bfb068..e6a47f5f86 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -143,7 +143,7 @@ void runLargeExample() { } #else gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index cf3ce04535..82ea16a47a 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -167,7 +167,7 @@ void runLargeExample() { } #else gttic(large); - auto MPE = scheduler.optimalAssignment(); + auto MPE = scheduler.optimize(); gttoc(large); tictoc_finishedIteration(); tictoc_print(); diff --git a/gtsam_unstable/discrete/tests/testCSP.cpp b/gtsam_unstable/discrete/tests/testCSP.cpp index 88defd9860..fb386b2553 100644 --- a/gtsam_unstable/discrete/tests/testCSP.cpp +++ b/gtsam_unstable/discrete/tests/testCSP.cpp @@ -132,7 +132,7 @@ TEST(CSP, allInOne) { EXPECT(assert_equal(expectedProduct, product)); // Solve - auto mpe = csp.optimalAssignment(); + auto mpe = csp.optimize(); DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); EXPECT(assert_equal(expected, mpe)); @@ -172,22 +172,18 @@ TEST(CSP, WesternUS) { csp.addAllDiff(WY, CO); csp.addAllDiff(CO, NM); + DiscreteValues mpe; + insert(mpe)(0, 2)(1, 3)(2, 2)(3, 1)(4, 1)(5, 3)(6, 3)(7, 2)(8, 0)(9, 1)(10, 0); + // Create ordering according to example in ND-CSP.lyx Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7), Key(8), Key(9), Key(10); + // Solve using that ordering: - auto mpe = csp.optimalAssignment(ordering); - // GTSAM_PRINT(mpe); - DiscreteValues expected; - insert(expected)(WA.first, 1)(CA.first, 1)(NV.first, 3)(OR.first, 0)( - MT.first, 1)(WY.first, 0)(NM.first, 3)(CO.first, 2)(ID.first, 2)( - UT.first, 1)(AZ.first, 0); + auto actualMPE = csp.optimize(ordering); - // TODO: Fix me! mpe result seems to be right. (See the printing) - // It has the same prob as the expected solution. - // Is mpe another solution, or the expected solution is unique??? - EXPECT(assert_equal(expected, mpe)); + EXPECT(assert_equal(mpe, actualMPE)); EXPECT_DOUBLES_EQUAL(1, csp(mpe), 1e-9); // Write out the dual graph for hmetis @@ -227,7 +223,7 @@ TEST(CSP, ArcConsistency) { EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); // Solve - auto mpe = csp.optimalAssignment(); + auto mpe = csp.optimize(); DiscreteValues expected; insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); EXPECT(assert_equal(expected, mpe)); diff --git a/gtsam_unstable/discrete/tests/testScheduler.cpp b/gtsam_unstable/discrete/tests/testScheduler.cpp index 7822cbd38b..086057a466 100644 --- a/gtsam_unstable/discrete/tests/testScheduler.cpp +++ b/gtsam_unstable/discrete/tests/testScheduler.cpp @@ -122,7 +122,7 @@ TEST(schedulingExample, test) { // Do exact inference gttic(small); - auto MPE = s.optimalAssignment(); + auto MPE = s.optimize(); gttoc(small); // print MPE, commented out as unit tests don't print diff --git a/gtsam_unstable/discrete/tests/testSudoku.cpp b/gtsam_unstable/discrete/tests/testSudoku.cpp index 35f3ba8437..8b28581699 100644 --- a/gtsam_unstable/discrete/tests/testSudoku.cpp +++ b/gtsam_unstable/discrete/tests/testSudoku.cpp @@ -100,7 +100,7 @@ class Sudoku : public CSP { /// solve and print solution void printSolution() const { - auto MPE = optimalAssignment(); + auto MPE = optimize(); printAssignment(MPE); } @@ -126,7 +126,7 @@ TEST(Sudoku, small) { 0, 1, 0, 0); // optimize and check - auto solution = csp.optimalAssignment(); + auto solution = csp.optimize(); DiscreteValues expected; insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)( csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)( @@ -148,7 +148,7 @@ TEST(Sudoku, small) { EXPECT_LONGS_EQUAL(16, new_csp.size()); // Check that solution - auto new_solution = new_csp.optimalAssignment(); + auto new_solution = new_csp.optimize(); // csp.printAssignment(new_solution); EXPECT(assert_equal(expected, new_solution)); } @@ -250,7 +250,7 @@ TEST(Sudoku, AJC_3star_Feb8_2012) { EXPECT_LONGS_EQUAL(81, new_csp.size()); // Check that solution - auto solution = new_csp.optimalAssignment(); + auto solution = new_csp.optimize(); // csp.printAssignment(solution); EXPECT_LONGS_EQUAL(6, solution.at(key99)); } From 2ac79af17fe202f61092ea8355c321d5000fff7b Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 14:47:46 -0500 Subject: [PATCH 13/20] Added optimize variants that take custom ordering --- gtsam/discrete/DiscreteFactorGraph.cpp | 39 ++++++++++++++------------ gtsam/discrete/DiscreteFactorGraph.h | 16 +++++++++++ gtsam/discrete/DiscreteLookupDAG.cpp | 17 +++++++++++ gtsam/discrete/DiscreteLookupDAG.h | 5 ++++ 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index a166fdce93..7c03d21f9c 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -131,36 +131,39 @@ namespace gtsam { } /* ************************************************************************ */ + // The max-product solution below is a bit clunky: the elimination machinery + // does not allow for differently *typed* versions of elimination, so we + // eliminate into a Bayes Net using the special eliminate function above, and + // then create the DiscreteLookupDAG after the fact, in linear time. + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_maxProduct); - - // The solution below is a bitclunky: the elimination machinery does not - // allow for differently *typed* versions of elimination, so we eliminate - // into a Bayes Net using the special eliminate function above, and then - // create the DiscreteLookupDAG after the fact, in linear time. auto bayesNet = BaseEliminateable::eliminateSequential(orderingType, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); + } - // Copy to the DAG - DiscreteLookupDAG dag; - for (auto&& conditional : *bayesNet) { - if (auto lookupTable = - boost::dynamic_pointer_cast(conditional)) { - dag.push_back(lookupTable); - } else { - throw std::runtime_error( - "DiscreteFactorGraph::maxProduct: Expected look up table."); - } - } - return dag; + DiscreteLookupDAG DiscreteFactorGraph::maxProduct( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_maxProduct); + auto bayesNet = + BaseEliminateable::eliminateSequential(ordering, EliminateForMPE); + return DiscreteLookupDAG::FromBayesNet(*bayesNet); } /* ************************************************************************ */ DiscreteValues DiscreteFactorGraph::optimize( OptionalOrderingType orderingType) const { gttic(DiscreteFactorGraph_optimize); - DiscreteLookupDAG dag = maxProduct(); + DiscreteLookupDAG dag = maxProduct(orderingType); + return dag.argmax(); + } + + DiscreteValues DiscreteFactorGraph::optimize( + const Ordering& ordering) const { + gttic(DiscreteFactorGraph_optimize); + DiscreteLookupDAG dag = maxProduct(ordering); return dag.argmax(); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 7c658f5484..59827f9a57 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -138,6 +138,14 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteLookupDAG maxProduct( OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Implement the max-product algorithm + * + * @param ordering + * @return DiscreteLookupDAG::shared_ptr `DAG with lookup tables + */ + DiscreteLookupDAG maxProduct(const Ordering& ordering) const; + /** * @brief Find the maximum probable explanation (MPE) by doing max-product. * @@ -147,6 +155,14 @@ class GTSAM_EXPORT DiscreteFactorGraph DiscreteValues optimize( OptionalOrderingType orderingType = boost::none) const; + /** + * @brief Find the maximum probable explanation (MPE) by doing max-product. + * + * @param ordering + * @return DiscreteValues : MPE + */ + DiscreteValues optimize(const Ordering& ordering) const; + // /** Permute the variables in the factors */ // GTSAM_EXPORT void permuteWithInverse(const Permutation& // inversePermutation); diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 1edf508a1e..16620cc249 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -16,6 +16,7 @@ * @author Frank Dellaert */ +#include #include #include @@ -99,6 +100,22 @@ size_t DiscreteLookupTable::argmax(const DiscreteValues& parentsValues) const { return mpe; } +/* ************************************************************************** */ +DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( + const DiscreteBayesNet& bayesNet) { + DiscreteLookupDAG dag; + for (auto&& conditional : bayesNet) { + if (auto lookupTable = + boost::dynamic_pointer_cast(conditional)) { + dag.push_back(lookupTable); + } else { + throw std::runtime_error( + "DiscreteFactorGraph::maxProduct: Expected look up table."); + } + } + return dag; +} + /* ************************************************************************** */ DiscreteValues DiscreteLookupDAG::argmax() const { DiscreteValues result; diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index 1b3a38b406..f1eb24ec36 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -28,6 +28,8 @@ namespace gtsam { +class DiscreteBayesNet; + /** * @brief DiscreteLookupTable table for max-product * @@ -83,6 +85,9 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { /// Construct empty DAG. DiscreteLookupDAG() {} + // Create from BayesNet with LookupTables + static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); + /// Destructor virtual ~DiscreteLookupDAG() {} From ad21632fd27331f21fde0cd416bb0b669ef5003d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 17:35:33 -0500 Subject: [PATCH 14/20] fix typos --- gtsam/discrete/tests/testDiscreteDistribution.cpp | 2 +- python/gtsam/tests/test_DiscreteDistribution.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteDistribution.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp index 5e59aaa65b..d88b510f81 100644 --- a/gtsam/discrete/tests/testDiscreteDistribution.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /* - * @file testDiscretePrior.cpp + * @file testDiscreteDistribution.cpp * @brief unit tests for DiscreteDistribution * @author Frank dellaert * @date December 2021 diff --git a/python/gtsam/tests/test_DiscreteDistribution.py b/python/gtsam/tests/test_DiscreteDistribution.py index fa999fd6b5..3986bf2dfc 100644 --- a/python/gtsam/tests/test_DiscreteDistribution.py +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -20,7 +20,7 @@ X = 0, 2 -class TestDiscretePrior(GtsamTestCase): +class TestDiscreteDistribution(GtsamTestCase): """Tests for Discrete Priors.""" def test_constructor(self): From 125708fbb754887962da1a5ab962fd330913e2cf Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 17:35:39 -0500 Subject: [PATCH 15/20] Fix wrapper --- gtsam/discrete/discrete.i | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index e2310f4344..97c267aba2 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -111,11 +111,9 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { gtsam::DecisionTreeFactor* likelihood( const gtsam::DiscreteValues& frontalValues) const; gtsam::DecisionTreeFactor* likelihood(size_t value) const; - size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(size_t value) const; size_t sample() const; - void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -138,7 +136,7 @@ virtual class DiscreteDistribution : gtsam::DiscreteConditional { gtsam::DefaultKeyFormatter) const; double operator()(size_t value) const; std::vector pmf() const; - size_t solve() const; + size_t argmax() const; }; #include @@ -163,8 +161,6 @@ class DiscreteBayesNet { void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; - gtsam::DiscreteValues optimize() const; - gtsam::DiscreteValues optimize(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -217,6 +213,21 @@ class DiscreteBayesTree { std::map> names) const; }; +#include +class DiscreteLookupDAG { + DiscreteLookupDAG(); + void push_back(const gtsam::DiscreteLookupTable* table); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::DiscreteLookupTable* at(size_t i) const; + void print(string s = "DiscreteLookupDAG\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + gtsam::DiscreteValues argmax() const; + gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const; +}; + #include class DotWriter { DotWriter(double figureWidthInches = 5, double figureHeightInches = 5, @@ -260,6 +271,9 @@ class DiscreteFactorGraph { double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; + gtsam::DiscreteLookupDAG maxProduct(); + gtsam::DiscreteLookupDAG maxProduct(const gtsam::Ordering& ordering); + gtsam::DiscreteBayesNet eliminateSequential(); gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering); std::pair From 03314ed781fccd71bc98985d556d7f334b7dac24 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 17:39:06 -0500 Subject: [PATCH 16/20] updates to fix various issues --- gtsam/discrete/DiscreteConditional.cpp | 6 ++---- gtsam/discrete/DiscreteFactorGraph.h | 8 -------- gtsam/discrete/DiscreteLookupDAG.cpp | 8 +------- gtsam/discrete/DiscreteLookupDAG.h | 26 +++++++++----------------- 4 files changed, 12 insertions(+), 36 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 164a45f407..9a4897b727 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -322,8 +322,7 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } -/* ******************************************************************************** - */ +/* ************************************************************************** */ size_t DiscreteConditional::sample(size_t parent_value) const { if (nrParents() != 1) throw std::invalid_argument( @@ -334,8 +333,7 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } -/* ******************************************************************************** - */ +/* ************************************************************************** */ size_t DiscreteConditional::sample() const { if (nrParents() != 0) throw std::invalid_argument( diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 59827f9a57..e0f0a104bd 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -163,14 +163,6 @@ class GTSAM_EXPORT DiscreteFactorGraph */ DiscreteValues optimize(const Ordering& ordering) const; - // /** Permute the variables in the factors */ - // GTSAM_EXPORT void permuteWithInverse(const Permutation& - // inversePermutation); - // - // /** Apply a reduction, which is a remapping of variable indices. */ - // GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& - // inverseReduction); - /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 16620cc249..d96b38b0ec 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscreteLookupTable.cpp + * @file DiscreteLookupDAG.cpp * @date Feb 14, 2011 * @author Duy-Nguyen Ta * @author Frank Dellaert @@ -116,12 +116,6 @@ DiscreteLookupDAG DiscreteLookupDAG::FromBayesNet( return dag; } -/* ************************************************************************** */ -DiscreteValues DiscreteLookupDAG::argmax() const { - DiscreteValues result; - return argmax(result); -} - DiscreteValues DiscreteLookupDAG::argmax(DiscreteValues result) const { // Argmax each node in turn in topological sort order (parents first). for (auto lookupTable : boost::adaptors::reverse(*this)) diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index f1eb24ec36..8cb651f28a 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -11,7 +11,7 @@ /** * @file DiscreteLookupDAG.h - * @date JAnuary, 2022 + * @date January, 2022 * @author Frank dellaert */ @@ -34,7 +34,7 @@ class DiscreteBayesNet; * @brief DiscreteLookupTable table for max-product * * Inherits from discrete conditional for convenience, but is not normalized. - * Is used in pax-product algorithm. + * Is used in the max-product algorithm. */ class DiscreteLookupTable : public DiscreteConditional { public: @@ -85,7 +85,7 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { /// Construct empty DAG. DiscreteLookupDAG() {} - // Create from BayesNet with LookupTables + /// Create from BayesNet with LookupTables static DiscreteLookupDAG FromBayesNet(const DiscreteBayesNet& bayesNet); /// Destructor @@ -111,25 +111,17 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { } /** - * @brief argmax by back-substitution. + * @brief argmax by back-substitution, optionally given certain variables. * * Assumes the DAG is reverse topologically sorted, i.e. last - * conditional will be optimized first. If the DAG resulted from - * eliminating a factor graph, this is true for the elimination ordering. - * - * @return optimal assignment for all variables. - */ - DiscreteValues argmax() const; - - /** - * @brief argmax by back-substitution, given certain variables. - * - * Assumes the DAG is reverse topologically sorted *and* that the - * DAG does not contain any conditionals for the given variables. + * conditional will be optimized first *and* that the + * DAG does not contain any conditionals for the given variables. If the DAG + * resulted from eliminating a factor graph, this is true for the elimination + * ordering. * * @return given assignment extended w. optimal assignment for all variables. */ - DiscreteValues argmax(DiscreteValues given) const; + DiscreteValues argmax(DiscreteValues given = DiscreteValues()) const; /// @} private: From f9b14893c86fdde1dc870b9e2d20e50390b71633 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 18:10:47 -0500 Subject: [PATCH 17/20] moved argmax to conditional --- gtsam/discrete/DiscreteConditional.cpp | 20 ++++++++++++++++++++ gtsam/discrete/DiscreteConditional.h | 6 ++++++ gtsam/discrete/DiscreteDistribution.cpp | 17 ----------------- gtsam/discrete/DiscreteDistribution.h | 12 ------------ 4 files changed, 26 insertions(+), 29 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 9a4897b727..06b2856f8f 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -286,6 +286,26 @@ size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { } #endif +/* ************************************************************************** */ +size_t DiscreteConditional::argmax() const { + size_t maxValue = 0; + double maxP = 0; + assert(nrFrontals() == 1); + assert(nrParents() == 0); + DiscreteValues frontals; + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = (*this)(frontals); + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + maxValue = value; + } + } + return maxValue; +} + /* ************************************************************************** */ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { assert(nrFrontals() == 1); diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index af05e932bb..48d94a3837 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -192,6 +192,12 @@ class GTSAM_EXPORT DiscreteConditional /// Zero parent version. size_t sample() const; + /** + * @brief Return assignment that maximizes distribution. + * @return Optimal assignment (1 frontal variable). + */ + size_t argmax() const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscreteDistribution.cpp b/gtsam/discrete/DiscreteDistribution.cpp index 5f6fba6a28..7397714709 100644 --- a/gtsam/discrete/DiscreteDistribution.cpp +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -49,21 +49,4 @@ std::vector DiscreteDistribution::pmf() const { return array; } -/* ************************************************************************** */ -size_t DiscreteDistribution::argmax() const { - size_t maxValue = 0; - double maxP = 0; - assert(nrFrontals() == 1); - Key j = firstFrontalKey(); - for (size_t value = 0; value < cardinality(j); value++) { - double pValueS = (*this)(value); - // Update MPE solution if better - if (pValueS > maxP) { - maxP = pValueS; - maxValue = value; - } - } - return maxValue; -} - } // namespace gtsam diff --git a/gtsam/discrete/DiscreteDistribution.h b/gtsam/discrete/DiscreteDistribution.h index 8dcc75733f..c5147dbc19 100644 --- a/gtsam/discrete/DiscreteDistribution.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -90,18 +90,6 @@ class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { /// Return entire probability mass function. std::vector pmf() const; - /** - * @brief Return assignment that maximizes distribution. - * @return Optimal assignment (1 frontal variable). - */ - size_t argmax() const; - - /** - * sample - * @return sample from conditional - */ - size_t sample() const { return Base::sample(); } - /// @} #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 /// @name Deprecated functionality From e3c98b0fafee3352847f0aeffbc1a7490fdcd488 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 18:12:30 -0500 Subject: [PATCH 18/20] Fix python tests --- python/gtsam/tests/test_DiscreteBayesNet.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 6abd660cfc..3ae3b625cd 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -79,7 +79,7 @@ def test_Asia(self): self.gtsamAssertEquals(chordal.at(7), expected2) # solve - actualMPE = chordal.optimize() + actualMPE = fg.optimize() expectedMPE = DiscreteValues() for key in [Asia, Dyspnea, XRay, Tuberculosis, Smoking, Either, LungCancer, Bronchitis]: expectedMPE[key[0]] = 0 @@ -94,8 +94,7 @@ def test_Asia(self): fg.add(Dyspnea, "0 1") # solve again, now with evidence - chordal2 = fg.eliminateSequential(ordering) - actualMPE2 = chordal2.optimize() + actualMPE2 = fg.optimize() expectedMPE2 = DiscreteValues() for key in [XRay, Tuberculosis, Either, LungCancer]: expectedMPE2[key[0]] = 0 @@ -105,6 +104,7 @@ def test_Asia(self): list(expectedMPE2.items())) # now sample from it + chordal2 = fg.eliminateSequential(ordering) actualSample = chordal2.sample() self.assertEqual(len(actualSample), 8) @@ -122,10 +122,6 @@ def test_fragment(self): for key in [Asia, Smoking]: given[key[0]] = 0 - # Now optimize fragment: - actual = fragment.optimize(given) - self.assertEqual(len(actual), 5) - # Now sample from fragment: actual = fragment.sample(given) self.assertEqual(len(actual), 5) From 99a97da5f77b506e83daf5cb76b02fa16188e615 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 21 Jan 2022 18:12:38 -0500 Subject: [PATCH 19/20] Fix examples --- examples/DiscreteBayesNetExample.cpp | 9 ++++----- examples/DiscreteBayesNet_FG.cpp | 8 ++++---- examples/HMMExample.cpp | 8 ++++---- examples/UGM_chain.cpp | 5 ++--- examples/UGM_small.cpp | 5 ++--- gtsam_unstable/discrete/examples/schedulingExample.cpp | 8 ++++---- gtsam_unstable/discrete/examples/schedulingQuals12.cpp | 4 ++-- gtsam_unstable/discrete/examples/schedulingQuals13.cpp | 4 ++-- 8 files changed, 24 insertions(+), 27 deletions(-) diff --git a/examples/DiscreteBayesNetExample.cpp b/examples/DiscreteBayesNetExample.cpp index febc1e1288..dfd7beb63b 100644 --- a/examples/DiscreteBayesNetExample.cpp +++ b/examples/DiscreteBayesNetExample.cpp @@ -53,10 +53,9 @@ int main(int argc, char **argv) { // Create solver and eliminate Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7); - DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); // solve - auto mpe = chordal->optimize(); + auto mpe = fg.optimize(); GTSAM_PRINT(mpe); // We can also build a Bayes tree (directed junction tree). @@ -69,14 +68,14 @@ int main(int argc, char **argv) { fg.add(Dyspnea, "0 1"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering); - auto mpe2 = chordal2->optimize(); + auto mpe2 = fg.optimize(); GTSAM_PRINT(mpe2); // We can also sample from it + DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { - auto sample = chordal2->sample(); + auto sample = chordal->sample(); GTSAM_PRINT(sample); } return 0; diff --git a/examples/DiscreteBayesNet_FG.cpp b/examples/DiscreteBayesNet_FG.cpp index 69283a1be7..88904001a0 100644 --- a/examples/DiscreteBayesNet_FG.cpp +++ b/examples/DiscreteBayesNet_FG.cpp @@ -85,7 +85,7 @@ int main(int argc, char **argv) { } // "Most Probable Explanation", i.e., configuration with largest value - auto mpe = graph.eliminateSequential()->optimize(); + auto mpe = graph.optimize(); cout << "\nMost Probable Explanation (MPE):" << endl; print(mpe); @@ -96,8 +96,7 @@ int main(int argc, char **argv) { graph.add(Cloudy, "1 0"); // solve again, now with evidence - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto mpe_with_evidence = chordal->optimize(); + auto mpe_with_evidence = graph.optimize(); cout << "\nMPE given C=0:" << endl; print(mpe_with_evidence); @@ -110,7 +109,8 @@ int main(int argc, char **argv) { cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1] << endl; - // We can also sample from it + // We can also sample from the eliminated graph + DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); cout << "\n10 samples:" << endl; for (size_t i = 0; i < 10; i++) { auto sample = chordal->sample(); diff --git a/examples/HMMExample.cpp b/examples/HMMExample.cpp index b46baf4e09..3a76730016 100644 --- a/examples/HMMExample.cpp +++ b/examples/HMMExample.cpp @@ -59,16 +59,16 @@ int main(int argc, char **argv) { // Convert to factor graph DiscreteFactorGraph factorGraph(hmm); + // Do max-prodcut + auto mpe = factorGraph.optimize(); + GTSAM_PRINT(mpe); + // Create solver and eliminate // This will create a DAG ordered with arrow of time reversed DiscreteBayesNet::shared_ptr chordal = factorGraph.eliminateSequential(ordering); chordal->print("Eliminated"); - // solve - auto mpe = chordal->optimize(); - GTSAM_PRINT(mpe); - // We can also sample from it cout << "\n10 samples:" << endl; for (size_t k = 0; k < 10; k++) { diff --git a/examples/UGM_chain.cpp b/examples/UGM_chain.cpp index ababef0220..ad21af9fa7 100644 --- a/examples/UGM_chain.cpp +++ b/examples/UGM_chain.cpp @@ -68,9 +68,8 @@ int main(int argc, char** argv) { << graph.size() << " factors (Unary+Edge)."; // "Decoding", i.e., configuration with largest value - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto optimalDecoding = chordal->optimize(); + // Uses max-product. + auto optimalDecoding = graph.optimize(); optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n"); // "Inference" Computing marginals for each node diff --git a/examples/UGM_small.cpp b/examples/UGM_small.cpp index 24bd0c0ba7..bc6a413178 100644 --- a/examples/UGM_small.cpp +++ b/examples/UGM_small.cpp @@ -61,9 +61,8 @@ int main(int argc, char** argv) { } // "Decoding", i.e., configuration with largest value (MPE) - // We use sequential variable elimination - DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential(); - auto optimalDecoding = chordal->optimize(); + // Uses max-product + auto optimalDecoding = graph.optimize(); GTSAM_PRINT(optimalDecoding); // "Inference" Computing marginals diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index 7ed00bcf61..487edc97a3 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -165,11 +165,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(6 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); @@ -319,11 +319,11 @@ void accomodateStudent() { // GTSAM_PRINT(*chordal); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(0); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index e6a47f5f86..830d59ba73 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -190,11 +190,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; size_t count = (*root)(values); diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 82ea16a47a..b24f9bf0a4 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -212,11 +212,11 @@ void solveStaged(size_t addMutex = 2) { root->print(""/*scheduler.studentName(s)*/); // solve root node only - DiscreteValues values; - size_t bestSlot = root->solve(values); + size_t bestSlot = root->argmax(); // get corresponding count DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + DiscreteValues values; values[dkey.first] = bestSlot; double count = (*root)(values); From 06150d143c9a2e03a8ef5b4ba687103bef3d3a11 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 21 Jan 2022 00:23:42 -0500 Subject: [PATCH 20/20] fix for setup.py install deprecation --- .github/scripts/python.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/scripts/python.sh b/.github/scripts/python.sh index 6cc62d2b06..6f5643fc75 100644 --- a/.github/scripts/python.sh +++ b/.github/scripts/python.sh @@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ make -j2 install cd $GITHUB_WORKSPACE/build/python -$PYTHON setup.py install --user --prefix= +pip install --user --install-option="--prefix=" . cd $GITHUB_WORKSPACE/python/gtsam/tests $PYTHON -m unittest discover -v