Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Bayes Net pruning #1300

Merged
merged 2 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;

// The canonical decision tree factor which will get the discrete conditionals
// added to it.
DecisionTreeFactor dtFactor;

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
dtFactor = dtFactor * f;
}
}
return boost::make_shared<DecisionTreeFactor>(dtFactor);
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));

/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
Expand Down
14 changes: 11 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;

/// @}

Expand Down
2 changes: 0 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

namespace gtsam {

class HybridGaussianFactorGraph;

/**
* Hybrid Conditional Density
*
Expand Down
19 changes: 18 additions & 1 deletion gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) {
EXPECT(bayesNet.equals(other));
}


/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
Expand Down Expand Up @@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);

HybridValues delta = hybridBayesNet->optimize();

auto prunedBayesNet = hybridBayesNet->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();

EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Expand Down
1 change: 0 additions & 1 deletion gtsam/inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ namespace gtsam {
// Forward declarations
template<class FACTOR> class FactorGraph;
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
class HybridBayesTreeClique;

/* ************************************************************************* */
/** clique statistics */
Expand Down