Skip to content

Commit

Permalink
Merge pull request #1286 from borglab/hybrid/serialization
Browse files Browse the repository at this point in the history
Hybrid Serialization
  • Loading branch information
varunagrawal authored Sep 2, 2022
2 parents 7c84020 + 27a9d56 commit 30c913e
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 14 deletions.
21 changes: 21 additions & 0 deletions gtsam/discrete/DiscreteKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,25 @@ namespace gtsam {
return keys & key2;
}

void DiscreteKeys::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}

bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
if (this->size() != other.size()) {
return false;
}

for (size_t i = 0; i < this->size(); i++) {
if (this->at(i).first != other.at(i).first ||
this->at(i).second != other.at(i).second) {
return false;
}
}
return true;
}
}
25 changes: 19 additions & 6 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h>

#include <boost/serialization/vector.hpp>
#include <map>
#include <string>
#include <vector>
Expand Down Expand Up @@ -72,15 +73,27 @@ namespace gtsam {

/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// Check equality to another DiscreteKeys object.
bool equals(const DiscreteKeys& other, double tol = 0) const;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& boost::serialization::make_nvp(
"DiscreteKeys",
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
}

}; // DiscreteKeys

/// Create a list from two keys
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
}

// traits
template <>
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};

} // namespace gtsam
18 changes: 16 additions & 2 deletions gtsam/discrete/tests/testDiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,33 @@
* @author Duy-Nguyen Ta
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <CppUnitLite/TestHarness.h>

#include <boost/assign/std/map.hpp>
using namespace boost::assign;

using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;

/* ************************************************************************* */
TEST(DisreteKeys, Serialization) {
DiscreteKeys keys;
keys& DiscreteKey(0, 2);
keys& DiscreteKey(1, 3);
keys& DiscreteKey(2, 4);

EXPECT(equalsObj<DiscreteKeys>(keys));
EXPECT(equalsXML<DiscreteKeys>(keys));
EXPECT(equalsBinary<DiscreteKeys>(keys));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

54 changes: 48 additions & 6 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/global_includes.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
Expand All @@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
using sharedConditional = boost::shared_ptr<ConditionalType>;

/// @name Standard Constructors
/// @{

/** Construct empty bayes net */
HybridBayesNet() = default;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
/// @}
/// @name Testable
/// @{

/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}

/// print graph
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}

/// @}
/// @name Standard Interface
/// @{

/// Add HybridConditional to Bayes Net
using Base::add;
Expand Down Expand Up @@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
* the MPE assignment.
*
* @return HybridValues
*/
HybridValues optimize() const;

/**
Expand All @@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

/// traits
template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};

} // namespace gtsam
12 changes: 12 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
VectorValues optimize(const DiscreteValues& assignment) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

/// traits
template <>
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};

/**
* @brief Class for Hybrid Bayes tree orphan subtrees.
*
Expand Down
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}

}; // HybridConditional

// traits
Expand Down
14 changes: 14 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
bool isContinuous_ = false;
bool isHybrid_ = false;

// TODO(Varun) remove
size_t nrContinuous_ = 0;

protected:
Expand Down Expand Up @@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
}
};
// HybridFactor

Expand Down
15 changes: 15 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* @date December 2021
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>

Expand All @@ -28,6 +29,8 @@

using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;

using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
Expand Down Expand Up @@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));

EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
15 changes: 15 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* @date August 2022
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
Expand Down Expand Up @@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous()));
}

/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));

using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
27 changes: 27 additions & 0 deletions gtsam/linear/tests/testSerializationLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
EXPECT(equalsBinary(graph));
}

/* ****************************************************************************/
TEST(Serialization, gaussian_bayes_net) {
// Create an arbitrary Bayes Net
GaussianBayesNet gbn;
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));

std::string serialized = serialize(gbn);
GaussianBayesNet actual;
deserialize(serialized, actual);
EXPECT(assert_equal(gbn, actual));
}

/* ************************************************************************* */
TEST (Serialization, gaussian_bayes_tree) {
const Key x1=1, x2=2, x3=3, x4=4;
Expand Down

0 comments on commit 30c913e

Please sign in to comment.