Skip to content

Commit

Permalink
Merge pull request #1277 from borglab/feature/nonlinear-hybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 22, 2022
2 parents 7977f77 + 84456f4 commit 05b1174
Show file tree
Hide file tree
Showing 6 changed files with 561 additions and 515 deletions.
27 changes: 26 additions & 1 deletion gtsam/hybrid/GaussianMixture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,36 @@ void GaussianMixture::print(const std::string &s,
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty())
if (gf && !gf->empty())
gf->print("", formatter);
else
return {"nullptr"};
return rd.str();
});
}

/* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = [&decisionTree](
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);

if (decisionTree(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
};

auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;
}

} // namespace gtsam
12 changes: 11 additions & 1 deletion gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
Expand Down Expand Up @@ -121,7 +122,7 @@ class GTSAM_EXPORT GaussianMixture
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

/* print utility */
/// Print utility
void print(
const std::string &s = "GaussianMixture\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
Expand All @@ -131,6 +132,15 @@ class GTSAM_EXPORT GaussianMixture
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();

/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
*
* @param decisionTree A pruned decision tree of discrete keys where the
* leaves are probabilities.
*/
void prune(const DecisionTreeFactor &decisionTree);

/**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
Expand Down
35 changes: 32 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// sum out frontals, this is the factor on the separator
GaussianMixtureFactor::Sum sum = sumFrontals(factors);

// If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });

return hasNull ? GaussianFactorGraph() : gfg;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);

using EliminationPair = GaussianFactorGraph::EliminationResult;

KeyVector keysOfEliminated; // Not the ordering
Expand All @@ -195,7 +208,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);

if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
Expand Down Expand Up @@ -235,14 +251,27 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};

} else {
// Create a resulting DCGaussianMixture on the separator.
// Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}
}
/* ************************************************************************ */
/* ************************************************************************
* Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains
* continuous and hybrid factors
* 2. When the ordering is fully discrete, and the graph only contains discrete
* factors
*
* Any usage outside of this is considered incorrect.
*
* \warning This function is not meant to be used with arbitrary hybrid factor
* graphs. For example, if there exists continuous parents, and one tries to
* eliminate a discrete variable (as specified in the ordering), the result will
* be INCORRECT and there will be NO error raised.
*/
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
Expand Down
72 changes: 65 additions & 7 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree)
void HybridGaussianISAM::updateInternal(
const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans,
const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) {
// Remove the contaminated part of the Bayes tree
BayesNetType bn;
Expand Down Expand Up @@ -74,16 +75,21 @@ void HybridGaussianISAM::updateInternal(
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));

// KeyVector new

// Get an ordering where the new keys are eliminated last
const VariableIndex index(factors);
const Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
Ordering elimination_ordering;
if (ordering) {
elimination_ordering = *ordering;
} else {
elimination_ordering = Ordering::ColamdConstrainedLast(
index,
KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()),
true);
}

// eliminate all factors (top, added, orphans) into a new Bayes tree
auto bayesTree = factors.eliminateMultifrontal(ordering, function, index);
HybridBayesTree::shared_ptr bayesTree =
factors.eliminateMultifrontal(elimination_ordering, function, index);

// Re-add into Bayes tree data structures
this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(),
Expand All @@ -93,9 +99,61 @@ void HybridGaussianISAM::updateInternal(

/* ************************************************************************* */
void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<Ordering>& ordering,
const HybridBayesTree::Eliminate& function) {
Cliques orphans;
this->updateInternal(newFactors, &orphans, function);
this->updateInternal(newFactors, &orphans, ordering, function);
}

/* ************************************************************************* */
/**
* @brief Check if `b` is a subset of `a`.
* Non-const since they need to be sorted.
*
* @param a KeyVector
* @param b KeyVector
* @return True if the keys of b is a subset of a, else false.
*/
bool IsSubset(KeyVector a, KeyVector b) {
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
return std::includes(a.begin(), a.end(), b.begin(), b.end());
}

/* ************************************************************************* */
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;

std::vector<gtsam::Key> prunedKeys;
for (auto&& clique : nodes()) {
// The cliques can be repeated for each frontal so we record it in
// prunedKeys and check if we have already pruned a particular clique.
if (std::find(prunedKeys.begin(), prunedKeys.end(), clique.first) !=
prunedKeys.end()) {
continue;
}

// Add all the keys of the current clique to be pruned to prunedKeys
for (auto&& key : clique.second->conditional()->frontals()) {
prunedKeys.push_back(key);
}

// Convert parents() to a KeyVector for comparison
KeyVector parents;
for (auto&& parent : clique.second->conditional()->parents()) {
parents.push_back(parent);
}

if (IsSubset(parents, decisionTree->keys())) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());

gaussianMixture->prune(prunedDiscreteFactor);
}
}
}

} // namespace gtsam
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridGaussianISAM.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
void updateInternal(
const HybridGaussianFactorGraph& newFactors,
HybridBayesTree::Cliques* orphans,
const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate);

Expand All @@ -59,8 +60,17 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
* @param function Elimination function.
*/
void update(const HybridGaussianFactorGraph& newFactors,
const boost::optional<Ordering>& ordering = boost::none,
const HybridBayesTree::Eliminate& function =
HybridBayesTree::EliminationTraitsType::DefaultEliminate);

/**
* @brief
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
*/
void prune(const Key& root, const size_t maxNumberLeaves);
};

/// traits
Expand Down
Loading

0 comments on commit 05b1174

Please sign in to comment.