Skip to content

Commit

Permalink
Merge pull request #1388 from borglab/hybrid/simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 17, 2023
2 parents 7cd3600 + f714c4a commit 0571af9
Show file tree
Hide file tree
Showing 22 changed files with 1,309 additions and 502 deletions.
719 changes: 719 additions & 0 deletions doc/Hybrid.lyx

Large diffs are not rendered by default.

Binary file added doc/Hybrid.pdf
Binary file not shown.
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ class GTSAM_EXPORT DiscreteConditional
return -error(x);
}

/**
* logNormalizationConstant K is just zero, such that
* logProbability(x) = log(evaluate(x)) = - error(x)
* and hence error(x) = - log(evaluate(x)) > 0 for all x.
*/
double logNormalizationConstant() const override { return 0.0; }

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
85 changes: 61 additions & 24 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
}
});
}

/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
Expand Down Expand Up @@ -63,28 +74,22 @@ GaussianMixture::GaussianMixture(
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(conditional);
if (conditional) {
return GraphAndConstant(result, conditional->logNormalizationConstant());
} else {
return GraphAndConstant(result, 0.0);
}
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
return GaussianFactorGraph{gc};
};
return {conditionals_, lambda};
return {conditionals_, wrap};
}

/* *******************************************************************************/
Expand Down Expand Up @@ -170,22 +175,43 @@ KeyVector GaussianMixture::continuousParents() const {
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// Check that values has all frontals
for (auto &&kv : frontals) {
if (frontals.find(kv.first) == frontals.end()) {
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) {
return false;
}
}
return true;
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &given) const {
if (!allFrontalsGiven(given)) {
throw std::runtime_error(
"GaussianMixture::likelihood: given values are missing some frontals.");
}

const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->logNormalizationConstant()};
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return likelihood_m;
} else {
// Add a constant factor to the likelihood in case the noise models
// are not all equal.
GaussianFactorGraph gfg;
gfg.push_back(likelihood_m);
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = boost::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return boost::make_shared<JacobianFactor>(gfg);
}
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
Expand Down Expand Up @@ -285,6 +311,16 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return 1e50;
}
};
return DecisionTree<Key, double>(conditionals_, errorFunc);
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
Expand All @@ -293,7 +329,8 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
}

/* *******************************************************************************/
Expand Down
44 changes: 33 additions & 11 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_;
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
Expand Down Expand Up @@ -155,10 +156,16 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;

// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }

/**
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
* on the parents.
*/
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;
const VectorValues &given) const;

/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const;
Expand All @@ -182,21 +189,33 @@ class GTSAM_EXPORT GaussianMixture
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, for the GaussianMixture, cannot depend on
* any arguments. Hence, we delegate to the underlying Gaussian
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
Expand Down Expand Up @@ -233,6 +252,9 @@ class GTSAM_EXPORT GaussianMixture
/// @}

private:
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
Expand Down
50 changes: 18 additions & 32 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}

/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
Expand All @@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
}

/* *******************************************************************************/
Expand All @@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
[&](const sharedFactor &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
Expand All @@ -81,24 +76,19 @@ void GaussianMixtureFactor::print(const std::string &s,
}

/* *******************************************************************************/
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
const DiscreteValues &assignment) const {
return factors_(assignment).factor;
}

/* *******************************************************************************/
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
return factors_(assignment).constant;
return factors_(assignment);
}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
Expand All @@ -107,29 +97,25 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return GraphAndConstant(result, factor_z.constant);
};
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
return {factors_, wrap};
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.error(values.continuous());
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/

Expand Down
Loading

0 comments on commit 0571af9

Please sign in to comment.