Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Bayes Tree Optimize #1282

Merged
merged 13 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
* @date January 2022
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>

namespace gtsam {
Expand Down Expand Up @@ -111,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
}

/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
Expand All @@ -125,22 +131,39 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
try {
GaussianMixture gm = *this->atGaussian(idx);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment));

} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue;
}
}

return gbn;
}

/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}

DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();

// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
}

/* *******************************************************************************/
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
}

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;
GaussianMixture::shared_ptr atMixture(size_t i) const;

/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;

/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
Expand Down
59 changes: 55 additions & 4 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/

#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
Expand All @@ -35,6 +37,52 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}

/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
HybridBayesNet hbn;
DiscreteBayesNet dbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

if (conditional->isDiscrete()) {
// If discrete, we use it to compute the MPE
dbn.push_back(conditional->asDiscreteConditional());

} else {
// Else conditional is hybrid or continuous-only,
// so we directly add it to the Hybrid Bayes net.
hbn.push_back(conditional);
}
}
}
// Get the MPE
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hbn.choose(mpe);

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif

return HybridValues(mpe, gbn.optimize());
}

/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;
Expand All @@ -50,11 +98,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());

// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
Expand All @@ -65,9 +111,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
(*gm)(assignment);

gbn.push_back(gaussian_conditional);

} else if (conditional->isContinuous()) {
// If conditional is Gaussian, we simply add it to the Bayes net.
gbn.push_back(conditional->asGaussian());
}
}
}

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
Expand Down
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;

/**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous
* update delta.
*
* @return HybridValues
*/
HybridValues optimize() const;

/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
Expand Down
11 changes: 11 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
return boost::static_pointer_cast<GaussianMixture>(inner_);
}

/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
}

/**
* @brief Return conditional as a DiscreteConditional
*
Expand Down
76 changes: 0 additions & 76 deletions gtsam/hybrid/HybridLookupDAG.cpp

This file was deleted.

Loading