Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid ISAM #1273

Merged
merged 6 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,19 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// sum out frontals, this is the factor on the separator
GaussianMixtureFactor::Sum sum = sumFrontals(factors);

// If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });

return hasNull ? GaussianFactorGraph() : gfg;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);

using EliminationPair = GaussianFactorGraph::EliminationResult;

KeyVector keysOfEliminated; // Not the ordering
Expand All @@ -195,7 +208,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (graph.empty()) {
return {nullptr, nullptr};
}
auto result = EliminatePreferCholesky(graph, frontalKeys);
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
result = EliminatePreferCholesky(graph, frontalKeys);

if (keysOfEliminated.empty()) {
keysOfEliminated =
result.first->keys(); // Initialize the keysOfEliminated to be the
Expand Down Expand Up @@ -235,14 +251,27 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
boost::make_shared<HybridDiscreteFactor>(discreteFactor)};

} else {
// Create a resulting DCGaussianMixture on the separator.
// Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>(
KeyVector(continuousSeparator.begin(), continuousSeparator.end()),
discreteSeparator, separatorFactors);
return {boost::make_shared<HybridConditional>(conditional), factor};
}
}
/* ************************************************************************ */
/* ************************************************************************
* Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains
* continuous and hybrid factors
* 2. When the ordering is fully discrete, and the graph only contains discrete
* factors
*
* Any usage outside of this is considered incorrect.
*
* \warning This function is not meant to be used with arbitrary hybrid factor
* graphs. For example, if there exists continuous parents, and one tries to
* eliminate a discrete variable (as specified in the ordering), the result will
* be INCORRECT and there will be NO error raised.
*/
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
Expand Down
20 changes: 17 additions & 3 deletions gtsam/hybrid/HybridGaussianISAM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ void HybridGaussianISAM::updateInternal(
std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast));

// KeyVector new

// Get an ordering where the new keys are eliminated last
const VariableIndex index(factors);
Ordering elimination_ordering;
Expand Down Expand Up @@ -107,6 +105,22 @@ void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors,
this->updateInternal(newFactors, &orphans, ordering, function);
}

/* ************************************************************************* */
/**
* @brief Check if `b` is a subset of `a`.
* Non-const since they need to be sorted.
*
* @param a KeyVector
* @param b KeyVector
* @return True if the keys of b is a subset of a, else false.
*/
bool IsSubset(KeyVector a, KeyVector b) {
std::sort(a.begin(), a.end());
std::sort(b.begin(), b.end());
return std::includes(a.begin(), a.end(), b.begin(), b.end());
}

/* ************************************************************************* */
void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->clique(root)->conditional()->inner());
Expand All @@ -133,7 +147,7 @@ void HybridGaussianISAM::prune(const Key& root, const size_t maxNrLeaves) {
parents.push_back(parent);
}

if (parents == decisionTree->keys()) {
if (IsSubset(parents, decisionTree->keys())) {
auto gaussianMixture = boost::dynamic_pointer_cast<GaussianMixture>(
clique.second->conditional()->inner());

Expand Down
Loading