Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid/simplify #1388

Merged
merged 29 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
34a9aef
normalizationConstants returns all constants as a DecisionTreeFactor
dellaert Jan 12, 2023
1dcc6dd
All tests still work with zero constant!
dellaert Jan 12, 2023
03ad393
Removed FactorAndConstant, no longer needed
dellaert Jan 13, 2023
906330f
Add discrete contribution to logProbability
dellaert Jan 13, 2023
681c75c
Expose toFactorGraph to wrapper
dellaert Jan 13, 2023
dfef2c2
Simplify elimination
dellaert Jan 13, 2023
070cdb7
insert_or_assign
dellaert Jan 14, 2023
96e3eb7
Some test refactoring
dellaert Jan 14, 2023
c22b2ca
Improved docs
dellaert Jan 15, 2023
5b0408c
Check for error>0 and proper normalization constant
dellaert Jan 16, 2023
191e614
Fix print
dellaert Jan 16, 2023
57e59d1
Compute log-normalization constant as the max of the individual norma…
dellaert Jan 16, 2023
7a41180
Refactored tests and removed incorrect (R not upper-triangular) test.
dellaert Jan 16, 2023
207c9b7
Implemented the "hidden constant" scheme.
dellaert Jan 17, 2023
3a446d7
Explicitly implement logNormalizationConstant
dellaert Jan 17, 2023
202a5a3
Fixed toFactorGraph and added test to verify
dellaert Jan 17, 2023
a5951d8
Fixed test to work with "hidden constant" scheme
dellaert Jan 17, 2023
8357fc7
Fix python tests (and expose HybridBayesNet.error)
dellaert Jan 17, 2023
e31884c
Eradicated GraphAndConstant
dellaert Jan 17, 2023
9af7236
Added DEBUG_MARGINALS flag
dellaert Jan 17, 2023
519b2bb
Added comment
dellaert Jan 17, 2023
32d69a3
Trap if conditional==null.
dellaert Jan 17, 2023
f4859f0
Fix logProbability tests
dellaert Jan 17, 2023
4283925
Ratio test succeeds on fg, but not on posterior yet,
dellaert Jan 17, 2023
b494a61
Removed obsolete normalizationConstants method
dellaert Jan 17, 2023
892759e
Add math related to hybrid classes
dellaert Jan 17, 2023
c3ca31f
Added partial elimination test
dellaert Jan 17, 2023
e444962
Added correction with the normalization constant in the second elimin…
dellaert Jan 17, 2023
f714c4a
Merge branch 'develop' into hybrid/simplify
dellaert Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
719 changes: 719 additions & 0 deletions doc/Hybrid.lyx

Large diffs are not rendered by default.

Binary file added doc/Hybrid.pdf
Binary file not shown.
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ class GTSAM_EXPORT DiscreteConditional
return -error(x);
}

/**
* logNormalizationConstant K is just zero, such that
* logProbability(x) = log(evaluate(x)) = - error(x)
* and hence error(x) = - log(evaluate(x)) > 0 for all x.
*/
double logNormalizationConstant() const override { return 0.0; }

/// @}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
Expand Down
85 changes: 61 additions & 24 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@ GaussianMixture::GaussianMixture(
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {}
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
}
});
}

/* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const {
Expand Down Expand Up @@ -63,28 +74,22 @@ GaussianMixture::GaussianMixture(
// GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(conditional);
if (conditional) {
return GraphAndConstant(result, conditional->logNormalizationConstant());
} else {
return GraphAndConstant(result, 0.0);
}
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
return GaussianFactorGraph{gc};
dellaert marked this conversation as resolved.
Show resolved Hide resolved
};
return {conditionals_, lambda};
return {conditionals_, wrap};
}

/* *******************************************************************************/
Expand Down Expand Up @@ -170,22 +175,43 @@ KeyVector GaussianMixture::continuousParents() const {
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &frontals) const {
// Check that values has all frontals
for (auto &&kv : frontals) {
if (frontals.find(kv.first) == frontals.end()) {
throw std::runtime_error("GaussianMixture: frontals missing factor key.");
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) {
return false;
}
}
return true;
}

/* ************************************************************************* */
boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
const VectorValues &given) const {
if (!allFrontalsGiven(given)) {
throw std::runtime_error(
"GaussianMixture::likelihood: given values are missing some frontals.");
}

const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const GaussianMixtureFactor::Factors likelihoods(
conditionals_, [&](const GaussianConditional::shared_ptr &conditional) {
return GaussianMixtureFactor::FactorAndConstant{
conditional->likelihood(frontals),
conditional->logNormalizationConstant()};
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
return likelihood_m;
} else {
// Add a constant factor to the likelihood in case the noise models
// are not all equal.
GaussianFactorGraph gfg;
gfg.push_back(likelihood_m);
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = boost::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return boost::make_shared<JacobianFactor>(gfg);
}
});
return boost::make_shared<GaussianMixtureFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
Expand Down Expand Up @@ -285,6 +311,16 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return 1e50;
}
};
return DecisionTree<Key, double>(conditionals_, errorFunc);
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}
Expand All @@ -293,7 +329,8 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) - conditional->logNormalizationConstant();
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
}

/* *******************************************************************************/
Expand Down
44 changes: 33 additions & 11 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class GTSAM_EXPORT GaussianMixture
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

private:
Conditionals conditionals_;
Conditionals conditionals_; ///< a decision tree of Gaussian conditionals.
double logConstant_; ///< log of the normalization constant.

/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
Expand Down Expand Up @@ -155,10 +156,16 @@ class GTSAM_EXPORT GaussianMixture
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;

// Create a likelihood factor for a Gaussian mixture, return a Mixture factor
// on the parents.
/// The log normalization constant is max of the the individual
/// log-normalization constants.
double logNormalizationConstant() const override { return logConstant_; }

/**
* Create a likelihood factor for a Gaussian mixture, return a Mixture factor
* on the parents.
*/
boost::shared_ptr<GaussianMixtureFactor> likelihood(
const VectorValues &frontals) const;
const VectorValues &given) const;

/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals() const;
Expand All @@ -182,21 +189,33 @@ class GTSAM_EXPORT GaussianMixture
*
* error(x;y,m) = K - log(probability(x;y,m))
*
* For all x,y,m. But note that K, for the GaussianMixture, cannot depend on
* any arguments. Hence, we delegate to the underlying Gaussian
* For all x,y,m. But note that K, the (log) normalization constant defined
* in Conditional.h, should not depend on x, y, or m, only on the parameters
* of the density. Hence, we delegate to the underlying Gaussian
* conditionals, indexed by m, which do satisfy:
*
*
* log(probability_m(x;y)) = K_m - error_m(x;y)
*
* We resolve by having K == 0.0 and
*
* error(x;y,m) = error_m(x;y) - K_m
*
* We resolve by having K == max(K_m) and
*
* error(x;y,m) = error_m(x;y) + K - K_m
*
* which also makes error(x;y,m) >= 0 for all x,y,m.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const HybridValues &values) const override;

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
* only, with the leaf values as the error for each assignment.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the logProbability of this Gaussian Mixture.
*
Expand Down Expand Up @@ -233,6 +252,9 @@ class GTSAM_EXPORT GaussianMixture
/// @}

private:
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
Expand Down
50 changes: 18 additions & 32 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@ namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Mixture &factors)
: Base(continuousKeys, discreteKeys),
factors_(factors, [](const GaussianFactor::shared_ptr &gf) {
return FactorAndConstant{gf, 0.0};
}) {}
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}

/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
Expand All @@ -48,11 +45,10 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const FactorAndConstant &f1,
const FactorAndConstant &f2) {
return f1.factor->equals(*(f2.factor), tol) &&
std::abs(f1.constant - f2.constant) < tol;
});
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
}

/* *******************************************************************************/
Expand All @@ -65,8 +61,7 @@ void GaussianMixtureFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const FactorAndConstant &gf_z) -> std::string {
auto gf = gf_z.factor;
[&](const sharedFactor &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf && !gf->empty()) {
Expand All @@ -81,24 +76,19 @@ void GaussianMixtureFactor::print(const std::string &s,
}

/* *******************************************************************************/
GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
const DiscreteValues &assignment) const {
return factors_(assignment).factor;
}

/* *******************************************************************************/
double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
return factors_(assignment).constant;
return factors_(assignment);
}

/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
Expand All @@ -107,29 +97,25 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return GraphAndConstant(result, factor_z.constant);
};
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
return {factors_, wrap};
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [continuousValues](const FactorAndConstant &factor_z) {
return factor_z.error(continuousValues);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(const HybridValues &values) const {
const FactorAndConstant factor_z = factors_(values.discrete());
return factor_z.error(values.continuous());
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/

Expand Down
Loading