diff --git a/CMakeLists.txt b/CMakeLists.txt index a88639b492..74019da446 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,12 +9,18 @@ endif() # Set the version number for the library set (GTSAM_VERSION_MAJOR 4) -set (GTSAM_VERSION_MINOR 1) -set (GTSAM_VERSION_PATCH 1) +set (GTSAM_VERSION_MINOR 2) +set (GTSAM_VERSION_PATCH 0) +set (GTSAM_PRERELEASE_VERSION "a0") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") -set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}") -set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING}) +if (${GTSAM_VERSION_PATCH} EQUAL 0) + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}") +else() + set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}") +endif() +message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}") + set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR}) set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR}) set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH}) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 50b21fc768..7aed00c57d 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -134,17 +134,34 @@ namespace gtsam { return boost::make_shared(dkeys, result); } + /* ************************************************************************* */ + std::vector> DecisionTreeFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs; + for (auto& key : keys()) { + pairs.emplace_back(key, cardinalities_.at(key)); + } + // Reverse to make cartesianProduct output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = cartesianProduct(rpairs); + + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); + } + return result; + } + /* ************************************************************************* */ std::string DecisionTreeFactor::markdown( const KeyFormatter& keyFormatter) const { std::stringstream ss; // Print out header and construct argument for `cartesianProduct`. - std::vector> pairs; ss << "|"; for (auto& key : keys()) { ss << keyFormatter(key) << "|"; - pairs.emplace_back(key, cardinalities_.at(key)); } ss << "value|\n"; @@ -154,12 +171,12 @@ namespace gtsam { ss << ":-:|\n"; // Print out all rows. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = cartesianProduct(rpairs); - for (const auto& assignment : assignments) { + auto rows = enumerate(); + for (const auto& kv : rows) { ss << "|"; + auto assignment = kv.first; for (auto& key : keys()) ss << assignment.at(key) << "|"; - ss << operator()(assignment) << "|\n"; + ss << kv.second << "|\n"; } return ss.str(); } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 27ee67cf23..f90af56dd0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -61,6 +61,15 @@ namespace gtsam { DiscreteFactor(keys.indices()), Potentials(keys, table) { } + /// Single-key specialization + template + DecisionTreeFactor(const DiscreteKey& key, SOURCE table) + : DecisionTreeFactor(DiscreteKeys{key}, table) {} + + /// Single-key specialization, with vector of doubles. + DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) + : DecisionTreeFactor(DiscreteKeys{key}, row) {} + /** Construct from a DiscreteConditional type */ DecisionTreeFactor(const DiscreteConditional& c); @@ -162,6 +171,9 @@ namespace gtsam { // Potentials::reduceWithInverse(inverseReduction); // } + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index d78eed08f2..aed4cec0aa 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace gtsam { @@ -75,6 +76,11 @@ namespace gtsam { // Add inherited versions of add. using Base::add; + /** Add a DiscretePrior using a table or a string */ + void add(const DiscreteKey& key, const std::string& spec) { + emplace_shared(key, spec); + } + /** Add a DiscreteCondtional */ template void add(Args&&... args) { diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 080dbba9bc..46d5509e06 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other, } /* ******************************************************************************** */ -Potentials::ADT DiscreteConditional::choose( +static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, + const DiscreteValues& parentsValues) { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the parent variables. + DiscreteConditional::ADT adt(conditional); + size_t value; + for (Key j : conditional.parents()) { + try { + value = parentsValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (std::out_of_range&) { + parentsValues.print("parentsValues: "); + throw runtime_error("DiscreteConditional::choose: parent value missing"); + }; + } + return adt; +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::choose( const DiscreteValues& parentsValues) const { // Get the big decision tree with all the levels, and then go down the // branches based on the value of the parent variables. - ADT pFS(*this); + ADT adt(*this); size_t value; for (Key j : parents()) { try { value = parentsValues.at(j); - pFS = pFS.choose(j, value); // ADT keeps getting smaller. + adt = adt.choose(j, value); // ADT keeps getting smaller. } catch (exception&) { - cout << "Key: " << j << " Value: " << value << endl; parentsValues.print("parentsValues: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; } - return pFS; + + // Convert ADT to factor. + DiscreteKeys discreteKeys; + for (Key j : frontals()) { + discreteKeys.emplace_back(j, this->cardinality(j)); + } + return boost::make_shared(discreteKeys, adt); } /* ******************************************************************************** */ -DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor( - const DiscreteValues& parentsValues) const { - ADT pFS = choose(parentsValues); +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + const DiscreteValues& frontalValues) const { + // Get the big decision tree with all the levels, and then go down the + // branches based on the value of the frontal variables. + ADT adt(*this); + size_t value; + for (Key j : frontals()) { + try { + value = frontalValues.at(j); + adt = adt.choose(j, value); // ADT keeps getting smaller. + } catch (exception&) { + frontalValues.print("frontalValues: "); + throw runtime_error("DiscreteConditional::choose: frontal value missing"); + }; + } // Convert ADT to factor. - if (nrFrontals() != 1) { - throw std::runtime_error("Expected only one frontal variable in choose."); + DiscreteKeys discreteKeys; + for (Key j : parents()) { + discreteKeys.emplace_back(j, this->cardinality(j)); } - DiscreteKeys keys; - const Key frontalKey = keys_[0]; - size_t frontalCardinality = this->cardinality(frontalKey); - keys.push_back(DiscreteKey(frontalKey, frontalCardinality)); - return boost::make_shared(keys, pFS); + return boost::make_shared(discreteKeys, adt); +} + +/* ******************************************************************************** */ +DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood( + size_t parent_value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value likelihood can only be invoked on single-variable " + "conditional"); + DiscreteValues values; + values.emplace(keys_[0], parent_value); + return likelihood(values); } /* ******************************************************************************** */ void DiscreteConditional::solveInPlace(DiscreteValues* values) const { // TODO: Abhijit asks: is this really the fastest way? He thinks it is. - ADT pFS = choose(*values); // P(F|S=parentsValues) + ADT pFS = Choose(*this, *values); // P(F|S=parentsValues) // Initialize DiscreteValues mpe; @@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const { size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const { // TODO: is this really the fastest way? I think it is. - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // Then, find the max over all remaining // TODO, only works for one key now, seems horribly slow this way @@ -203,10 +248,14 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { static mt19937 rng(2); // random number generator // Get the correct conditional density - ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); @@ -222,13 +271,24 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { return distribution(rng); } +/* ******************************************************************************** */ +size_t DiscreteConditional::sample(size_t parent_value) const { + if (nrParents() != 1) + throw std::invalid_argument( + "Single value sample() can only be invoked on single-parent " + "conditional"); + DiscreteValues values; + values.emplace(keys_.back(), parent_value); + return sample(values); +} + /* ************************************************************************* */ std::string DiscreteConditional::markdown( const KeyFormatter& keyFormatter) const { std::stringstream ss; // Print out signature. - ss << " $P("; + ss << " *P("; bool first = true; for (Key key : frontals()) { if (!first) ss << ","; @@ -237,7 +297,7 @@ std::string DiscreteConditional::markdown( } if (nrParents() == 0) { // We have no parents, call factor method. - ss << ")$:" << std::endl; + ss << ")*:\n" << std::endl; ss << DecisionTreeFactor::markdown(keyFormatter); return ss.str(); } @@ -250,7 +310,7 @@ std::string DiscreteConditional::markdown( ss << keyFormatter(parent); first = false; } - ss << ")$:" << std::endl; + ss << ")*:\n" << std::endl; // Print out header and construct argument for `cartesianProduct`. std::vector> pairs; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index a8000f4867..d21e3ae264 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -62,8 +62,6 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, * conditional probability table (CPT) in 00 01 10 11 order. For * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... * - * The first string is parsed to add a key and parents. - * * Example: DiscreteConditional P(D, {B,E}, table); */ DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents, @@ -75,8 +73,7 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, * probability table (CPT) in 00 01 10 11 order. For three-valued, it would * be 00 01 02 10 11 12 20 21 22, etc.... * - * The first string is parsed to add a key and parents. The second string - * parses into a table. + * The string is parsed into a Signature::Table. * * Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); */ @@ -84,6 +81,10 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} + /// No-parent specialization; can also use DiscretePrior. + DiscreteConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteConditional(Signature(key, {}, spec)) {} + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); @@ -135,13 +136,17 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); } - /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ - ADT choose(const DiscreteValues& parentsValues) const; - /** Restrict to given parent values, returns DecisionTreeFactor */ - DecisionTreeFactor::shared_ptr chooseAsFactor( + DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; + /** Convert to a likelihood factor by providing value before bar. */ + DecisionTreeFactor::shared_ptr likelihood( + const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const; + /** * solve a conditional * @param parentsValues Known values of the parents @@ -156,6 +161,10 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor, */ size_t sample(const DiscreteValues& parentsValues) const; + + /// Single value version. + size_t sample(size_t parent_value) const; + /// @} /// @name Advanced Interface /// @{ diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h index 38091829fb..6856493f7f 100644 --- a/gtsam/discrete/DiscreteFactorGraph.h +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -101,29 +101,12 @@ public EliminateableFactorGraph { /// @} - // Add single key decision-tree factor. - template - void add(const DiscreteKey& j, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j); - emplace_shared(keys, table); + /** Add a decision-tree factor */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); } - - // Add binary key decision-tree factor. - template - void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { - DiscreteKeys keys; - keys.push_back(j1); - keys.push_back(j2); - emplace_shared(keys, table); - } - - // Add shared discreteFactor immediately from arguments. - template - void add(const DiscreteKeys& keys, SOURCE table) { - emplace_shared(keys, table); - } - + /** Return the set of variables involved in the factors (set union) */ KeySet keys() const; diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 86f1bcf63d..ae4dac38fc 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -43,9 +43,7 @@ namespace gtsam { DiscreteKeys() : std::vector::vector() {} /// Construct from a key - DiscreteKeys(const DiscreteKey& key) { - push_back(key); - } + explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); } /// Construct from a vector of keys DiscreteKeys(const std::vector& keys) : diff --git a/gtsam/discrete/DiscretePrior.cpp b/gtsam/discrete/DiscretePrior.cpp new file mode 100644 index 0000000000..3941e0199e --- /dev/null +++ b/gtsam/discrete/DiscretePrior.cpp @@ -0,0 +1,50 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.cpp + * @date December 2021 + * @author Frank Dellaert + */ + +#include + +namespace gtsam { + +void DiscretePrior::print(const std::string& s, + const KeyFormatter& formatter) const { + Base::print(s, formatter); +} + +double DiscretePrior::operator()(size_t value) const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "Single value operator can only be invoked on single-variable " + "priors"); + DiscreteValues values; + values.emplace(keys_[0], value); + return Base::operator()(values); +} + +std::vector DiscretePrior::pmf() const { + if (nrFrontals() != 1) + throw std::invalid_argument( + "DiscretePrior::pmf only defined for single-variable priors"); + const size_t nrValues = cardinalities_.at(keys_[0]); + std::vector array; + array.reserve(nrValues); + for (size_t v = 0; v < nrValues; v++) { + array.push_back(operator()(v)); + } + return array; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h new file mode 100644 index 0000000000..1a7c6ae6cb --- /dev/null +++ b/gtsam/discrete/DiscretePrior.h @@ -0,0 +1,111 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscretePrior.h + * @date December 2021 + * @author Frank Dellaert + */ + +#pragma once + +#include + +#include + +namespace gtsam { + +/** + * A prior probability on a set of discrete variables. + * Derives from DiscreteConditional + */ +class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { + public: + using Base = DiscreteConditional; + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscretePrior() {} + + /// Constructor from factor. + DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + + /** + * Construct from a Signature. + * + * Example: DiscretePrior P(D % "3/2"); + */ + DiscretePrior(const Signature& s) : Base(s) {} + + /** + * Construct from key and a Signature::Table specifying the + * conditional probability table (CPT). + * + * Example: DiscretePrior P(D, table); + */ + DiscretePrior(const DiscreteKey& key, const Signature::Table& table) + : Base(Signature(key, {}, table)) {} + + /** + * Construct from key and a string specifying the conditional + * probability table (CPT). + * + * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + */ + DiscretePrior(const DiscreteKey& key, const std::string& spec) + : DiscretePrior(Signature(key, {}, spec)) {} + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Prior: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard interface + /// @{ + + /// Evaluate given a single value. + double operator()(size_t value) const; + + /// We also want to keep the Base version, taking DiscreteValues: + // TODO(dellaert): does not play well with wrapper! + // using Base::operator(); + + /// Return entire probability mass function. + std::vector pmf() const; + + /** + * solve a conditional + * @return MPE value of the child (1 frontal variable). + */ + size_t solve() const { return Base::solve({}); } + + /** + * sample + * @return sample from conditional + */ + size_t sample() const { return Base::sample({}); } + + /// @} +}; +// DiscretePrior + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index f0dc72a24d..36caccfc83 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -30,25 +30,37 @@ class DiscreteFactor { }; #include -virtual class DecisionTreeFactor: gtsam::DiscreteFactor { +virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); + + DecisionTreeFactor(const gtsam::DiscreteKey& key, + const std::vector& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; string dot(bool showZero = false) const; + std::vector> enumerate() const; string markdown(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; + gtsam::DefaultKeyFormatter) const; }; #include virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); + DiscreteConditional(const gtsam::DiscreteKey& key, string spec); DiscreteConditional(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, @@ -62,20 +74,43 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; gtsam::DecisionTreeFactor* toFactor() const; - gtsam::DecisionTreeFactor* chooseAsFactor(const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* choose( + const gtsam::DiscreteValues& parentsValues) const; + gtsam::DecisionTreeFactor* likelihood( + const gtsam::DiscreteValues& frontalValues) const; + gtsam::DecisionTreeFactor* likelihood(size_t value) const; size_t solve(const gtsam::DiscreteValues& parentsValues) const; size_t sample(const gtsam::DiscreteValues& parentsValues) const; - void solveInPlace(gtsam::DiscreteValues@ parentsValues) const; - void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const; + size_t sample(size_t value) const; + void solveInPlace(gtsam::DiscreteValues @parentsValues) const; + void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; string markdown(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +virtual class DiscretePrior : gtsam::DiscreteConditional { + DiscretePrior(); + DiscretePrior(const gtsam::DecisionTreeFactor& f); + DiscretePrior(const gtsam::DiscreteKey& key, string spec); + void print(string s = "Discrete Prior\n", + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; + double operator()(size_t value) const; + std::vector pmf() const; + size_t solve() const; + size_t sample() const; }; #include -class DiscreteBayesNet { +class DiscreteBayesNet { DiscreteBayesNet(); + void add(const gtsam::DiscreteConditional& s); + void add(const gtsam::DiscreteKey& key, string spec); + void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, + string spec); void add(const gtsam::DiscreteKey& key, - const gtsam::DiscreteKeys& parents, string spec); + const std::vector& parents, string spec); bool empty() const; size_t size() const; gtsam::KeySet keys() const; @@ -86,15 +121,13 @@ class DiscreteBayesNet { bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; string dot(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; - void saveGraph(string s, - const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; - void add(const gtsam::DiscreteConditional& s); + void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; double operator()(const gtsam::DiscreteValues& values) const; gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues sample() const; string markdown(const gtsam::KeyFormatter& keyFormatter = - gtsam::DefaultKeyFormatter) const; + gtsam::DefaultKeyFormatter) const; }; #include @@ -142,11 +175,13 @@ class DotWriter { class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); - + void add(const gtsam::DiscreteKey& j, string table); - void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKeys& keys, string table); - + void add(const std::vector& keys, string table); + bool empty() const; size_t size() const; gtsam::KeySet keys() const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index ad8e9bd2a8..6af7ca7313 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -34,7 +34,7 @@ TEST( DecisionTreeFactor, constructors) DiscreteKey X(0,2), Y(1,3), Z(2,2); // Create factors - DecisionTreeFactor f1(X, "2 8"); + DecisionTreeFactor f1(X, {2, 8}); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); EXPECT_LONGS_EQUAL(1,f1.size()); @@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max) DecisionTreeFactor::shared_ptr actual22 = f2.sum(1); } +/* ************************************************************************* */ +// Check enumerate yields the correct list of assignment/value pairs. +TEST(DecisionTreeFactor, enumerate) { + DiscreteKey A(12, 3), B(5, 2); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); + auto actual = f.enumerate(); + std::vector> expected; + DiscreteValues values; + for (size_t a : {0, 1, 2}) { + for (size_t b : {0, 1}) { + values[12] = a; + values[5] = b; + expected.emplace_back(values, f(values)); + } + } + EXPECT(actual == expected); +} + /* ************************************************************************* */ // Check markdown representation looks as expected. TEST(DecisionTreeFactor, markdown) { DiscreteKey A(12, 3), B(5, 2); - DecisionTreeFactor f1(A & B, "1 2 3 4 5 6"); + DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); string expected = "|A|B|value|\n" "|:-:|:-:|:-:|\n" @@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) { "|2|0|5|\n" "|2|1|6|\n"; auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; - string actual = f1.markdown(formatter); + string actual = f.markdown(formatter); EXPECT(actual == expected); } diff --git a/gtsam/discrete/tests/testDiscreteBayesNet.cpp b/gtsam/discrete/tests/testDiscreteBayesNet.cpp index ea58165669..1de45905a6 100644 --- a/gtsam/discrete/tests/testDiscreteBayesNet.cpp +++ b/gtsam/discrete/tests/testDiscreteBayesNet.cpp @@ -75,8 +75,8 @@ TEST(DiscreteBayesNet, bayesNet) { TEST(DiscreteBayesNet, Asia) { DiscreteBayesNet asia; - asia.add(Asia % "99/1"); - asia.add(Smoking % "50/50"); + asia.add(Asia, "99/1"); + asia.add(Smoking % "50/50"); // Signature version asia.add(Tuberculosis | Asia = "99/1 95/5"); asia.add(LungCancer | Smoking = "99/1 90/10"); @@ -180,13 +180,13 @@ TEST(DiscreteBayesNet, markdown) { string expected = "`DiscreteBayesNet` of size 2\n" "\n" - " $P(Asia)$:\n" + " *P(Asia)*:\n\n" "|Asia|value|\n" "|:-:|:-:|\n" "|0|0.99|\n" "|1|0.01|\n" "\n" - " $P(Smoking|Asia)$:\n" + " *P(Smoking|Asia)*:\n\n" "|Asia|0|1|\n" "|:-:|:-:|:-:|\n" "|0|0.8|0.2|\n" diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index f42792e71d..00ae1acd01 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -10,10 +10,11 @@ * -------------------------------------------------------------------------- */ /* - * @file testDecisionTreeFactor.cpp + * @file testDiscreteConditional.cpp * @brief unit tests for DiscreteConditional * @author Duy-Nguyen Ta - * @date Feb 14, 2011 + * @author Frank dellaert + * @date Feb 14, 2011 */ #include @@ -30,24 +31,21 @@ using namespace std; using namespace gtsam; /* ************************************************************************* */ -TEST( DiscreteConditional, constructors) -{ - DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - - DiscreteConditional::shared_ptr expected1 = // - boost::make_shared(X | Y = "1/1 2/3 1/4"); - EXPECT(expected1); - EXPECT_LONGS_EQUAL(0, *(expected1->beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected1->beginParents())); - EXPECT(expected1->endParents() == expected1->end()); - EXPECT(expected1->endFrontals() == expected1->beginParents()); - +TEST(DiscreteConditional, constructors) { + DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! + + DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); + EXPECT(expected.endParents() == expected.end()); + EXPECT(expected.endFrontals() == expected.beginParents()); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(*expected1, actual1, 1e-9)); + EXPECT(assert_equal(expected, actual1, 1e-9)); - DecisionTreeFactor f2(X & Y & Z, - "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); + DecisionTreeFactor f2( + X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); } @@ -107,13 +105,27 @@ TEST(DiscreteConditional, Combine) { EXPECT(assert_equal(expected, *actual, 1e-5)); } +/* ************************************************************************* */ +TEST(DiscreteConditional, likelihood) { + DiscreteKey X(0, 2), Y(1, 3); + DiscreteConditional conditional(X | Y = "2/8 4/6 5/5"); + + auto actual0 = conditional.likelihood(0); + DecisionTreeFactor expected0(Y, "0.2 0.4 0.5"); + EXPECT(assert_equal(expected0, *actual0, 1e-9)); + + auto actual1 = conditional.likelihood(1); + DecisionTreeFactor expected1(Y, "0.8 0.6 0.5"); + EXPECT(assert_equal(expected1, *actual1, 1e-9)); +} + /* ************************************************************************* */ // Check markdown representation looks as expected, no parents. TEST(DiscreteConditional, markdown_prior) { DiscreteKey A(Symbol('x', 1), 3); DiscreteConditional conditional(A % "1/2/2"); string expected = - " $P(x1)$:\n" + " *P(x1)*:\n\n" "|x1|value|\n" "|:-:|:-:|\n" "|0|0.2|\n" @@ -130,7 +142,7 @@ TEST(DiscreteConditional, markdown_multivalued) { DiscreteConditional conditional( A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3"); string expected = - " $P(a1|b1)$:\n" + " *P(a1|b1)*:\n\n" "|b1|0|1|2|\n" "|:-:|:-:|:-:|:-:|\n" "|0|0.02|0.88|0.1|\n" @@ -148,7 +160,7 @@ TEST(DiscreteConditional, markdown) { DiscreteKey A(2, 2), B(1, 2), C(0, 3); DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0"); string expected = - " $P(A|B,C)$:\n" + " *P(A|B,C)*:\n\n" "|B|C|0|1|\n" "|:-:|:-:|:-:|:-:|\n" "|0|0|0|1|\n" diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp new file mode 100644 index 0000000000..b91926cc05 --- /dev/null +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -0,0 +1,55 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testDiscretePrior.cpp + * @brief unit tests for DiscretePrior + * @author Frank dellaert + * @date December 2021 + */ + +#include +#include +#include + +using namespace std; +using namespace gtsam; + +static const DiscreteKey X(0, 2); + +/* ************************************************************************* */ +TEST(DiscretePrior, constructors) { + DiscretePrior actual(X % "2/3"); + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + EXPECT(assert_equal(expected, actual, 1e-9)); +} + +/* ************************************************************************* */ +TEST(DiscretePrior, operator) { + DiscretePrior prior(X % "2/3"); + EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); + EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); +} + +/* ************************************************************************* */ +TEST(DiscretePrior, to_vector) { + DiscretePrior prior(X % "2/3"); + vector expected {0.4, 0.6}; + EXPECT(prior.pmf() == expected); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index 36c1ddda58..e34613c3b3 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i, Potentials::ADT p(dummy & areaKey, available_); // available_ is Doodle string Potentials::ADT q = p.choose(dummyIndex, *slot); - DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q)); - CSP::push_back(f); + CSP::add(areaKey, q); } else { - CSP::add(s.key_, areaKey, available_); // available_ is Doodle string + DiscreteKeys keys {s.key_, areaKey}; + CSP::add(keys, available_); // available_ is Doodle string } } diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py new file mode 100644 index 0000000000..12a60d5cb1 --- /dev/null +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -0,0 +1,54 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for DecisionTreeFactors. +Author: Frank Dellaert +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam.utils.test_case import GtsamTestCase + + +class TestDecisionTreeFactor(GtsamTestCase): + """Tests for DecisionTreeFactors.""" + + def setUp(self): + A = (12, 3) + B = (5, 2) + self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + + def test_enumerate(self): + actual = self.factor.enumerate() + _, values = zip(*actual) + self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def test_markdown(self): + """Test whether the _repr_markdown_ method.""" + + expected = \ + "|A|B|value|\n" \ + "|:-:|:-:|:-:|\n" \ + "|0|0|1|\n" \ + "|0|1|2|\n" \ + "|1|0|3|\n" \ + "|1|1|4|\n" \ + "|2|0|5|\n" \ + "|2|1|6|\n" + + def formatter(x: int): + return "A" if x == 12 else "B" + + actual = self.factor._repr_markdown_(formatter) + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index bf09da1935..bdd5a05464 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -14,7 +14,7 @@ import unittest from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, - DiscreteKeys, DiscreteValues, Ordering) + DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) from gtsam.utils.test_case import GtsamTestCase @@ -53,24 +53,18 @@ def test_Asia(self): XRay = (2, 2) Dyspnea = (1, 2) - def P(keys): - dks = DiscreteKeys() - for key in keys: - dks.push_back(key) - return dks - asia = DiscreteBayesNet() - asia.add(Asia, P([]), "99/1") - asia.add(Smoking, P([]), "50/50") + asia.add(Asia, "99/1") + asia.add(Smoking, "50/50") - asia.add(Tuberculosis, P([Asia]), "99/1 95/5") - asia.add(LungCancer, P([Smoking]), "99/1 90/10") - asia.add(Bronchitis, P([Smoking]), "70/30 40/60") + asia.add(Tuberculosis, [Asia], "99/1 95/5") + asia.add(LungCancer, [Smoking], "99/1 90/10") + asia.add(Bronchitis, [Smoking], "70/30 40/60") - asia.add(Either, P([Tuberculosis, LungCancer]), "F T T T") + asia.add(Either, [Tuberculosis, LungCancer], "F T T T") - asia.add(XRay, P([Either]), "95/5 2/98") - asia.add(Dyspnea, P([Either, Bronchitis]), "9/1 2/8 3/7 1/9") + asia.add(XRay, [Either], "95/5 2/98") + asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") # Convert to factor graph fg = DiscreteFactorGraph(asia) @@ -80,7 +74,7 @@ def P(keys): for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) - expected2 = DiscreteConditional(Bronchitis, P([]), "11/9") + expected2 = DiscretePrior(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index d87734de99..b1ed4fe696 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -14,20 +14,10 @@ import unittest from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, - Ordering) + DiscreteConditional, DiscreteFactorGraph, Ordering) from gtsam.utils.test_case import GtsamTestCase -def P(*args): - """ Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.""" - # TODO: We can make life easier by providing variable argument functions in C++ itself. - dks = DiscreteKeys() - for key in args: - dks.push_back(key) - return dks - - class TestDiscreteBayesNet(GtsamTestCase): """Tests for Discrete Bayes Nets.""" @@ -40,25 +30,25 @@ def test_elimination(self): # Create thin-tree Bayesnet. bayesNet = DiscreteBayesNet() - bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1") - bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4") - bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1") - bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1") - bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4") - bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1") - bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1") - bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4") - bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1") - bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") + bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") + bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[12], P(keys[14]), "3/1 3/1") - bayesNet.add(keys[13], P(keys[14]), "1/3 3/1") + bayesNet.add(keys[12], [keys[14]], "3/1 3/1") + bayesNet.add(keys[13], [keys[14]], "1/3 3/1") - bayesNet.add(keys[14], P(), "1/3") + bayesNet.add(keys[14], "1/3") # Create a factor graph out of the Bayes net. factorGraph = DiscreteFactorGraph(bayesNet) diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 5e24dc40b9..1b2ce70cd7 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -13,12 +13,29 @@ import unittest -from gtsam import DiscreteConditional, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" + + def test_single_value_versions(self): + X = (0, 2) + Y = (1, 3) + conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") + + actual0 = conditional.likelihood(0) + expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5") + self.gtsamAssertEquals(actual0, expected0, 1e-9) + + actual1 = conditional.likelihood(1) + expected1 = DecisionTreeFactor(Y, "0.8 0.6 0.5") + self.gtsamAssertEquals(actual1, expected1, 1e-9) + + actual = conditional.sample(2) + self.assertIsInstance(actual, int) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -32,7 +49,7 @@ def test_markdown(self): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") expected = \ - " $P(A|B,C)$:\n" \ + " *P(A|B,C)*:\n\n" \ "|B|C|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index 9dafff33f0..1ba145e096 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -32,11 +32,11 @@ def test_evaluation(self): graph = DiscreteFactorGraph() # Add two unary factors (priors) - graph.add(P1, "0.9 0.3") + graph.add(P1, [0.9, 0.3]) graph.add(P2, "0.9 0.6") # Add a binary factor - graph.add(P1, P2, "4 1 10 4") + graph.add([P1, P2], "4 1 10 4") # Instantiate Values assignment = DiscreteValues() @@ -85,8 +85,8 @@ def test_optimize(self): # A simple factor graph (A)-fAC-(C)-fBC-(B) # with smoothness priors graph = DiscreteFactorGraph() - graph.add(A, C, "3 1 1 3") - graph.add(C, B, "3 1 1 3") + graph.add([A, C], "3 1 1 3") + graph.add([C, B], "3 1 1 3") # Test optimization expectedValues = DiscreteValues() @@ -105,8 +105,8 @@ def test_MPE(self): # Create Factor graph graph = DiscreteFactorGraph() - graph.add(C, A, "0.2 0.8 0.3 0.7") - graph.add(C, B, "0.1 0.9 0.4 0.6") + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") actualMPE = graph.optimize() diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py new file mode 100644 index 0000000000..4f017d66a4 --- /dev/null +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -0,0 +1,60 @@ +""" +GTSAM Copyright 2010-2021, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Discrete Priors. +Author: Varun Agrawal +""" + +# pylint: disable=no-name-in-module, invalid-name + +import unittest + +import numpy as np +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior +from gtsam.utils.test_case import GtsamTestCase + +X = 0, 2 + + +class TestDiscretePrior(GtsamTestCase): + """Tests for Discrete Priors.""" + + def test_constructor(self): + """Test various constructors.""" + actual = DiscretePrior(X, "2/3") + keys = DiscreteKeys() + keys.push_back(X) + f = DecisionTreeFactor(keys, "0.4 0.6") + expected = DiscretePrior(f) + self.gtsamAssertEquals(actual, expected) + + def test_operator(self): + prior = DiscretePrior(X, "2/3") + self.assertAlmostEqual(prior(0), 0.4) + self.assertAlmostEqual(prior(1), 0.6) + + def test_pmf(self): + prior = DiscretePrior(X, "2/3") + expected = np.array([0.4, 0.6]) + np.testing.assert_allclose(expected, prior.pmf()) + + def test_markdown(self): + """Test the _repr_markdown_ method.""" + + prior = DiscretePrior(X, "2/3") + expected = " *P(0)*:\n\n" \ + "|0|value|\n" \ + "|:-:|:-:|\n" \ + "|0|0.4|\n" \ + "|1|0.6|\n" \ + + actual = prior._repr_markdown_() + self.assertEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main()