Skip to content

Commit

Permalink
Merge pull request #1273 from borglab/hybrid-incremental
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 22, 2022
2 parents 4c9c106 + 893c5f7 commit 84456f4
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 263 deletions.
35 changes: 32 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// sum out frontals, this is the factor on the separator
GaussianMixtureFactor::Sum sum = sumFrontals(factors);

// If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });

return hasNull ? GaussianFactorGraph() : gfg;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);

using EliminationPair = GaussianFactorGraph::EliminationResult;

KeyVector keysOfEliminated; // Not the ordering
Expand All @@ -195,7 +208,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);

if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
Expand Down Expand Up @@ -235,14 +251,27 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};

} else {
// Create a resulting DCGaussianMixture on the separator.
// Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}
}
/* ************************************************************************ */
/* ************************************************************************
* Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains
* continuous and hybrid factors
* 2. When the ordering is fully discrete, and the graph only contains discrete
* factors
*
* Any usage outside of this is considered incorrect.
*
* \warning This function is not meant to be used with arbitrary hybrid factor
* graphs. For example, if there exists continuous parents, and one tries to
* eliminate a discrete variable (as specified in the ordering), the result will
* be INCORRECT and there will be NO error raised.
*/
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
Expand Down
20 changes: 17 additions & 3 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ void HybridGaussianISAM::updateInternal(
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));

// KeyVector new

// Get an ordering where the new keys are eliminated last
const VariableIndex index(factors);
Ordering elimination_ordering;
Expand Down Expand Up @@ -107,6 +105,22 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
this->updateInternal(newFactors, &orphans, ordering, function);
}

/* ************************************************************************* */
/**
* @brief Check if `b` is a subset of `a`.
* Non-const since they need to be sorted.
*
* @param a KeyVector
* @param b KeyVector
* @return True if the keys of b is a subset of a, else false.
*/
bool IsSubset(KeyVector a, KeyVector b) {
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
return std::includes(a.begin(), a.end(), b.begin(), b.end());
}

/* ************************************************************************* */
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
Expand All @@ -133,7 +147,7 @@ void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
parents.push_back(parent);
}

if (parents == decisionTree->keys()) {
if (IsSubset(parents, decisionTree->keys())) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());

Expand Down
Loading

0 comments on commit 84456f4

Please sign in to comment.