Skip to content

Commit

Permalink
Merge pull request #1284 from borglab/hybrid/misc
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 31, 2022
2 parents f7e1d2a + 2c4866e commit 7c84020
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 22 deletions.
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ namespace gtsam {
push_back(key);
return *this;
}

/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}

}; // DiscreteKeys

/// Create a list from two keys
Expand Down
25 changes: 23 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
OptionalOrderingType orderingType) const {
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
Expand Down
17 changes: 11 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;

/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;

/**
* @brief
*
* @param orderingType
* @return const Ordering
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
const Ordering getHybridOrdering(
OptionalOrderingType orderingType = boost::none) const;
const Ordering getHybridOrdering() const;
};

} // namespace gtsam
18 changes: 9 additions & 9 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
namespace gtsam {

/**
* HybridValues represents a collection of DiscreteValues and VectorValues. It
* is typically used to store the variables of a HybridGaussianFactorGraph.
* HybridValues represents a collection of DiscreteValues and VectorValues.
* It is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
Expand All @@ -47,18 +47,18 @@ class GTSAM_EXPORT HybridValues {
/// @name Standard Constructors
/// @{

// Default constructor creates an empty HybridValues.
/// Default constructor creates an empty HybridValues.
HybridValues() = default;

// Construct from DiscreteValues and VectorValues.
/// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete_(dv), continuous_(cv){};

/// @}
/// @name Testable
/// @{

// print required by Testable for unit testing
/// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
Expand All @@ -67,7 +67,7 @@ class GTSAM_EXPORT HybridValues {
keyFormatter); // print continuous components
};

// equals required by Testable for unit testing
/// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete_.equals(other.discrete_, tol) &&
continuous_.equals(other.continuous_, tol);
Expand All @@ -83,13 +83,13 @@ class GTSAM_EXPORT HybridValues {
/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }

// Check whether a variable with key \c j exists in DiscreteValue.
/// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };

// Check whether a variable with key \c j exists in VectorValue.
/// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous_.exists(j); };

// Check whether a variable with key \c j exists.
/// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };

/** Insert a discrete \c value with key \c j. Replaces the existing value if
Expand Down
2 changes: 2 additions & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class HybridBayesTree {
bool empty() const;
const HybridBayesTreeClique* operator[](size_t j) const;

gtsam::HybridValues optimize() const;

string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
Expand Down
9 changes: 4 additions & 5 deletions gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
hfg.add(DecisionTreeFactor(m1, {2, 8}));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
HybridBayesTree::shared_ptr result =
hfg.eliminateMultifrontal(hfg.getHybridOrdering());

// The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size());
Expand Down Expand Up @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));

// Get a constrained ordering keeping c1 last
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)});
auto ordering_full = hfg.getHybridOrdering();

// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
Expand Down Expand Up @@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
}
HybridBayesNet::shared_ptr hbn;
HybridGaussianFactorGraph::shared_ptr remaining;
std::tie(hbn, remaining) =
hfg->eliminatePartialSequential(ordering_partial);
std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);

EXPECT_LONGS_EQUAL(14, hbn->size());
EXPECT_LONGS_EQUAL(11, remaining->size());
Expand Down

0 comments on commit 7c84020

Please sign in to comment.