Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Likelihoods and other improvements #1356

Merged
merged 10 commits into from
Dec 30, 2022
44 changes: 36 additions & 8 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>

Expand All @@ -36,20 +37,17 @@ GaussianMixture::GaussianMixture(
conditionals_(conditionals) {}

/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() {
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
return conditionals_;
}

/* *******************************************************************************/
GaussianMixture GaussianMixture::FromConditionals(
GaussianMixture::GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);

return GaussianMixture(continuousFrontals, continuousParents, discreteParents,
dt);
}
const std::vector<GaussianConditional::shared_ptr> &conditionalsList)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionalsList)) {}

/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
Expand Down Expand Up @@ -128,6 +126,36 @@ void GaussianMixture::print(const std::string &s,
});
}

/* ************************************************************************* */
KeyVector GaussianMixture::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// TODO(dellaert): check that values has all frontals
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals(), [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->likelihood(frontals);
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}

/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
Expand Down
33 changes: 21 additions & 12 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

namespace gtsam {

class GaussianMixtureFactor;

/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
* part of a Bayes Network. This is the result of the elimination of a
Expand Down Expand Up @@ -112,21 +114,11 @@ class GTSAM_EXPORT GaussianMixture
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionals(
GaussianMixture(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;

/// @}
/// @name Testable
/// @{
Expand All @@ -140,9 +132,25 @@ class GTSAM_EXPORT GaussianMixture
const KeyFormatter &formatter = DefaultKeyFormatter) const override;

/// @}
/// @name Standard API
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;

/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;

// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;

/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();
const Conditionals &conditionals() const;

/**
* @brief Compute error of the GaussianMixture as a tree.
Expand Down Expand Up @@ -181,6 +189,7 @@ class GTSAM_EXPORT GaussianMixture
* @return Sum
*/
Sum add(const Sum &sum) const;
/// @}
};

/// Return the DiscreteKey vector as a set.
Expand Down
51 changes: 29 additions & 22 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,42 @@ GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && Base::equals(*e, tol);
}

/* *******************************************************************************/
GaussianMixtureFactor GaussianMixtureFactor::FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors) {
Factors dt(discreteKeys, factors);

return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
if (e == nullptr) return false;

// This will return false if either factors_ is empty or e->factors_ is empty,
// but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const GaussianFactor::shared_ptr &f1,
const GaussianFactor::shared_ptr &f2) {
return f1->equals(*f2, tol);
});
}

/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
std::cout << "{\n";
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
gf->print("", formatter);
return rd.str();
} else {
return "nullptr";
}
});
if (factors_.empty()) {
std::cout << " empty" << std::endl;
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
gf->print("", formatter);
return rd.str();
} else {
return "nullptr";
}
});
}
std::cout << "}" << std::endl;
}

Expand Down
11 changes: 4 additions & 7 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Construct a new GaussianMixtureFactor object using a vector of
* GaussianFactor shared pointers.
*
* @param keys Vector of keys for continuous factors.
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Vector of gaussian factor shared pointers.
*/
GaussianMixtureFactor(const KeyVector &keys, const DiscreteKeys &discreteKeys,
GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors)
: GaussianMixtureFactor(keys, discreteKeys,
: GaussianMixtureFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}

static This FromFactors(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors);

/// @}
/// @name Testable
/// @{
Expand Down
21 changes: 18 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,38 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Add HybridConditional to Bayes Net
using Base::add;

/// Add a Gaussian Mixture to the Bayes Net.
void addMixture(const GaussianMixture::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a Gaussian conditional to the Bayes Net.
void addGaussian(const GaussianConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a discrete conditional to the Bayes Net.
void addDiscrete(const DiscreteConditional::shared_ptr &ptr) {
push_back(HybridConditional(ptr));
}

/// Add a Gaussian Mixture to the Bayes Net.
template <typename... T>
void addMixture(T &&...args) {
void emplaceMixture(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianMixture>(std::forward<T>(args)...)));
}

/// Add a Gaussian conditional to the Bayes Net.
template <typename... T>
void addGaussian(T &&...args) {
void emplaceGaussian(T &&...args) {
push_back(HybridConditional(
boost::make_shared<GaussianConditional>(std::forward<T>(args)...)));
}

/// Add a discrete conditional to the Bayes Net.
template <typename... T>
void addDiscrete(T &&...args) {
void emplaceDiscrete(T &&...args) {
push_back(HybridConditional(
boost::make_shared<DiscreteConditional>(std::forward<T>(args)...)));
}
Expand Down
43 changes: 41 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,51 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
return error_tree;
}

/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);

if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
}

} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
}
}
}
return error;
}

/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
AlgebraicDecisionTree<Key> prob_tree = error_tree.apply([](double error) {
// NOTE: The 0.5 term is handled by each factor
return exp(-error);
});
return prob_tree;
}

Expand Down
25 changes: 25 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;

/**
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
Expand All @@ -193,6 +206,18 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
Loading