Skip to content

Commit

Permalink
Merge pull request #1386 from borglab/hybrid/more_tests
Browse files Browse the repository at this point in the history
Test and fix all conditionals
  • Loading branch information
dellaert authored Jan 15, 2023
2 parents 5c59862 + 51c4641 commit 618ac28
Show file tree
Hide file tree
Showing 19 changed files with 338 additions and 41 deletions.
6 changes: 5 additions & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <boost/make_shared.hpp>
Expand Down Expand Up @@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
return ss.str();
}

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const{
return this->evaluate(x.discrete());
}
/* ************************************************************************* */

} // namespace gtsam
15 changes: 12 additions & 3 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional
}

/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
*
Expand Down Expand Up @@ -235,6 +238,14 @@ class GTSAM_EXPORT DiscreteConditional
/// @name HybridValues methods.
/// @{

/**
* Calculate probability for HybridValues `x`.
* Dispatches to DiscreteValues version.
*/
double evaluate(const HybridValues& x) const override;

using BaseConditional::operator(); ///< HybridValues version

/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
Expand All @@ -243,8 +254,6 @@ class GTSAM_EXPORT DiscreteConditional
return -error(x);
}

using DecisionTreeFactor::evaluate;

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
13 changes: 12 additions & 1 deletion gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
};

#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/HybridValues.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
Expand All @@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);

// Standard interface
double logNormalizationConstant() const;
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
Expand All @@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(size_t value) const;
size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;

// Markdown and HTML
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand All @@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;

// Expose HybridValues versions
double logProbability(const gtsam::HybridValues& x) const;
double evaluate(const gtsam::HybridValues& x) const;
double error(const gtsam::HybridValues& x) const;
};

#include <gtsam/discrete/DiscreteDistribution.h>
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TEST(DiscreteConditional, PriorProbability) {
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(dc, values));
}

/* ************************************************************************* */
Expand All @@ -109,6 +110,7 @@ TEST(DiscreteConditional, probability) {
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
}

/* ************************************************************************* */
Expand Down
14 changes: 13 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,22 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
}

/* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const {
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
}

/* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->logProbability(values.continuous());
}

/* *******************************************************************************/
double GaussianMixture::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous());
}

} // namespace gtsam
38 changes: 31 additions & 7 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,44 @@ class GTSAM_EXPORT GaussianMixture
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture given the
* continuous values and a discrete assignment.
* @brief Compute the error of this Gaussian Mixture.
*
* This requires some care, as different mixture components may have
* different normalization constants. Let's consider p(x|y,m), where m is
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, for the GaussianMixture, cannot depend on
* any arguments. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double logProbability(const HybridValues &values) const override;

// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
/// Calculate probability density for given `values`.
double evaluate(const HybridValues &values) const override;

// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
/// Evaluate probability density, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
Expand Down
35 changes: 35 additions & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
: !(e->inner_);
}

/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
Expand All @@ -136,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const {
"HybridConditional::logProbability: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logNormalizationConstant() const {
if (auto gc = asGaussian()) {
return gc->logNormalizationConstant();
}
if (auto gm = asMixture()) {
return gm->logNormalizationConstant(); // 0.0!
}
if (auto dc = asDiscrete()) {
return dc->logNormalizationConstant(); // 0.0!
}
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values));
}

} // namespace gtsam
15 changes: 14 additions & 1 deletion gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,22 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() const { return inner_; }

/// Return the logProbability of the underlying conditional.
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

/**
* Return the log normalization constant.
* Note this is 0.0 for discrete and hybrid conditionals, but depends
* on the continuous parameters for Gaussian conditionals.
*/
double logNormalizationConstant() const override;

/// Return the probability (or density) of the underlying conditional.
double evaluate(const HybridValues& values) const override;

/// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const {
for (Key key : frontals()) {
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ virtual class HybridConditional {
size_t nrParents() const;

// Standard interface:
double logNormalizationConstant() const;
double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const;
double operator()(const gtsam::HybridValues& values) const;
Expand Down
83 changes: 83 additions & 0 deletions gtsam/hybrid/tests/testHybridConditional.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file testHybridConditional.cpp
* @brief Unit tests for HybridConditional class
* @date January 2023
*/

#include <gtsam/hybrid/HybridConditional.h>

#include "TinyHybridExample.h"

// Include for test suite
#include <CppUnitLite/TestHarness.h>

using namespace gtsam;

using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;

/* ****************************************************************************/
// Check invariants for all conditionals in a tiny Bayes net.
TEST(HybridConditional, Invariants) {
// Create hybrid Bayes net p(z|x,m)p(x)P(m)
auto bn = tiny::createHybridBayesNet();

// Create values to check invariants.
const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}};
const DiscreteValues d{{M(0), 1}};
const HybridValues values{c, d};

// Check invariants for p(z|x,m)
auto hc0 = bn.at(0);
CHECK(hc0->isHybrid());

// Check invariants as a GaussianMixture.
const auto mixture = hc0->asMixture();
EXPECT(GaussianMixture::CheckInvariants(*mixture, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc0, values));

// Check invariants for p(x)
auto hc1 = bn.at(1);
CHECK(hc1->isContinuous());

// Check invariants as a GaussianConditional.
const auto gaussian = hc1->asGaussian();
EXPECT(GaussianConditional::CheckInvariants(*gaussian, c));
EXPECT(GaussianConditional::CheckInvariants(*gaussian, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc1, values));

// Check invariants for p(m)
auto hc2 = bn.at(2);
CHECK(hc2->isDiscrete());

// Check invariants as a DiscreteConditional.
const auto discrete = hc2->asDiscrete();
EXPECT(DiscreteConditional::CheckInvariants(*discrete, d));
EXPECT(DiscreteConditional::CheckInvariants(*discrete, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc2, values));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
26 changes: 26 additions & 0 deletions gtsam/inference/Conditional-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,30 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
const HybridValues& c) const {
throw std::runtime_error("Conditional::evaluate is not implemented");
}

/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
return std::exp(logNormalizationConstant());
}

/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
template <class VALUES>
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
const DERIVEDCONDITIONAL& conditional, const VALUES& values) {
const double prob_or_density = conditional.evaluate(values);
if (prob_or_density < 0.0) return false; // prob_or_density is negative.
if (std::abs(prob_or_density - conditional(values)) > 1e-9)
return false; // operator and evaluate differ
const double logProb = conditional.logProbability(values);
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
return false; // logProb is not consistent with prob_or_density
const double expected =
conditional.logNormalizationConstant() - conditional.error(values);
if (std::abs(logProb - expected) > 1e-9)
return false; // logProb is not consistent with error
return true;
}

} // namespace gtsam
Loading

0 comments on commit 618ac28

Please sign in to comment.