Skip to content

Commit

Permalink
Merge pull request #1300 from borglab/hybrid/improved-prune-2
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 4, 2022
2 parents 3407f97 + d6d44fc commit cae787a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
28 changes: 26 additions & 2 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,32 @@ static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const {
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;

// The canonical decision tree factor which will get the discrete conditionals
// added to it.
DecisionTreeFactor dtFactor;

for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
dtFactor = dtFactor * f;
}
}
return boost::make_shared<DecisionTreeFactor>(dtFactor);
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals();
const DecisionTreeFactor::shared_ptr discreteFactor =
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));

/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
Expand Down
14 changes: 11 additions & 3 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,17 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
protected:
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

public:
/// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
HybridBayesNet prune(size_t maxNrLeaves) const;

/// @}

Expand Down
2 changes: 0 additions & 2 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@

namespace gtsam {

class HybridGaussianFactorGraph;

/**
* Hybrid Conditional Density
*
Expand Down
19 changes: 18 additions & 1 deletion gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ TEST(HybridBayesNet, Add) {
EXPECT(bayesNet.equals(other));
}


/* ****************************************************************************/
// Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) {
Expand Down Expand Up @@ -184,6 +183,24 @@ TEST(HybridBayesNet, OptimizeMultifrontal) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test bayes net pruning
TEST(HybridBayesNet, Prune) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);

HybridValues delta = hybridBayesNet->optimize();

auto prunedBayesNet = hybridBayesNet->prune(2);
HybridValues pruned_delta = prunedBayesNet.optimize();

EXPECT(assert_equal(delta.discrete(), pruned_delta.discrete()));
EXPECT(assert_equal(delta.continuous(), pruned_delta.continuous()));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Expand Down
1 change: 0 additions & 1 deletion gtsam/inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ namespace gtsam {
// Forward declarations
template<class FACTOR> class FactorGraph;
template<class BAYESTREE, class GRAPH> class EliminatableClusterTree;
class HybridBayesTreeClique;

/* ************************************************************************* */
/** clique statistics */
Expand Down

0 comments on commit cae787a

Please sign in to comment.