From 8eb623b58f821c7dbaf7ddda3ac6cc7af54a3f5c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 21:34:22 -0500 Subject: [PATCH 1/7] Added an optional names argument for discrete markdown renderers --- gtsam/discrete/DecisionTreeFactor.cpp | 11 ++++++---- gtsam/discrete/DecisionTreeFactor.h | 12 +++++++--- gtsam/discrete/DiscreteBayesNet.cpp | 5 +++-- gtsam/discrete/DiscreteBayesNet.h | 4 ++-- gtsam/discrete/DiscreteBayesTree.cpp | 5 +++-- gtsam/discrete/DiscreteBayesTree.h | 4 ++-- gtsam/discrete/DiscreteConditional.cpp | 15 ++++++++----- gtsam/discrete/DiscreteConditional.h | 4 ++-- gtsam/discrete/DiscreteFactor.cpp | 15 +++++++++++-- gtsam/discrete/DiscreteFactor.h | 15 +++++++++++-- gtsam/discrete/DiscreteFactorGraph.cpp | 21 ++++++++++-------- gtsam/discrete/DiscreteFactorGraph.h | 15 ++++++++++--- gtsam/discrete/discrete.i | 10 +++++++++ .../discrete/tests/testDecisionTreeFactor.cpp | 21 ++++++++++++++++++ .../tests/testDiscreteConditional.cpp | 22 ++++++++++--------- gtsam_unstable/discrete/Constraint.h | 4 ++-- 16 files changed, 133 insertions(+), 50 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 75018cf928..4f3e3f7f14 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -179,9 +179,9 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DecisionTreeFactor::markdown( - const KeyFormatter& keyFormatter) const { - std::stringstream ss; + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header and construct argument for `cartesianProduct`. ss << "|"; @@ -200,7 +200,10 @@ namespace gtsam { for (const auto& kv : rows) { ss << "|"; auto assignment = kv.first; - for (auto& key : keys()) ss << assignment.at(key) << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << Translate(names, key, index) << "|"; + } ss << kv.second << "|\n"; } return ss.str(); diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 46509db822..f8832c2237 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -192,9 +192,15 @@ namespace gtsam { std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, bool showZero = true) const; - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index d9fba630e5..510fb56389 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -63,12 +63,13 @@ namespace gtsam { /* ************************************************************************* */ std::string DiscreteBayesNet::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesNet` of size " << size() << endl << endl; for(const DiscreteConditional::shared_ptr& conditional: *this) - ss << conditional->markdown(keyFormatter) << endl; + ss << conditional->markdown(keyFormatter, names) << endl; return ss.str(); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index aed4cec0aa..5332b51dd0 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -108,8 +108,8 @@ namespace gtsam { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 8a9186d05a..07d6e0f0ee 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -57,13 +57,14 @@ namespace gtsam { /* **************************************************************************/ std::string DiscreteBayesTree::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, size_t& indent) { - ss << "\n" << clique->conditional()->markdown(keyFormatter); + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); return indent + 1; }; size_t indent; diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 12d6017cc3..6189f25d54 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -93,8 +93,8 @@ class GTSAM_EXPORT DiscreteBayesTree /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 46d5509e06..203f00f89f 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -283,8 +283,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const { } /* ************************************************************************* */ -std::string DiscreteConditional::markdown( - const KeyFormatter& keyFormatter) const { +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { std::stringstream ss; // Print out signature. @@ -331,7 +331,10 @@ std::string DiscreteConditional::markdown( pairs.rend() - nrParents()); const auto frontal_assignments = cartesianProduct(slatnorf); for (const auto& a : frontal_assignments) { - for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); + for (it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index); + } ss << "|"; } ss << "\n"; @@ -348,8 +351,10 @@ std::string DiscreteConditional::markdown( for (const auto& a : assignments) { if (count == 0) { ss << "|"; - for (it = beginParents(); it != endParents(); ++it) - ss << a.at(*it) << "|"; + for (it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index) << "|"; + } } ss << operator()(a) << "|"; count = (count + 1) % n; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index d21e3ae264..1cad927e92 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -180,8 +180,8 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} }; diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d28..1a12ef405a 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -19,9 +19,20 @@ #include +#include + using namespace std; namespace gtsam { -/* ************************************************************************* */ -} // namespace gtsam +string DiscreteFactor::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 76ed703bb2..70545a5ca5 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -89,9 +89,20 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @name Wrapper support /// @{ - /// Render as markdown table. + using Names = std::map>; + + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ virtual std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; /// @} }; diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index bd84e13647..be046d2902 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,17 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -64,7 +66,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -130,14 +132,15 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DiscreteFactorGraph::markdown( - const KeyFormatter& keyFormatter) const { + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; for (size_t i = 0; i < factors_.size(); i++) { ss << "factor " << i << ":\n"; - ss << factors_[i]->markdown(keyFormatter) << endl; + ss << factors_[i]->markdown(keyFormatter, names) << endl; } return ss.str(); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 6856493f7f..9aa04d6497 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -24,7 +24,10 @@ #include #include #include + #include +#include +#include namespace gtsam { @@ -140,9 +143,15 @@ public EliminateableFactorGraph { /// @name Wrapper support /// @{ - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; // \ DiscreteFactorGraph diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5bd4a2913a..e298deaf1b 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -52,6 +52,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -88,6 +90,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -130,6 +134,8 @@ class DiscreteBayesNet { gtsam::DiscreteValues sample() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -164,6 +170,8 @@ class DiscreteBayesTree { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -211,6 +219,8 @@ class DiscreteFactorGraph { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 6af7ca7313..c4e5f06bb3 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -119,6 +119,27 @@ TEST(DecisionTreeFactor, markdown) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DecisionTreeFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 00ae1acd01..90da07cdcd 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -161,17 +161,19 @@ TEST(DiscreteConditional, markdown) { DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = " *P(A|B,C)*:\n\n" - "|B|C|0|1|\n" + "|B|C|T|F|\n" "|:-:|:-:|:-:|:-:|\n" - "|0|0|0|1|\n" - "|0|1|0.25|0.75|\n" - "|0|2|0.5|0.5|\n" - "|1|0|0.75|0.25|\n" - "|1|1|0|1|\n" - "|1|2|1|0|\n"; - vector names{"C", "B", "A"}; - auto formatter = [names](Key key) { return names[key]; }; - string actual = conditional.markdown(formatter); + "|-|Zero|0|1|\n" + "|-|One|0.25|0.75|\n" + "|-|Two|0.5|0.5|\n" + "|+|Zero|0.75|0.25|\n" + "|+|One|0|1|\n" + "|+|Two|1|0|\n"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.markdown(formatter, names); EXPECT(actual == expected); } diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 5c21028a0c..85748f0546 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -86,8 +86,8 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); } From c51bba81d8cf8ebf81e67bed6a33be3dd2e681e3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 28 Dec 2021 21:22:03 -0500 Subject: [PATCH 2/7] Fix sample() --- gtsam/discrete/DiscretePrior.h | 2 +- python/gtsam/tests/test_DiscretePrior.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index 1a7c6ae6cb..d11d9be066 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { * sample * @return sample from conditional */ - size_t sample() const { return Base::sample({}); } + size_t sample() const { return Base::sample(DiscreteValues()); } /// @} }; diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 4f017d66a4..5bf6a8d196 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -6,7 +6,7 @@ See LICENSE for the license information Unit tests for Discrete Priors. -Author: Varun Agrawal +Author: Frank Dellaert """ # pylint: disable=no-name-in-module, invalid-name @@ -42,6 +42,11 @@ def test_pmf(self): expected = np.array([0.4, 0.6]) np.testing.assert_allclose(expected, prior.pmf()) + def test_sample(self): + prior = DiscretePrior(X, "2/3") + actual = prior.sample() + self.assertIsInstance(actual, int) + def test_markdown(self): """Test the _repr_markdown_ method.""" From fca23e0559a39e3a1c402c0c0ebe18bebb5b71a2 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 22:38:39 -0500 Subject: [PATCH 3/7] italicized parent values --- gtsam/discrete/DiscreteConditional.cpp | 2 +- gtsam/discrete/tests/testDiscreteBayesNet.cpp | 3 ++- gtsam/discrete/tests/testDiscreteConditional.cpp | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 203f00f89f..af4ad4495d 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -317,7 +317,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, ss << "|"; const_iterator it; for(Key parent: parents()) { - ss << keyFormatter(parent) << "|"; + ss << "*" << keyFormatter(parent) << "*|"; pairs.emplace_back(parent, cardinalities_.at(parent)); } diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 1de45905a6..de8e05edc3 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -187,12 +187,13 @@ TEST(DiscreteBayesNet, markdown) { "|1|0.01|\n" "\n" " *P(Smoking|Asia)*:\n\n" - "|Asia|0|1|\n" + "|*Asia*|0|1|\n" "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" "|1|0.7|0.3|\n\n"; auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; string actual = fragment.markdown(formatter); + cout << actual << endl; EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 90da07cdcd..b498b0541a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -143,7 +143,7 @@ TEST(DiscreteConditional, markdown_multivalued) { A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); string expected = " *P(a1|b1)*:\n\n" - "|b1|0|1|2|\n" + "|*b1*|0|1|2|\n" "|:-:|:-:|:-:|:-:|\n" "|0|0.02|0.88|0.1|\n" "|1|0.02|0.2|0.78|\n" @@ -161,7 +161,7 @@ TEST(DiscreteConditional, markdown) { DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = " *P(A|B,C)*:\n\n" - "|B|C|T|F|\n" + "|*B*|*C*|T|F|\n" "|:-:|:-:|:-:|:-:|\n" "|-|Zero|0|1|\n" "|-|One|0.25|0.75|\n" From 88c79a2a56f564a64b30e520f5074b8b283c3111 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 22:48:55 -0500 Subject: [PATCH 4/7] Fixed some examples --- gtsam_unstable/discrete/examples/schedulingExample.cpp | 2 +- gtsam_unstable/discrete/examples/schedulingQuals12.cpp | 2 +- gtsam_unstable/discrete/examples/schedulingQuals13.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index 3460664db7..2a9addf918 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 19694c31ec..8260bfb068 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 4b96b1eeba..cf3ce04535 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -139,7 +139,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference From 53a6523943392afba36f6f679e501cdc607b459a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 23:23:51 -0500 Subject: [PATCH 5/7] Fixed issues with sample --- gtsam/discrete/DiscreteConditional.cpp | 9 +++++++++ gtsam/discrete/DiscreteConditional.h | 5 ++++- gtsam/discrete/DiscretePrior.h | 2 +- gtsam/discrete/discrete.i | 2 +- gtsam/discrete/tests/testDiscretePrior.cpp | 10 +++++++++- 5 files changed, 24 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index af4ad4495d..b4f95780d5 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -282,6 +282,15 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } +/* ******************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + /* ************************************************************************* */ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, const Names& names) const { diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 1cad927e92..7ce3dc9308 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -162,9 +162,12 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, size_t sample(const DiscreteValues& parentsValues) const; - /// Single value version. + /// Single parent version. size_t sample(size_t parent_value) const; + /// Zero parent version. + size_t sample() const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index d11d9be066..9ac8acb17a 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { * sample * @return sample from conditional */ - size_t sample() const { return Base::sample(DiscreteValues()); } + size_t sample() const { return Base::sample(); } /// @} }; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index e298deaf1b..a837328838 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -86,6 +86,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(size_t value) const; + size_t sample() const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = @@ -105,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { double operator()(size_t value) const; std::vector pmf() const; size_t solve() const; - size_t sample() const; }; #include diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index b91926cc05..23f093b229 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { DiscretePrior actual(X % "2/3"); + EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual.nrParents()); DecisionTreeFactor f(X, "0.4 0.6"); DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); @@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) { } /* ************************************************************************* */ -TEST(DiscretePrior, to_vector) { +TEST(DiscretePrior, pmf) { DiscretePrior prior(X % "2/3"); vector expected {0.4, 0.6}; EXPECT(prior.pmf() == expected); } +/* ************************************************************************* */ +TEST(DiscretePrior, sample) { + DiscretePrior prior(X % "2/3"); + prior.sample(); +} + /* ************************************************************************* */ int main() { TestResult tr; From 8c3d51262996d8235ebc1e8e168e45ff916f7c57 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 2 Jan 2022 23:24:03 -0500 Subject: [PATCH 6/7] Fixed python test --- python/gtsam/tests/test_DiscreteConditional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 1b2ce70cd7..86bc303a9a 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -50,7 +50,7 @@ def test_markdown(self): "0/1 1/3 1/1 3/1 0/1 1/0") expected = \ " *P(A|B,C)*:\n\n" \ - "|B|C|0|1|\n" \ + "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ "|0|1|0.25|0.75|\n" \ From 9d6f9f647ad345ebce40418436ec254745e942bd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Mon, 3 Jan 2022 11:13:32 -0500 Subject: [PATCH 7/7] Address comments --- gtsam/discrete/DiscreteFactor.h | 2 ++ gtsam/discrete/tests/testDiscreteBayesNet.cpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 70545a5ca5..e30c0a6fec 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -89,8 +89,10 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @name Wrapper support /// @{ + /// Translation table from values to strings. using Names = std::map>; + /// Translate an integer index value for given key to a string. static std::string Translate(const Names& names, Key key, size_t index); /** diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index de8e05edc3..251978c99b 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -193,7 +193,6 @@ TEST(DiscreteBayesNet, markdown) { "|1|0.7|0.3|\n\n"; auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; }; string actual = fragment.markdown(formatter); - cout << actual << endl; EXPECT(actual == expected); }