Skip to content

Commit

Permalink
Merge pull request #1356 from borglab/hybrid/elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Dec 30, 2022
2 parents 04ba193 + a4659f0 commit 90c2f2e
Show file tree
Hide file tree
Showing 17 changed files with 425 additions and 119 deletions.
44 changes: 36 additions & 8 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>

Expand All @@ -36,20 +37,17 @@ GaussianMixture::GaussianMixture(
conditionals_(conditionals) {}

/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() {
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
GaussianMixture GaussianMixture::FromConditionals(
GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);

return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
dt);
}
const std::vector<GaussianConditional::shared_ptr> &conditionalsList)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionalsList)) {}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
Expand Down Expand Up @@ -128,6 +126,36 @@ void GaussianMixture::print(const std::string &s,
});
}

/* ************************************************************************* */
KeyVector GaussianMixture::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}

/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
Expand Down
33 changes: 21 additions & 12 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

namespace gtsam {

class GaussianMixtureFactor;

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a
Expand Down Expand Up @@ -112,21 +114,11 @@ class GTSAM_EXPORT GaussianMixture
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionals(
GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;

/// @}
/// @name Testable
/// @{
Expand All @@ -140,9 +132,25 @@ class GTSAM_EXPORT GaussianMixture
const KeyFormatter &formatter = DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;

/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;

// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;

/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();
const Conditionals &conditionals() const;

/**
* @brief Compute error of the GaussianMixture as a tree.
Expand Down Expand Up @@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
* @return Sum
*/
Sum add(const Sum &sum) const;
/// @}
};

/// Return the DiscreteKey vector as a set.
Expand Down
51 changes: 29 additions & 22 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,42 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
}

/* *******************************************************************************/
GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) {
Factors dt(discreteKeys, factors);

return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
if (e == nullptr) return false;

// This will return false if either factors_ is empty or e->factors_ is empty,
// but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianFactor::shared_ptr &f1,
const GaussianFactor::shared_ptr &f2) {
return f1->equals(*f2, tol);
});
}

/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << "{\n";
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
gf->print("", formatter);
return rd.str();
} else {
return "nullptr";
}
});
if (factors_.empty()) {
std::cout << " empty" << std::endl;
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
gf->print("", formatter);
return rd.str();
} else {
return "nullptr";
}
});
}
std::cout << "}" << std::endl;
}

Expand Down
11 changes: 4 additions & 7 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Construct a new GaussianMixtureFactor object using a vector of
* GaussianFactor shared pointers.
*
* @param keys Vector of keys for continuous factors.
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers.
*/
GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys,
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
: GaussianMixtureFactor(keys, discreteKeys,
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}

static This FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors);

/// @}
/// @name Testable
/// @{
Expand Down
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,38 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Add HybridConditional to Bayes Net
using Base::add;

/// Add a Gaussian Mixture to the Bayes Net.
void addMixture(const GaussianMixture::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a Gaussian conditional to the Bayes Net.
void addGaussian(const GaussianConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a discrete conditional to the Bayes Net.
void addDiscrete(const DiscreteConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a Gaussian Mixture to the Bayes Net.
template <typename... T>
void addMixture(T &&...args) {
void emplaceMixture(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianMixture>(std::forward<T>(args)...)));
}

/// Add a Gaussian conditional to the Bayes Net.
template <typename... T>
void addGaussian(T &&...args) {
void emplaceGaussian(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianConditional>(std::forward<T>(args)...)));
}

/// Add a discrete conditional to the Bayes Net.
template <typename... T>
void addDiscrete(T &&...args) {
void emplaceDiscrete(T &&...args) {
push_back(HybridConditional(
boost::make_shared<DiscreteConditional>(std::forward<T>(args)...)));
}
Expand Down
43 changes: 41 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,51 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree;
}

/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);

if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}

} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
}
return error;
}

/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return prob_tree;
}

Expand Down
25 changes: 25 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;

/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
Expand All @@ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
Loading

0 comments on commit 90c2f2e

Please sign in to comment.