diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 75018cf928..4f3e3f7f14 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -179,9 +179,9 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DecisionTreeFactor::markdown( - const KeyFormatter& keyFormatter) const { - std::stringstream ss; + string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header and construct argument for `cartesianProduct`. ss << "|"; @@ -200,7 +200,10 @@ namespace gtsam { for (const auto& kv : rows) { ss << "|"; auto assignment = kv.first; - for (auto& key : keys()) ss << assignment.at(key) << "|"; + for (auto& key : keys()) { + size_t index = assignment.at(key); + ss << Translate(names, key, index) << "|"; + } ss << kv.second << "|\n"; } return ss.str(); diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 46509db822..f8832c2237 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -192,9 +192,15 @@ namespace gtsam { std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, bool showZero = true) const; - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp index d9fba630e5..510fb56389 100644 --- a/gtsam/discrete/DiscreteBayesNet.cpp +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -63,12 +63,13 @@ namespace gtsam { /* ************************************************************************* */ std::string DiscreteBayesNet::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesNet` of size " << size() << endl << endl; for(const DiscreteConditional::shared_ptr& conditional: *this) - ss << conditional->markdown(keyFormatter) << endl; + ss << conditional->markdown(keyFormatter, names) << endl; return ss.str(); } /* ************************************************************************* */ diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index aed4cec0aa..5332b51dd0 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -108,8 +108,8 @@ namespace gtsam { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} diff --git a/gtsam/discrete/DiscreteBayesTree.cpp b/gtsam/discrete/DiscreteBayesTree.cpp index 8a9186d05a..07d6e0f0ee 100644 --- a/gtsam/discrete/DiscreteBayesTree.cpp +++ b/gtsam/discrete/DiscreteBayesTree.cpp @@ -57,13 +57,14 @@ namespace gtsam { /* **************************************************************************/ std::string DiscreteBayesTree::markdown( - const KeyFormatter& keyFormatter) const { + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl; auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique, size_t& indent) { - ss << "\n" << clique->conditional()->markdown(keyFormatter); + ss << "\n" << clique->conditional()->markdown(keyFormatter, names); return indent + 1; }; size_t indent; diff --git a/gtsam/discrete/DiscreteBayesTree.h b/gtsam/discrete/DiscreteBayesTree.h index 12d6017cc3..6189f25d54 100644 --- a/gtsam/discrete/DiscreteBayesTree.h +++ b/gtsam/discrete/DiscreteBayesTree.h @@ -93,8 +93,8 @@ class GTSAM_EXPORT DiscreteBayesTree /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 46d5509e06..b4f95780d5 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -282,9 +282,18 @@ size_t DiscreteConditional::sample(size_t parent_value) const { return sample(values); } +/* ******************************************************************************** */ +size_t DiscreteConditional::sample() const { + if (nrParents() != 0) + throw std::invalid_argument( + "sample() can only be invoked on no-parent prior"); + DiscreteValues values; + return sample(values); +} + /* ************************************************************************* */ -std::string DiscreteConditional::markdown( - const KeyFormatter& keyFormatter) const { +std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { std::stringstream ss; // Print out signature. @@ -317,7 +326,7 @@ std::string DiscreteConditional::markdown( ss << "|"; const_iterator it; for(Key parent: parents()) { - ss << keyFormatter(parent) << "|"; + ss << "*" << keyFormatter(parent) << "*|"; pairs.emplace_back(parent, cardinalities_.at(parent)); } @@ -331,7 +340,10 @@ std::string DiscreteConditional::markdown( pairs.rend() - nrParents()); const auto frontal_assignments = cartesianProduct(slatnorf); for (const auto& a : frontal_assignments) { - for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it); + for (it = beginFrontals(); it != endFrontals(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index); + } ss << "|"; } ss << "\n"; @@ -348,8 +360,10 @@ std::string DiscreteConditional::markdown( for (const auto& a : assignments) { if (count == 0) { ss << "|"; - for (it = beginParents(); it != endParents(); ++it) - ss << a.at(*it) << "|"; + for (it = beginParents(); it != endParents(); ++it) { + size_t index = a.at(*it); + ss << Translate(names, *it, index) << "|"; + } } ss << operator()(a) << "|"; count = (count + 1) % n; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index d21e3ae264..7ce3dc9308 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -162,9 +162,12 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, size_t sample(const DiscreteValues& parentsValues) const; - /// Single value version. + /// Single parent version. size_t sample(size_t parent_value) const; + /// Zero parent version. + size_t sample() const; + /// @} /// @name Advanced Interface /// @{ @@ -180,8 +183,8 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override; + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} }; diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp index c101653d28..1a12ef405a 100644 --- a/gtsam/discrete/DiscreteFactor.cpp +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -19,9 +19,20 @@ #include +#include + using namespace std; namespace gtsam { -/* ************************************************************************* */ -} // namespace gtsam +string DiscreteFactor::Translate(const Names& names, Key key, size_t index) { + if (names.empty()) { + stringstream ss; + ss << index; + return ss.str(); + } else { + return names.at(key)[index]; + } +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 76ed703bb2..e30c0a6fec 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -89,9 +89,22 @@ class GTSAM_EXPORT DiscreteFactor: public Factor { /// @name Wrapper support /// @{ - /// Render as markdown table. + /// Translation table from values to strings. + using Names = std::map>; + + /// Translate an integer index value for given key to a string. + static std::string Translate(const Names& names, Key key, size_t index); + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ virtual std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const = 0; /// @} }; diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp index bd84e13647..be046d2902 100644 --- a/gtsam/discrete/DiscreteFactorGraph.cpp +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -16,15 +16,17 @@ * @author Frank Dellaert */ -//#define ENABLE_TIMING -#include -#include #include +#include #include +#include #include -#include #include -#include +#include + +using std::vector; +using std::string; +using std::map; namespace gtsam { @@ -64,7 +66,7 @@ namespace gtsam { } /* ************************************************************************* */ - void DiscreteFactorGraph::print(const std::string& s, + void DiscreteFactorGraph::print(const string& s, const KeyFormatter& formatter) const { std::cout << s << std::endl; std::cout << "size: " << size() << std::endl; @@ -130,14 +132,15 @@ namespace gtsam { } /* ************************************************************************* */ - std::string DiscreteFactorGraph::markdown( - const KeyFormatter& keyFormatter) const { + string DiscreteFactorGraph::markdown( + const KeyFormatter& keyFormatter, + const DiscreteFactor::Names& names) const { using std::endl; std::stringstream ss; ss << "`DiscreteFactorGraph` of size " << size() << endl << endl; for (size_t i = 0; i < factors_.size(); i++) { ss << "factor " << i << ":\n"; - ss << factors_[i]->markdown(keyFormatter) << endl; + ss << factors_[i]->markdown(keyFormatter, names) << endl; } return ss.str(); } diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 6856493f7f..9aa04d6497 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -24,7 +24,10 @@ #include #include #include + #include +#include +#include namespace gtsam { @@ -140,9 +143,15 @@ public EliminateableFactorGraph { /// @name Wrapper support /// @{ - /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, a map from Key to category names. + * @return std::string a (potentially long) markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const DiscreteFactor::Names& names = {}) const; /// @} }; // \ DiscreteFactorGraph diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index 1a7c6ae6cb..9ac8acb17a 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -98,7 +98,7 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { * sample * @return sample from conditional */ - size_t sample() const { return Base::sample({}); } + size_t sample() const { return Base::sample(); } /// @} }; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5bd4a2913a..a837328838 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -52,6 +52,8 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -84,10 +86,13 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; size_t sample(size_t value) const; + size_t sample() const; void solveInPlace(gtsam::DiscreteValues @parentsValues) const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -101,7 +106,6 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { double operator()(size_t value) const; std::vector pmf() const; size_t solve() const; - size_t sample() const; }; #include @@ -130,6 +134,8 @@ class DiscreteBayesNet { gtsam::DiscreteValues sample() const; string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -164,6 +170,8 @@ class DiscreteBayesTree { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; #include @@ -211,6 +219,8 @@ class DiscreteFactorGraph { string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + string markdown(const gtsam::KeyFormatter& keyFormatter, + std::map> names) const; }; } // namespace gtsam diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 6af7ca7313..c4e5f06bb3 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -119,6 +119,27 @@ TEST(DecisionTreeFactor, markdown) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check markdown representation with a value formatter. +TEST(DecisionTreeFactor, markdownWithValueFormatter) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + string expected = + "|A|B|value|\n" + "|:-:|:-:|:-:|\n" + "|Zero|-|1|\n" + "|Zero|+|2|\n" + "|One|-|3|\n" + "|One|+|4|\n" + "|Two|-|5|\n" + "|Two|+|6|\n"; + auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; + DecisionTreeFactor::Names names{{12, {"Zero", "One", "Two"}}, + {5, {"-", "+"}}}; + string actual = f.markdown(keyFormatter, names); + EXPECT(actual == expected); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index 1de45905a6..251978c99b 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -187,7 +187,7 @@ TEST(DiscreteBayesNet, markdown) { "|1|0.01|\n" "\n" " *P(Smoking|Asia)*:\n\n" - "|Asia|0|1|\n" + "|*Asia*|0|1|\n" "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" "|1|0.7|0.3|\n\n"; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 00ae1acd01..b498b0541a 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -143,7 +143,7 @@ TEST(DiscreteConditional, markdown_multivalued) { A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); string expected = " *P(a1|b1)*:\n\n" - "|b1|0|1|2|\n" + "|*b1*|0|1|2|\n" "|:-:|:-:|:-:|:-:|\n" "|0|0.02|0.88|0.1|\n" "|1|0.02|0.2|0.78|\n" @@ -161,17 +161,19 @@ TEST(DiscreteConditional, markdown) { DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = " *P(A|B,C)*:\n\n" - "|B|C|0|1|\n" + "|*B*|*C*|T|F|\n" "|:-:|:-:|:-:|:-:|\n" - "|0|0|0|1|\n" - "|0|1|0.25|0.75|\n" - "|0|2|0.5|0.5|\n" - "|1|0|0.75|0.25|\n" - "|1|1|0|1|\n" - "|1|2|1|0|\n"; - vector names{"C", "B", "A"}; - auto formatter = [names](Key key) { return names[key]; }; - string actual = conditional.markdown(formatter); + "|-|Zero|0|1|\n" + "|-|One|0.25|0.75|\n" + "|-|Two|0.5|0.5|\n" + "|+|Zero|0.75|0.25|\n" + "|+|One|0|1|\n" + "|+|Two|1|0|\n"; + vector keyNames{"C", "B", "A"}; + auto formatter = [keyNames](Key key) { return keyNames[key]; }; + DecisionTreeFactor::Names names{ + {0, {"Zero", "One", "Two"}}, {1, {"-", "+"}}, {2, {"T", "F"}}}; + string actual = conditional.markdown(formatter, names); EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index b91926cc05..23f093b229 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -28,6 +28,8 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { DiscretePrior actual(X % "2/3"); + EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual.nrParents()); DecisionTreeFactor f(X, "0.4 0.6"); DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); @@ -41,12 +43,18 @@ TEST(DiscretePrior, operator) { } /* ************************************************************************* */ -TEST(DiscretePrior, to_vector) { +TEST(DiscretePrior, pmf) { DiscretePrior prior(X % "2/3"); vector expected {0.4, 0.6}; EXPECT(prior.pmf() == expected); } +/* ************************************************************************* */ +TEST(DiscretePrior, sample) { + DiscretePrior prior(X % "2/3"); + prior.sample(); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam_unstable/discrete/Constraint.h b/gtsam_unstable/discrete/Constraint.h index 5c21028a0c..85748f0546 100644 --- a/gtsam_unstable/discrete/Constraint.h +++ b/gtsam_unstable/discrete/Constraint.h @@ -86,8 +86,8 @@ class GTSAM_EXPORT Constraint : public DiscreteFactor { /// @{ /// Render as markdown table. - std::string markdown( - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override { + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override { return (boost::format("`Constraint` on %1% variables\n") % (size())).str(); } diff --git a/gtsam_unstable/discrete/examples/schedulingExample.cpp b/gtsam_unstable/discrete/examples/schedulingExample.cpp index 3460664db7..2a9addf918 100644 --- a/gtsam_unstable/discrete/examples/schedulingExample.cpp +++ b/gtsam_unstable/discrete/examples/schedulingExample.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference diff --git a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp index 19694c31ec..8260bfb068 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals12.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals12.cpp @@ -115,7 +115,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference diff --git a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp index 4b96b1eeba..cf3ce04535 100644 --- a/gtsam_unstable/discrete/examples/schedulingQuals13.cpp +++ b/gtsam_unstable/discrete/examples/schedulingQuals13.cpp @@ -139,7 +139,7 @@ void runLargeExample() { // Do brute force product and output that to file if (scheduler.nrStudents() == 1) { // otherwise too slow DecisionTreeFactor product = scheduler.product(); - product.dot("scheduling-large", false); + product.dot("scheduling-large", DefaultKeyFormatter, false); } // Do exact inference diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 1b2ce70cd7..86bc303a9a 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -50,7 +50,7 @@ def test_markdown(self): "0/1 1/3 1/1 3/1 0/1 1/0") expected = \ " *P(A|B,C)*:\n\n" \ - "|B|C|0|1|\n" \ + "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ "|0|1|0.25|0.75|\n" \ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 4f017d66a4..5bf6a8d196 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -6,7 +6,7 @@ See LICENSE for the license information Unit tests for Discrete Priors. -Author: Varun Agrawal +Author: Frank Dellaert """ # pylint: disable=no-name-in-module, invalid-name @@ -42,6 +42,11 @@ def test_pmf(self): expected = np.array([0.4, 0.6]) np.testing.assert_allclose(expected, prior.pmf()) + def test_sample(self): + prior = DiscretePrior(X, "2/3") + actual = prior.sample() + self.assertIsInstance(actual, int) + def test_markdown(self): """Test the _repr_markdown_ method."""