Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test and fix all conditionals #1386

Merged
merged 13 commits into from
Jan 15, 2023
6 changes: 5 additions & 1 deletion gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/hybrid/HybridValues.h>

#include <algorithm>
#include <boost/make_shared.hpp>
Expand Down Expand Up @@ -510,6 +510,10 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
return ss.str();
}

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const{
return this->evaluate(x.discrete());
}
/* ************************************************************************* */

} // namespace gtsam
15 changes: 12 additions & 3 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,13 @@ class GTSAM_EXPORT DiscreteConditional
}

/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}

using DecisionTreeFactor::error; ///< DiscreteValues version
using DecisionTreeFactor::operator(); ///< DiscreteValues version

/**
* @brief restrict to given *parent* values.
*
Expand Down Expand Up @@ -235,6 +238,14 @@ class GTSAM_EXPORT DiscreteConditional
/// @name HybridValues methods.
/// @{

/**
* Calculate probability for HybridValues `x`.
* Dispatches to DiscreteValues version.
*/
double evaluate(const HybridValues& x) const override;

using BaseConditional::operator(); ///< HybridValues version

/**
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
* This is actually just -error(x).
Expand All @@ -243,8 +254,6 @@ class GTSAM_EXPORT DiscreteConditional
return -error(x);
}

using DecisionTreeFactor::evaluate;

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
13 changes: 12 additions & 1 deletion gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
};

#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/HybridValues.h>
virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional();
DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f);
Expand All @@ -95,9 +96,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
DiscreteConditional(const gtsam::DecisionTreeFactor& joint,
const gtsam::DecisionTreeFactor& marginal,
const gtsam::Ordering& orderedKeys);

// Standard interface
double logNormalizationConstant() const;
double logProbability(const gtsam::DiscreteValues& values) const;
double evaluate(const gtsam::DiscreteValues& values) const;
double operator()(const gtsam::DiscreteValues& values) const;
double error(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteConditional operator*(
const gtsam::DiscreteConditional& other) const;
gtsam::DiscreteConditional marginal(gtsam::Key key) const;
Expand All @@ -119,6 +123,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(size_t value) const;
size_t sample() const;
void sampleInPlace(gtsam::DiscreteValues @parentsValues) const;

// Markdown and HTML
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter,
Expand All @@ -127,6 +133,11 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
gtsam::DefaultKeyFormatter) const;
string html(const gtsam::KeyFormatter& keyFormatter,
std::map<gtsam::Key, std::vector<std::string>> names) const;

// Expose HybridValues versions
double logProbability(const gtsam::HybridValues& x) const;
double evaluate(const gtsam::HybridValues& x) const;
double error(const gtsam::HybridValues& x) const;
};

#include <gtsam/discrete/DiscreteDistribution.h>
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TEST(DiscreteConditional, PriorProbability) {
DiscreteConditional dc(Asia, "4/6");
DiscreteValues values{{asiaKey, 0}};
EXPECT_DOUBLES_EQUAL(0.4, dc.evaluate(values), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(dc, values));
}

/* ************************************************************************* */
Expand All @@ -109,6 +110,7 @@ TEST(DiscreteConditional, probability) {
EXPECT_DOUBLES_EQUAL(0.2, C_given_DE(given), 1e-9);
EXPECT_DOUBLES_EQUAL(log(0.2), C_given_DE.logProbability(given), 1e-9);
EXPECT_DOUBLES_EQUAL(-log(0.2), C_given_DE.error(given), 1e-9);
EXPECT(DiscreteConditional::CheckInvariants(C_given_DE, given));
}

/* ************************************************************************* */
Expand Down
14 changes: 13 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,22 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
}

/* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const {
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
}

/* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->logProbability(values.continuous());
}

/* *******************************************************************************/
double GaussianMixture::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous());
}

} // namespace gtsam
38 changes: 31 additions & 7 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,44 @@ class GTSAM_EXPORT GaussianMixture
const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture given the
* continuous values and a discrete assignment.
* @brief Compute the error of this Gaussian Mixture.
*
* This requires some care, as different mixture components may have
* different normalization constants. Let's consider p(x|y,m), where m is
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* discrete. We need the error to satisfy the invariant:
*
* error(x;y,m) = K - log(probability(x;y,m))
dellaert marked this conversation as resolved.
Show resolved Hide resolved
*
* For all x,y,m. But note that K, for the GaussianMixture, cannot depend on
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* any arguments. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double logProbability(const HybridValues &values) const override;

// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
/// Calculate probability density for given `values`.
double evaluate(const HybridValues &values) const override;

// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
/// Evaluate probability density, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
Expand Down
35 changes: 35 additions & 0 deletions gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
: !(e->inner_);
}

/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
Expand All @@ -136,4 +151,24 @@ double HybridConditional::logProbability(const HybridValues &values) const {
"HybridConditional::logProbability: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::logNormalizationConstant() const {
if (auto gc = asGaussian()) {
return gc->logNormalizationConstant();
}
if (auto gm = asMixture()) {
return gm->logNormalizationConstant(); // 0.0!
}
if (auto dc = asDiscrete()) {
return dc->logNormalizationConstant(); // 0.0!
}
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}

/* ************************************************************************ */
double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values));
}

} // namespace gtsam
15 changes: 14 additions & 1 deletion gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,22 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() const { return inner_; }

/// Return the logProbability of the underlying conditional.
/// Return the error of the underlying conditional.
double error(const HybridValues& values) const override;

/// Return the log-probability (or density) of the underlying conditional.
double logProbability(const HybridValues& values) const override;

/**
* Return the log normalization constant.
* Note this is 0.0 for discrete and hybrid conditionals, but depends
* on the continuous parameters for Gaussian conditionals.
*/
double logNormalizationConstant() const override;

/// Return the probability (or density) of the underlying conditional.
double evaluate(const HybridValues& values) const override;

/// Check if VectorValues `measurements` contains all frontal keys.
bool frontalsIn(const VectorValues& measurements) const {
for (Key key : frontals()) {
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ virtual class HybridConditional {
size_t nrParents() const;

// Standard interface:
double logNormalizationConstant() const;
double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const;
double operator()(const gtsam::HybridValues& values) const;
Expand Down
83 changes: 83 additions & 0 deletions gtsam/hybrid/tests/testHybridConditional.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file testHybridConditional.cpp
* @brief Unit tests for HybridConditional class
* @date January 2023
*/

#include <gtsam/hybrid/HybridConditional.h>

#include "TinyHybridExample.h"

// Include for test suite
#include <CppUnitLite/TestHarness.h>

using namespace gtsam;

using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;

/* ****************************************************************************/
// Check invariants for all conditionals in a tiny Bayes net.
TEST(HybridConditional, Invariants) {
// Create hybrid Bayes net p(z|x,m)p(x)P(m)
auto bn = tiny::createHybridBayesNet();

// Create values to check invariants.
const VectorValues c{{X(0), Vector1(5.1)}, {Z(0), Vector1(4.9)}};
const DiscreteValues d{{M(0), 1}};
const HybridValues values{c, d};

// Check invariants for p(z|x,m)
auto hc0 = bn.at(0);
CHECK(hc0->isHybrid());

// Check invariants as a GaussianMixture.
const auto mixture = hc0->asMixture();
EXPECT(GaussianMixture::CheckInvariants(*mixture, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc0, values));

// Check invariants for p(x)
auto hc1 = bn.at(1);
CHECK(hc1->isContinuous());

// Check invariants as a GaussianConditional.
const auto gaussian = hc1->asGaussian();
EXPECT(GaussianConditional::CheckInvariants(*gaussian, c));
EXPECT(GaussianConditional::CheckInvariants(*gaussian, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc1, values));

// Check invariants for p(m)
auto hc2 = bn.at(2);
CHECK(hc2->isDiscrete());

// Check invariants as a DiscreteConditional.
const auto discrete = hc2->asDiscrete();
EXPECT(DiscreteConditional::CheckInvariants(*discrete, d));
EXPECT(DiscreteConditional::CheckInvariants(*discrete, values));

// Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc2, values));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */
26 changes: 26 additions & 0 deletions gtsam/inference/Conditional-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,30 @@ double Conditional<FACTOR, DERIVEDCONDITIONAL>::evaluate(
const HybridValues& c) const {
throw std::runtime_error("Conditional::evaluate is not implemented");
}

/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
double Conditional<FACTOR, DERIVEDCONDITIONAL>::normalizationConstant() const {
return std::exp(logNormalizationConstant());
}

/* ************************************************************************* */
template <class FACTOR, class DERIVEDCONDITIONAL>
template <class VALUES>
bool Conditional<FACTOR, DERIVEDCONDITIONAL>::CheckInvariants(
const DERIVEDCONDITIONAL& conditional, const VALUES& values) {
const double prob_or_density = conditional.evaluate(values);
if (prob_or_density < 0.0) return false; // prob_or_density is negative.
if (std::abs(prob_or_density - conditional(values)) > 1e-9)
return false; // operator and evaluate differ
const double logProb = conditional.logProbability(values);
if (std::abs(prob_or_density - std::exp(logProb)) > 1e-9)
return false; // logProb is not consistent with prob_or_density
const double expected =
conditional.logNormalizationConstant() - conditional.error(values);
if (std::abs(logProb - expected) > 1e-9)
return false; // logProb is not consistent with error
return true;
}

} // namespace gtsam
Loading