Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Mixture Error #1318

Merged
merged 20 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace gtsam {
static inline double id(const double& x) { return x; }
};

AlgebraicDecisionTree() : Base(1.0) {}
AlgebraicDecisionTree(double leaf = 1.0) : Base(leaf) {}

// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}
Expand Down Expand Up @@ -158,9 +158,9 @@ namespace gtsam {
}

/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
void print(const std::string& s = "",
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.8g") % v).str();
};
Expand Down
30 changes: 28 additions & 2 deletions gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ size_t GaussianMixture::nrComponents() const {

/* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()(
const DiscreteValues &discreteVals) const {
auto &ptr = conditionals_(discreteVals);
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr;
auto conditional = boost::dynamic_pointer_cast<GaussianConditional>(ptr);
if (conditional)
Expand Down Expand Up @@ -207,4 +207,30 @@ void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
conditionals_.root_ = pruned_conditionals.root_;
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::error(
const VectorValues &continuousValues) const {
// functor to calculate to double error value from GaussianConditional.
auto errorFunc =
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
if (conditional) {
return conditional->error(continuousValues);
} else {
// Return arbitrarily large error if conditional is null
// Conditional is null if it is pruned out.
return 1e50;
}
};
DecisionTree<Key, double> errorTree(conditionals_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
}

} // namespace gtsam
22 changes: 21 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// @{

GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteVals) const;
const DiscreteValues &discreteValues) const;

/// Returns the total number of continuous components
size_t nrComponents() const;
Expand All @@ -144,6 +144,26 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Compute error of the GaussianMixture as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the conditionals, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
Expand Down
22 changes: 22 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,26 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
};
return {factors_, wrap};
}

/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixtureFactor::error(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc =
[continuousValues](const GaussianFactor::shared_ptr &factor) {
return factor->error(continuousValues);
};
DecisionTree<Key, double> errorTree(factors_, errorFunc);
return errorTree;
}

/* *******************************************************************************/
double GaussianMixtureFactor::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
// Directly index to get the conditional, no need to build the whole tree.
auto factor = factors_(discreteValues);
return factor->error(continuousValues);
}

} // namespace gtsam
24 changes: 24 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@

#pragma once

#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>

namespace gtsam {

class GaussianFactorGraph;

// Needed for wrapper.
using GaussianFactorVector = std::vector<gtsam::GaussianFactor::shared_ptr>;

/**
Expand Down Expand Up @@ -126,6 +130,26 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
*/
Sum add(const Sum &sum) const;

/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
*
* @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys
* as the factors involved, and leaf values as the error.
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
Expand Down
52 changes: 52 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,56 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
return gbn.optimize();
}

/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
return gbn.error(continuousValues);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree;

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> conditional_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = this->atMixture(idx);
conditional_error = gm->error(continuousValues);

// Assign for the first index, add error for subsequent ones.
if (idx == 0) {
error_tree = conditional_error;
} else {
error_tree = error_tree + conditional_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = this->atGaussian(idx)->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
return error_tree.apply([](double error) { return exp(-error); });
}

} // namespace gtsam
33 changes: 33 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,39 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves);

/**
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;

/**
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* @brief Compute conditional error for each discrete assignment,
* and return as a tree.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues &continuousValues) const;

/**
* @brief Compute unnormalized probability q(μ|M),
* for each discrete assignment, and return as a tree.
* q(μ|M) is the unnormalized probability at the MLE point μ,
* conditioned on the discrete variables.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues &continuousValues) const;

/// @}

private:
Expand Down
54 changes: 54 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,4 +423,58 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
return ordering;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);

// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
AlgebraicDecisionTree<Key> factor_error;

if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixtureFactor::shared_ptr gaussianMixture =
boost::static_pointer_cast<GaussianMixtureFactor>(factors_.at(idx));
// Compute factor error.
factor_error = gaussianMixture->error(continuousValues);

// If first factor, assign error, else add it.
if (idx == 0) {
error_tree = factor_error;
} else {
error_tree = error_tree + factor_error;
}

} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, get the (double) error
// and add it to the error_tree
auto hybridGaussianFactor =
boost::static_pointer_cast<HybridGaussianFactor>(factors_.at(idx));
GaussianFactor::shared_ptr gaussian = hybridGaussianFactor->inner();

// Compute the error of the gaussian factor.
double error = gaussian->error(continuousValues);
// Add the gaussian factor error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we skip.
continue;
}
}

return error_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
AlgebraicDecisionTree<Key> prob_tree =
error_tree.apply([](double error) { return exp(-error); });
return prob_tree;
}

} // namespace gtsam
27 changes: 25 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JacobianFactor;

/**
* @brief Main elimination function for HybridGaussianFactorGraph.
*
*
* @param factors The factor graph to eliminate.
* @param keys The elimination ordering.
* @return The conditional on the ordering keys and the remaining factors.
Expand Down Expand Up @@ -99,11 +99,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
using shared_ptr = boost::shared_ptr<This>; ///< shared_ptr to This

using Values = gtsam::Values; ///< backwards compatibility
using Indices = KeyVector; ///> map from keys to values
using Indices = KeyVector; ///< map from keys to values

/// @name Constructors
/// @{

/// @brief Default constructor.
HybridGaussianFactorGraph() = default;

/**
Expand Down Expand Up @@ -170,6 +171,28 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/**
* @brief Compute error for each discrete assignment,
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* and return as a tree.
*
* Error \f$ e = \Vert x - \mu \Vert_{\Sigma} \f$.
*
* @param continuousValues Continuous values at which to compute the error.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> error(const VectorValues& continuousValues) const;

/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
* for each discrete assignment, and return as a tree.
*
* @param continuousValues Continuous values at which to compute the
* probability.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
Expand Down
Loading