From 8c752049de6e683206de4678c48ed3e884f150a4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 11:57:20 -0800 Subject: [PATCH 01/13] move normalization constant to base class --- gtsam/inference/Conditional-inst.h | 23 +++++++++++++++++++++++ gtsam/inference/Conditional.h | 18 ++++++++++++++++++ gtsam/linear/GaussianConditional.h | 9 +-------- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index ee13946d9c..5a17c44ccd 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -56,4 +56,27 @@ double Conditional::evaluate( const HybridValues& c) const { throw std::runtime_error("Conditional::evaluate is not implemented"); } + +/* ************************************************************************* */ +template +double Conditional::normalizationConstant() const { + return std::exp(logNormalizationConstant()); +} + +/* ************************************************************************* */ +template +bool Conditional::checkInvariants( + const HybridValues& values) const { + const double probability = evaluate(values); + if (probability < 0.0 || probability > 1.0) + return false; // probability is not in [0,1] + const double logProb = logProbability(values); + if (std::abs(probability - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with probability + const double expected = + this->logNormalizationConstant() - this->error(values); + if (std::abs(logProb - expected) > 1e-9) + return false; // logProb is not consistent with error +} + } // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index 9083c5c1a8..bb75f9c6e4 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -141,6 +141,15 @@ namespace gtsam { return evaluate(x); } + /** + * By default, log normalization constant = 0.0. + * Override if this depends on the parameters. + */ + virtual double logNormalizationConstant() const; + + /** Non-virtual, exponentiate logNormalizationConstant. */ + double normalizationConstant() const; + /// @} /// @name Advanced Interface /// @{ @@ -172,7 +181,16 @@ namespace gtsam { /** Mutable iterator pointing past the last parent key. */ typename FACTOR::iterator endParents() { return asFactor().end(); } + /** Check that the invariants hold for derived class at a given point. */ + bool checkInvariants(const HybridValues& values) const; + + /// @} + private: + + /// @name Serialization + /// @{ + // Cast to factor type (non-const) (casts down to derived conditional type, then up to factor type) FACTOR& asFactor() { return static_cast(static_cast(*this)); } diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 880d13064e..69e2ef2d34 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -136,14 +136,7 @@ namespace gtsam { * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) * log = - 0.5 * n*log(2*pi) - 0.5 * log det(Sigma) */ - double logNormalizationConstant() const; - - /** - * normalization constant = 1.0 / sqrt((2*pi)^n*det(Sigma)) - */ - inline double normalizationConstant() const { - return exp(logNormalizationConstant()); - } + double logNormalizationConstant() const override; /** * Calculate log-probability log(evaluate(x)) for given values `x`: From 7ea8bd0fbacf5cba4cf2300ab66b0b957de67135 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 11:58:17 -0800 Subject: [PATCH 02/13] Add check --- gtsam/linear/tests/testGaussianConditional.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index eb90f8aabe..0bfb95351c 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -134,6 +134,21 @@ static const auto unitPrior = noiseModel::Isotropic::Sigma(1, sigma)); } // namespace density +/* ************************************************************************* */ +bool checkInvariants(const GaussianConditional* self, + const HybridValues& values) { + const double probability = self->evaluate(values); + if (probability < 0.0 || probability > 1.0) + return false; // probability is not in [0,1] + const double logProb = self->logProbability(values); + if (std::abs(probability - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with probability + const double expected = + self->logNormalizationConstant() - self->error(values); + if (std::abs(logProb - expected) > 1e-9) + return false; // logProb is not consistent with error +} + /* ************************************************************************* */ // Check that the evaluate function matches direct calculation with R. TEST(GaussianConditional, Evaluate1) { @@ -164,6 +179,7 @@ TEST(GaussianConditional, Evaluate1) { integral += 0.1 * sigma * density; } EXPECT_DOUBLES_EQUAL(1.0, integral, 1e-9); + EXPECT(checkInvariants(&density::unitPrior, mean)); } /* ************************************************************************* */ From a4aebb548ae5b2231d4b41c0a9771f8bcc7f22c1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 12 Jan 2023 13:29:28 -0800 Subject: [PATCH 03/13] Expose correct error versions --- gtsam/discrete/DiscreteConditional.h | 3 ++- gtsam/hybrid/GaussianMixture.cpp | 7 +++++++ gtsam/hybrid/GaussianMixture.h | 13 +++++++++++-- gtsam/hybrid/HybridConditional.cpp | 15 +++++++++++++++ gtsam/hybrid/HybridConditional.h | 3 +++ gtsam/linear/GaussianConditional.h | 2 +- gtsam/linear/HessianFactor.h | 3 +++ gtsam/linear/JacobianFactor.h | 7 ++++++- 8 files changed, 48 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 94451d407c..2760ea538f 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -243,7 +243,8 @@ class GTSAM_EXPORT DiscreteConditional return -error(x); } - using DecisionTreeFactor::evaluate; + using DecisionTreeFactor::error; ///< HybridValues version + using DecisionTreeFactor::evaluate; ///< HybridValues version /// @} diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index c5ffed27b5..f61b280cb7 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -289,6 +289,13 @@ AlgebraicDecisionTree GaussianMixture::logProbability( return errorTree; } +/* *******************************************************************************/ +double GaussianMixture::error(const HybridValues &values) const { + // Directly index to get the conditional, no need to build the whole tree. + auto conditional = conditionals_(values.discrete()); + return conditional->error(values.continuous()); +} + /* *******************************************************************************/ double GaussianMixture::logProbability(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a9f82d5559..a8d07cbc84 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -174,8 +174,17 @@ class GTSAM_EXPORT GaussianMixture const VectorValues &continuousValues) const; /** - * @brief Compute the logProbability of this Gaussian Mixture given the - * continuous values and a discrete assignment. + * @brief Compute the error of this Gaussian Mixture. + * + * log(probability(x)) = K - error(x) + * + * @param values Continuous values and discrete assignment. + * @return double + */ + double error(const HybridValues &values) const override; + + /** + * @brief Compute the logProbability of this Gaussian Mixture. * * @param values Continuous values and discrete assignment. * @return double diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index df92ffcb8d..55fd5d5d44 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -121,6 +121,21 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const { : !(e->inner_); } +/* ************************************************************************ */ +double HybridConditional::error(const HybridValues &values) const { + if (auto gc = asGaussian()) { + return gc->error(values.continuous()); + } + if (auto gm = asMixture()) { + return gm->error(values); + } + if (auto dc = asDiscrete()) { + return dc->error(values.discrete()); + } + throw std::runtime_error( + "HybridConditional::error: conditional type not handled"); +} + /* ************************************************************************ */ double HybridConditional::logProbability(const HybridValues &values) const { if (auto gc = asGaussian()) { diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 030e6c8354..19c070974b 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -176,6 +176,9 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() const { return inner_; } + /// Return the error of the underlying conditional. + double error(const HybridValues& values) const override; + /// Return the logProbability of the underlying conditional. double logProbability(const HybridValues& values) const override; diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 69e2ef2d34..18f0257cb0 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -264,7 +264,7 @@ namespace gtsam { using Conditional::evaluate; // Expose evaluate(const HybridValues&) method.. using Conditional::operator(); // Expose evaluate(const HybridValues&) method.. - using Base::error; // Expose error(const HybridValues&) method.. + using JacobianFactor::error; // Expose error(const HybridValues&) method.. /// @} diff --git a/gtsam/linear/HessianFactor.h b/gtsam/linear/HessianFactor.h index 3eefe12288..492df138f4 100644 --- a/gtsam/linear/HessianFactor.h +++ b/gtsam/linear/HessianFactor.h @@ -196,6 +196,9 @@ namespace gtsam { /** Compare to another factor for testing (implementing Testable) */ bool equals(const GaussianFactor& lf, double tol = 1e-9) const override; + /// HybridValues simply extracts the \class VectorValues and calls error. + using GaussianFactor::error; + /** * Evaluate the factor error f(x). * returns 0.5*[x -1]'*H*[x -1] (also see constructor documentation) diff --git a/gtsam/linear/JacobianFactor.h b/gtsam/linear/JacobianFactor.h index 8bcf18268a..ae661c642e 100644 --- a/gtsam/linear/JacobianFactor.h +++ b/gtsam/linear/JacobianFactor.h @@ -198,7 +198,12 @@ namespace gtsam { Vector unweighted_error(const VectorValues& c) const; /** (A*x-b) */ Vector error_vector(const VectorValues& c) const; /** (A*x-b)/sigma */ - double error(const VectorValues& c) const override; /** 0.5*(A*x-b)'*D*(A*x-b) */ + + /// HybridValues simply extracts the \class VectorValues and calls error. + using GaussianFactor::error; + + //// 0.5*(A*x-b)'*D*(A*x-b). + double error(const VectorValues& c) const override; /** Return the augmented information matrix represented by this GaussianFactor. * The augmented information matrix contains the information matrix with an From ebb5ae6f183b1844a0444dc2988707a796bab743 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 12:28:39 -0800 Subject: [PATCH 04/13] Expose correct error versions --- gtsam/linear/GaussianConditional.cpp | 9 +++++++-- gtsam/linear/GaussianConditional.h | 7 ++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 10f4eabbbf..7792e119b6 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -205,9 +205,14 @@ namespace gtsam { } /* ************************************************************************* */ - double GaussianConditional::evaluate(const VectorValues& c) const { - return exp(logProbability(c)); + double GaussianConditional::evaluate(const VectorValues& x) const { + return exp(logProbability(x)); } + + double GaussianConditional::evaluate(const HybridValues& x) const { + return evaluate(x.continuous()); + } + /* ************************************************************************* */ VectorValues GaussianConditional::solve(const VectorValues& x) const { // Concatenate all vector values that correspond to parent variables diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 18f0257cb0..15efeae011 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -262,7 +262,12 @@ namespace gtsam { */ double logProbability(const HybridValues& x) const override; - using Conditional::evaluate; // Expose evaluate(const HybridValues&) method.. + /** + * Calculate probability for HybridValues `x`. + * Simply dispatches to VectorValues version. + */ + double evaluate(const HybridValues& x) const override; + using Conditional::operator(); // Expose evaluate(const HybridValues&) method.. using JacobianFactor::error; // Expose error(const HybridValues&) method.. From b99d464049ac45e5da862f865c6d2ef989f1b184 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 12:29:11 -0800 Subject: [PATCH 05/13] Consistency test in testGaussianConditional --- gtsam/inference/Conditional-inst.h | 16 -------- gtsam/inference/Conditional.h | 5 +-- .../linear/tests/testGaussianConditional.cpp | 40 +++++++++++++------ 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 5a17c44ccd..1b439649e9 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -63,20 +63,4 @@ double Conditional::normalizationConstant() const { return std::exp(logNormalizationConstant()); } -/* ************************************************************************* */ -template -bool Conditional::checkInvariants( - const HybridValues& values) const { - const double probability = evaluate(values); - if (probability < 0.0 || probability > 1.0) - return false; // probability is not in [0,1] - const double logProb = logProbability(values); - if (std::abs(probability - std::exp(logProb)) > 1e-9) - return false; // logProb is not consistent with probability - const double expected = - this->logNormalizationConstant() - this->error(values); - if (std::abs(logProb - expected) > 1e-9) - return false; // logProb is not consistent with error -} - } // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index bb75f9c6e4..bba4c7bd5b 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -145,7 +145,7 @@ namespace gtsam { * By default, log normalization constant = 0.0. * Override if this depends on the parameters. */ - virtual double logNormalizationConstant() const; + virtual double logNormalizationConstant() const { return 0.0; } /** Non-virtual, exponentiate logNormalizationConstant. */ double normalizationConstant() const; @@ -181,9 +181,6 @@ namespace gtsam { /** Mutable iterator pointing past the last parent key. */ typename FACTOR::iterator endParents() { return asFactor().end(); } - /** Check that the invariants hold for derived class at a given point. */ - bool checkInvariants(const HybridValues& values) const; - /// @} private: diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index 0bfb95351c..12c668c258 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -135,18 +136,20 @@ static const auto unitPrior = } // namespace density /* ************************************************************************* */ -bool checkInvariants(const GaussianConditional* self, - const HybridValues& values) { - const double probability = self->evaluate(values); +template +bool checkInvariants(const GaussianConditional& conditional, + const VALUES& values) { + const double probability = conditional.evaluate(values); if (probability < 0.0 || probability > 1.0) return false; // probability is not in [0,1] - const double logProb = self->logProbability(values); + const double logProb = conditional.logProbability(values); if (std::abs(probability - std::exp(logProb)) > 1e-9) return false; // logProb is not consistent with probability const double expected = - self->logNormalizationConstant() - self->error(values); + conditional.logNormalizationConstant() - conditional.error(values); if (std::abs(logProb - expected) > 1e-9) return false; // logProb is not consistent with error + return true; } /* ************************************************************************* */ @@ -169,6 +172,12 @@ TEST(GaussianConditional, Evaluate1) { using density::key; using density::sigma; + // Check Invariants at the mean and a different value + for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { + EXPECT(checkInvariants(density::unitPrior, vv)); + EXPECT(checkInvariants(density::unitPrior, HybridValues{vv, {}, {}})); + } + // Let's numerically integrate and see that we integrate to 1.0. double integral = 0.0; // Loop from -5*sigma to 5*sigma in 0.1*sigma steps: @@ -179,7 +188,6 @@ TEST(GaussianConditional, Evaluate1) { integral += 0.1 * sigma * density; } EXPECT_DOUBLES_EQUAL(1.0, integral, 1e-9); - EXPECT(checkInvariants(&density::unitPrior, mean)); } /* ************************************************************************* */ @@ -196,6 +204,12 @@ TEST(GaussianConditional, Evaluate2) { using density::key; using density::sigma; + // Check Invariants at the mean and a different value + for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { + EXPECT(checkInvariants(density::widerPrior, vv)); + EXPECT(checkInvariants(density::widerPrior, HybridValues{vv, {}, {}})); + } + // Let's numerically integrate and see that we integrate to 1.0. double integral = 0.0; // Loop from -5*sigma to 5*sigma in 0.1*sigma steps: @@ -400,17 +414,17 @@ TEST(GaussianConditional, FromMeanAndStddev) { double expected1 = 0.5 * e1.dot(e1); EXPECT_DOUBLES_EQUAL(expected1, conditional1.error(values), 1e-9); - double expected2 = conditional1.logNormalizationConstant() - 0.5 * e1.dot(e1); - EXPECT_DOUBLES_EQUAL(expected2, conditional1.logProbability(values), 1e-9); - auto conditional2 = GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), A2, X(2), b, sigma); Vector2 e2 = (x0 - (A1 * x1 + A2 * x2 + b)) / sigma; - double expected3 = 0.5 * e2.dot(e2); - EXPECT_DOUBLES_EQUAL(expected3, conditional2.error(values), 1e-9); + double expected2 = 0.5 * e2.dot(e2); + EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9); - double expected4 = conditional2.logNormalizationConstant() - 0.5 * e2.dot(e2); - EXPECT_DOUBLES_EQUAL(expected4, conditional2.logProbability(values), 1e-9); + // Check Invariants for both conditionals + for (auto conditional : {conditional1, conditional2}) { + EXPECT(checkInvariants(conditional, values)); + EXPECT(checkInvariants(conditional, HybridValues{values, {}, {}})); + } } /* ************************************************************************* */ From cd2d37e724fe9465eb7f474e00eb9a7d1eef24b5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Fri, 13 Jan 2023 14:55:14 -0800 Subject: [PATCH 06/13] Made CheckInvariants a static method in Conditional.* --- gtsam/inference/Conditional-inst.h | 18 +++++++++++ gtsam/inference/Conditional.h | 4 +++ .../linear/tests/testGaussianConditional.cpp | 32 ++++++------------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 1b439649e9..8445b74bdf 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -63,4 +63,22 @@ double Conditional::normalizationConstant() const { return std::exp(logNormalizationConstant()); } +/* ************************************************************************* */ +template +template +bool Conditional::CheckInvariants( + const DERIVEDCONDITIONAL& conditional, const VALUES& values) { + const double probability = conditional.evaluate(values); + if (probability < 0.0 || probability > 1.0) + return false; // probability is not in [0,1] + const double logProb = conditional.logProbability(values); + if (std::abs(probability - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with probability + const double expected = + conditional.logNormalizationConstant() - conditional.error(values); + if (std::abs(logProb - expected) > 1e-9) + return false; // logProb is not consistent with error + return true; +} + } // namespace gtsam diff --git a/gtsam/inference/Conditional.h b/gtsam/inference/Conditional.h index bba4c7bd5b..b4b1080aaa 100644 --- a/gtsam/inference/Conditional.h +++ b/gtsam/inference/Conditional.h @@ -181,6 +181,10 @@ namespace gtsam { /** Mutable iterator pointing past the last parent key. */ typename FACTOR::iterator endParents() { return asFactor().end(); } + template + static bool CheckInvariants(const DERIVEDCONDITIONAL& conditional, + const VALUES& values); + /// @} private: diff --git a/gtsam/linear/tests/testGaussianConditional.cpp b/gtsam/linear/tests/testGaussianConditional.cpp index 12c668c258..0479ce9a11 100644 --- a/gtsam/linear/tests/testGaussianConditional.cpp +++ b/gtsam/linear/tests/testGaussianConditional.cpp @@ -135,23 +135,6 @@ static const auto unitPrior = noiseModel::Isotropic::Sigma(1, sigma)); } // namespace density -/* ************************************************************************* */ -template -bool checkInvariants(const GaussianConditional& conditional, - const VALUES& values) { - const double probability = conditional.evaluate(values); - if (probability < 0.0 || probability > 1.0) - return false; // probability is not in [0,1] - const double logProb = conditional.logProbability(values); - if (std::abs(probability - std::exp(logProb)) > 1e-9) - return false; // logProb is not consistent with probability - const double expected = - conditional.logNormalizationConstant() - conditional.error(values); - if (std::abs(logProb - expected) > 1e-9) - return false; // logProb is not consistent with error - return true; -} - /* ************************************************************************* */ // Check that the evaluate function matches direct calculation with R. TEST(GaussianConditional, Evaluate1) { @@ -174,8 +157,9 @@ TEST(GaussianConditional, Evaluate1) { // Check Invariants at the mean and a different value for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { - EXPECT(checkInvariants(density::unitPrior, vv)); - EXPECT(checkInvariants(density::unitPrior, HybridValues{vv, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, vv)); + EXPECT(GaussianConditional::CheckInvariants(density::unitPrior, + HybridValues{vv, {}, {}})); } // Let's numerically integrate and see that we integrate to 1.0. @@ -206,8 +190,9 @@ TEST(GaussianConditional, Evaluate2) { // Check Invariants at the mean and a different value for (auto vv : {mean, VectorValues{{key, Vector1(4)}}}) { - EXPECT(checkInvariants(density::widerPrior, vv)); - EXPECT(checkInvariants(density::widerPrior, HybridValues{vv, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, vv)); + EXPECT(GaussianConditional::CheckInvariants(density::widerPrior, + HybridValues{vv, {}, {}})); } // Let's numerically integrate and see that we integrate to 1.0. @@ -422,8 +407,9 @@ TEST(GaussianConditional, FromMeanAndStddev) { // Check Invariants for both conditionals for (auto conditional : {conditional1, conditional2}) { - EXPECT(checkInvariants(conditional, values)); - EXPECT(checkInvariants(conditional, HybridValues{values, {}, {}})); + EXPECT(GaussianConditional::CheckInvariants(conditional, values)); + EXPECT(GaussianConditional::CheckInvariants(conditional, + HybridValues{values, {}, {}})); } } From 0a6334ef1fba261f5f3d0e3b6ceed9d6c9a55b70 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 10:22:41 -0800 Subject: [PATCH 07/13] check invariants --- gtsam/discrete/tests/testDiscreteConditional.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index fdfe4a145b..6e73cfc6eb 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -96,6 +96,7 @@ TEST(DiscreteConditional, PriorProbability) { DiscreteConditional dc(Asia, "4/6"); DiscreteValues values{{asiaKey, 0}}; EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9); + EXPECT(DiscreteConditional::CheckInvariants(dc, values)); } /* ************************************************************************* */ @@ -109,6 +110,7 @@ TEST(DiscreteConditional, probability) { EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9); EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9); EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9); + EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given)); } /* ************************************************************************* */ From 693d18233a48f8961b0b669b9983d51a39c5105d Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 10:23:00 -0800 Subject: [PATCH 08/13] Adapt to continuous densities --- gtsam/inference/Conditional-inst.h | 11 ++++++----- gtsam/linear/GaussianConditional.h | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/gtsam/inference/Conditional-inst.h b/gtsam/inference/Conditional-inst.h index 8445b74bdf..4aa9c51265 100644 --- a/gtsam/inference/Conditional-inst.h +++ b/gtsam/inference/Conditional-inst.h @@ -68,12 +68,13 @@ template template bool Conditional::CheckInvariants( const DERIVEDCONDITIONAL& conditional, const VALUES& values) { - const double probability = conditional.evaluate(values); - if (probability < 0.0 || probability > 1.0) - return false; // probability is not in [0,1] + const double prob_or_density = conditional.evaluate(values); + if (prob_or_density < 0.0) return false; // prob_or_density is negative. + if (std::abs(prob_or_density - conditional(values)) > 1e-9) + return false; // operator and evaluate differ const double logProb = conditional.logProbability(values); - if (std::abs(probability - std::exp(logProb)) > 1e-9) - return false; // logProb is not consistent with probability + if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9) + return false; // logProb is not consistent with prob_or_density const double expected = conditional.logNormalizationConstant() - conditional.error(values); if (std::abs(logProb - expected) > 1e-9) diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 15efeae011..4611e30d06 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -34,7 +34,7 @@ namespace gtsam { /** * A GaussianConditional functions as the node in a Bayes network. * It has a set of parents y,z, etc. and implements a probability density on x. - * The negative log-probability is given by \f$ \frac{1}{2} |Rx - (d - Sy - Tz - ...)|^2 \f$ + * The negative log-density is given by \f$ \frac{1}{2} |Rx - (d - Sy - Tz - ...)|^2 \f$ * @ingroup linear */ class GTSAM_EXPORT GaussianConditional : From ab439bfbb0d025577ce365e333dfb50e695e3f18 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 10:23:21 -0800 Subject: [PATCH 09/13] Checking mixture invariants, WIP --- gtsam/hybrid/GaussianMixture.cpp | 7 +- gtsam/hybrid/GaussianMixture.h | 13 ++-- gtsam/hybrid/HybridConditional.cpp | 20 ++++++ gtsam/hybrid/HybridConditional.h | 12 +++- gtsam/hybrid/tests/testHybridConditional.cpp | 75 ++++++++++++++++++++ 5 files changed, 119 insertions(+), 8 deletions(-) create mode 100644 gtsam/hybrid/tests/testHybridConditional.cpp diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index f61b280cb7..1913be7aa5 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -298,9 +298,14 @@ double GaussianMixture::error(const HybridValues &values) const { /* *******************************************************************************/ double GaussianMixture::logProbability(const HybridValues &values) const { - // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); return conditional->logProbability(values.continuous()); } +/* *******************************************************************************/ +double GaussianMixture::evaluate(const HybridValues &values) const { + auto conditional = conditionals_(values.discrete()); + return conditional->evaluate(values.continuous()); +} + } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index a8d07cbc84..2137acff6a 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -175,7 +175,7 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Compute the error of this Gaussian Mixture. - * + * * log(probability(x)) = K - error(x) * * @param values Continuous values and discrete assignment. @@ -191,12 +191,13 @@ class GTSAM_EXPORT GaussianMixture */ double logProbability(const HybridValues &values) const override; - // /// Calculate probability density for given values `x`. - // double evaluate(const HybridValues &values) const; + /// Calculate probability density for given `values`. + double evaluate(const HybridValues &values) const override; - // /// Evaluate probability density, sugar. - // double operator()(const HybridValues &values) const { return - // evaluate(values); } + /// Evaluate probability density, sugar. + double operator()(const HybridValues &values) const { + return evaluate(values); + } /** * @brief Prune the decision tree of Gaussian factors as per the discrete diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 55fd5d5d44..24f61a85f4 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -151,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const { "HybridConditional::logProbability: conditional type not handled"); } +/* ************************************************************************ */ +double HybridConditional::logNormalizationConstant() const { + if (auto gc = asGaussian()) { + return gc->logNormalizationConstant(); + } + if (auto gm = asMixture()) { + return gm->logNormalizationConstant(); // 0.0! + } + if (auto dc = asDiscrete()) { + return dc->logNormalizationConstant(); // 0.0! + } + throw std::runtime_error( + "HybridConditional::logProbability: conditional type not handled"); +} + +/* ************************************************************************ */ +double HybridConditional::evaluate(const HybridValues &values) const { + return std::exp(logProbability(values)); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 19c070974b..c8cb968dff 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -179,9 +179,19 @@ class GTSAM_EXPORT HybridConditional /// Return the error of the underlying conditional. double error(const HybridValues& values) const override; - /// Return the logProbability of the underlying conditional. + /// Return the log-probability (or density) of the underlying conditional. double logProbability(const HybridValues& values) const override; + /** + * Return the log normalization constant. + * Note this is 0.0 for discrete and hybrid conditionals, but depends + * on the continuous parameters for Gaussian conditionals. + */ + double logNormalizationConstant() const override; + + /// Return the probability (or density) of the underlying conditional. + double evaluate(const HybridValues& values) const override; + /// Check if VectorValues `measurements` contains all frontal keys. bool frontalsIn(const VectorValues& measurements) const { for (Key key : frontals()) { diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp new file mode 100644 index 0000000000..da766a56f2 --- /dev/null +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -0,0 +1,75 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file testHybridConditional.cpp + * @brief Unit tests for HybridConditional class + * @date January 2023 + */ + +#include + +#include "TinyHybridExample.h" + +// Include for test suite +#include + +using namespace gtsam; + +using symbol_shorthand::M; +using symbol_shorthand::X; +using symbol_shorthand::Z; + +/* ****************************************************************************/ +// Check invariants for all conditionals in a tiny Bayes net. +TEST(HybridConditional, Invariants) { + // Create hybrid Bayes net p(z|x,m)p(x)P(m) + auto bn = tiny::createHybridBayesNet(); + + // Create values to check invariants. + const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}}; + const DiscreteValues d{{M(0), 1}}; + const HybridValues values{c, d}; + + // Check invariants for p(z|x,m) + auto hc1 = bn.at(0); + CHECK(hc1->isHybrid()); + GTSAM_PRINT(*hc1); + + // Check invariants as a GaussianMixture. + const auto mixture = hc1->asMixture(); + double probability = mixture->evaluate(values); + CHECK(probability >= 0.0); + EXPECT_DOUBLES_EQUAL(probability, (*mixture)(values), 1e-9); + double logProb = mixture->logProbability(values); + EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); + double expected = + mixture->logNormalizationConstant() - mixture->error(values); + EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + EXPECT(GaussianMixture::CheckInvariants(*mixture, values)); + + // Check invariants as a HybridConditional. + probability = hc1->evaluate(values); + CHECK(probability >= 0.0); + EXPECT_DOUBLES_EQUAL(probability, (*hc1)(values), 1e-9); + logProb = hc1->logProbability(values); + EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); + expected = hc1->logNormalizationConstant() - hc1->error(values); + EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + EXPECT(HybridConditional::CheckInvariants(*hc1, values)); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ From c9fcfe3299fb7a743f6919c134bb328fc6e453ff Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 12:56:38 -0800 Subject: [PATCH 10/13] Resolve GaussianMixture error crisis --- gtsam/hybrid/GaussianMixture.cpp | 2 +- gtsam/hybrid/GaussianMixture.h | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 1913be7aa5..9de8aba590 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -293,7 +293,7 @@ AlgebraicDecisionTree GaussianMixture::logProbability( double GaussianMixture::error(const HybridValues &values) const { // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()); + return conditional->error(values.continuous()) - conditional->logNormalizationConstant(); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2137acff6a..d90e08409a 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -176,7 +176,21 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Compute the error of this Gaussian Mixture. * - * log(probability(x)) = K - error(x) + * This requires some care, as different mixture components may have + * different normalization constants. Let's consider p(x|y,m), where m is + * discrete. We need the error to satisfy the invariant: + * + * error(x;y,m) = K - log(probability(x;y,m)) + * + * For all x,y,m. But note that K, for the GaussianMixture, cannot depend on + * any arguments. Hence, we delegate to the underlying Gaussian + * conditionals, indexed by m, which do satisfy: + * + * log(probability_m(x;y)) = K_m - error_m(x;y) + * + * We resolve by having K == 0.0 and + * + * error(x;y,m) = error_m(x;y) - K_m * * @param values Continuous values and discrete assignment. * @return double From ce8bf7ac48d5b3a5fd2c421ee30a7cb00e700e3c Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 12:57:12 -0800 Subject: [PATCH 11/13] Expose all needed versions of evaluate, operator(), error --- gtsam/discrete/DiscreteConditional.cpp | 6 +++++- gtsam/discrete/DiscreteConditional.h | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0d6c5e976f..214bc64da2 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, return ss.str(); } +/* ************************************************************************* */ +double DiscreteConditional::evaluate(const HybridValues& x) const{ + return this->evaluate(x.discrete()); +} /* ************************************************************************* */ } // namespace gtsam diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 2760ea538f..f073c2d761 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -160,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional } /// Evaluate, just look up in AlgebraicDecisonTree - double operator()(const DiscreteValues& values) const override { + double evaluate(const DiscreteValues& values) const { return ADT::operator()(values); } + using DecisionTreeFactor::error; ///< DiscreteValues version + using DecisionTreeFactor::operator(); ///< DiscreteValues version + /** * @brief restrict to given *parent* values. * @@ -235,6 +238,14 @@ class GTSAM_EXPORT DiscreteConditional /// @name HybridValues methods. /// @{ + /** + * Calculate probability for HybridValues `x`. + * Dispatches to DiscreteValues version. + */ + double evaluate(const HybridValues& x) const override; + + using BaseConditional::operator(); ///< HybridValues version + /** * Calculate log-probability log(evaluate(x)) for HybridValues `x`. * This is actually just -error(x). @@ -243,9 +254,6 @@ class GTSAM_EXPORT DiscreteConditional return -error(x); } - using DecisionTreeFactor::error; ///< HybridValues version - using DecisionTreeFactor::evaluate; ///< HybridValues version - /// @} #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 From bead5ce4da3f44955f9777b04e7e1254801af714 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 12:57:53 -0800 Subject: [PATCH 12/13] Test all HybridConditionals in all possible calling conventions. --- gtsam/hybrid/tests/testHybridConditional.cpp | 46 ++++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index da766a56f2..406306df7c 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -40,31 +40,39 @@ TEST(HybridConditional, Invariants) { const HybridValues values{c, d}; // Check invariants for p(z|x,m) - auto hc1 = bn.at(0); - CHECK(hc1->isHybrid()); - GTSAM_PRINT(*hc1); + auto hc0 = bn.at(0); + CHECK(hc0->isHybrid()); // Check invariants as a GaussianMixture. - const auto mixture = hc1->asMixture(); - double probability = mixture->evaluate(values); - CHECK(probability >= 0.0); - EXPECT_DOUBLES_EQUAL(probability, (*mixture)(values), 1e-9); - double logProb = mixture->logProbability(values); - EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); - double expected = - mixture->logNormalizationConstant() - mixture->error(values); - EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + const auto mixture = hc0->asMixture(); EXPECT(GaussianMixture::CheckInvariants(*mixture, values)); // Check invariants as a HybridConditional. - probability = hc1->evaluate(values); - CHECK(probability >= 0.0); - EXPECT_DOUBLES_EQUAL(probability, (*hc1)(values), 1e-9); - logProb = hc1->logProbability(values); - EXPECT_DOUBLES_EQUAL(probability, std::exp(logProb), 1e-9); - expected = hc1->logNormalizationConstant() - hc1->error(values); - EXPECT_DOUBLES_EQUAL(logProb, expected, 1e-9); + EXPECT(HybridConditional::CheckInvariants(*hc0, values)); + + // Check invariants for p(x) + auto hc1 = bn.at(1); + CHECK(hc1->isContinuous()); + + // Check invariants as a GaussianConditional. + const auto gaussian = hc1->asGaussian(); + EXPECT(GaussianConditional::CheckInvariants(*gaussian, c)); + EXPECT(GaussianConditional::CheckInvariants(*gaussian, values)); + + // Check invariants as a HybridConditional. EXPECT(HybridConditional::CheckInvariants(*hc1, values)); + + // Check invariants for p(m) + auto hc2 = bn.at(2); + CHECK(hc2->isDiscrete()); + + // Check invariants as a DiscreteConditional. + const auto discrete = hc2->asDiscrete(); + EXPECT(DiscreteConditional::CheckInvariants(*discrete, d)); + EXPECT(DiscreteConditional::CheckInvariants(*discrete, values)); + + // Check invariants as a HybridConditional. + EXPECT(HybridConditional::CheckInvariants(*hc2, values)); } /* ************************************************************************* */ From 51c46410dc178ef4923c8e1bfc0c226c6ab3311e Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 14 Jan 2023 13:24:54 -0800 Subject: [PATCH 13/13] Make sure all conditional methods can be called in wrappers and satisfy invariants there, as well. --- gtsam/discrete/discrete.i | 13 ++++++++- gtsam/hybrid/hybrid.i | 1 + gtsam/linear/linear.i | 7 +++++ python/gtsam/tests/test_HybridBayesNet.py | 34 +++++++++++++++++++---- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index a25897ffa9..78efd57e28 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { }; #include +#include virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); @@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + + // Standard interface + double logNormalizationConstant() const; double logProbability(const gtsam::DiscreteValues& values) const; double evaluate(const gtsam::DiscreteValues& values) const; - double operator()(const gtsam::DiscreteValues& values) const; + double error(const gtsam::DiscreteValues& values) const; gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; gtsam::DiscreteConditional marginal(gtsam::Key key) const; @@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { size_t sample(size_t value) const; size_t sample() const; void sampleInPlace(gtsam::DiscreteValues @parentsValues) const; + + // Markdown and HTML string markdown(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; string markdown(const gtsam::KeyFormatter& keyFormatter, @@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { gtsam::DefaultKeyFormatter) const; string html(const gtsam::KeyFormatter& keyFormatter, std::map> names) const; + + // Expose HybridValues versions + double logProbability(const gtsam::HybridValues& x) const; + double evaluate(const gtsam::HybridValues& x) const; + double error(const gtsam::HybridValues& x) const; }; #include diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index aad1cca9bd..bbadd1aa8a 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -61,6 +61,7 @@ virtual class HybridConditional { size_t nrParents() const; // Standard interface: + double logNormalizationConstant() const; double logProbability(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const; double operator()(const gtsam::HybridValues& values) const; diff --git a/gtsam/linear/linear.i b/gtsam/linear/linear.i index 2d88c5f938..c0230f1c21 100644 --- a/gtsam/linear/linear.i +++ b/gtsam/linear/linear.i @@ -456,6 +456,7 @@ class GaussianFactorGraph { }; #include +#include virtual class GaussianConditional : gtsam::JacobianFactor { // Constructors GaussianConditional(size_t key, Vector d, Matrix R, @@ -497,6 +498,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor { bool equals(const gtsam::GaussianConditional& cg, double tol) const; // Standard Interface + double logNormalizationConstant() const; double logProbability(const gtsam::VectorValues& x) const; double evaluate(const gtsam::VectorValues& x) const; double error(const gtsam::VectorValues& x) const; @@ -518,6 +520,11 @@ virtual class GaussianConditional : gtsam::JacobianFactor { // enabling serialization functionality void serialize() const; + + // Expose HybridValues versions + double logProbability(const gtsam::HybridValues& x) const; + double evaluate(const gtsam::HybridValues& x) const; + double error(const gtsam::HybridValues& x) const; }; #include diff --git a/python/gtsam/tests/test_HybridBayesNet.py b/python/gtsam/tests/test_HybridBayesNet.py index c949551c4d..01e1c5a5d8 100644 --- a/python/gtsam/tests/test_HybridBayesNet.py +++ b/python/gtsam/tests/test_HybridBayesNet.py @@ -17,8 +17,8 @@ from gtsam.symbol_shorthand import A, X from gtsam.utils.test_case import GtsamTestCase -from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, - GaussianMixture, HybridBayesNet, HybridValues, noiseModel) +from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, GaussianConditional, + GaussianMixture, HybridBayesNet, HybridValues, noiseModel, VectorValues) class TestHybridBayesNet(GtsamTestCase): @@ -53,9 +53,13 @@ def test_evaluate(self): # Create values at which to evaluate. values = HybridValues() - values.insert(asiaKey, 0) - values.insert(X(0), [-6]) - values.insert(X(1), [1]) + continuous = VectorValues() + continuous.insert(X(0), [-6]) + continuous.insert(X(1), [1]) + values.insert(continuous) + discrete = DiscreteValues() + discrete[asiaKey] = 0 + values.insert(discrete) conditionalProbability = conditional.evaluate(values.continuous()) mixtureProbability = conditional0.evaluate(values.continuous()) @@ -68,6 +72,26 @@ def test_evaluate(self): self.assertAlmostEqual(bayesNet.logProbability(values), math.log(bayesNet.evaluate(values))) + # Check invariance for all conditionals: + self.check_invariance(bayesNet.at(0).asGaussian(), continuous) + self.check_invariance(bayesNet.at(0).asGaussian(), values) + self.check_invariance(bayesNet.at(0), values) + + self.check_invariance(bayesNet.at(1), values) + + self.check_invariance(bayesNet.at(2).asDiscrete(), discrete) + self.check_invariance(bayesNet.at(2).asDiscrete(), values) + self.check_invariance(bayesNet.at(2), values) + + def check_invariance(self, conditional, values): + """Check invariance for given conditional.""" + probability = conditional.evaluate(values) + self.assertTrue(probability >= 0.0) + logProb = conditional.logProbability(values) + self.assertAlmostEqual(probability, np.exp(logProb)) + expected = conditional.logNormalizationConstant() - conditional.error(values) + self.assertAlmostEqual(logProb, expected) + if __name__ == "__main__": unittest.main()