Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Serialization #1286

Merged
merged 7 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions gtsam/discrete/DiscreteKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,25 @@ namespace gtsam {
return keys & key2;
}

void DiscreteKeys::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}

bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const {
if (this->size() != other.size()) {
return false;
}

for (size_t i = 0; i < this->size(); i++) {
if (this->at(i).first != other.at(i).first ||
this->at(i).second != other.at(i).second) {
return false;
}
}
return true;
}
}
25 changes: 19 additions & 6 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <gtsam/global_includes.h>
#include <gtsam/inference/Key.h>

#include <boost/serialization/vector.hpp>
#include <map>
#include <string>
#include <vector>
Expand Down Expand Up @@ -72,15 +73,27 @@ namespace gtsam {

/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// Check equality to another DiscreteKeys object.
bool equals(const DiscreteKeys& other, double tol = 0) const;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& boost::serialization::make_nvp(
"DiscreteKeys",
boost::serialization::base_object<std::vector<DiscreteKey>>(*this));
}

}; // DiscreteKeys

/// Create a list from two keys
GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2);
}

// traits
template <>
struct traits<DiscreteKeys> : public Testable<DiscreteKeys> {};

} // namespace gtsam
18 changes: 16 additions & 2 deletions gtsam/discrete/tests/testDiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,33 @@
* @author Duy-Nguyen Ta
*/

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <CppUnitLite/TestHarness.h>

#include <boost/assign/std/map.hpp>
using namespace boost::assign;

using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;

/* ************************************************************************* */
TEST(DisreteKeys, Serialization) {
DiscreteKeys keys;
keys& DiscreteKey(0, 2);
keys& DiscreteKey(1, 3);
keys& DiscreteKey(2, 4);

EXPECT(equalsObj<DiscreteKeys>(keys));
EXPECT(equalsXML<DiscreteKeys>(keys));
EXPECT(equalsBinary<DiscreteKeys>(keys));
}

/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */

54 changes: 48 additions & 6 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/global_includes.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
Expand All @@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
using sharedConditional = boost::shared_ptr<ConditionalType>;

/// @name Standard Constructors
/// @{

/** Construct empty bayes net */
HybridBayesNet() = default;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
/// @}
/// @name Testable
/// @{

/** Check equality */
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}

/// print graph
void print(
const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}

/// @}
/// @name Standard Interface
/// @{

/// Add HybridConditional to Bayes Net
using Base::add;
Expand Down Expand Up @@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
* the MPE assignment.
*
* @return HybridValues
*/
HybridValues optimize() const;

/**
Expand All @@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/// Prune the Hybrid Bayes Net given the discrete decision tree.
HybridBayesNet prune(
const DecisionTreeFactor::shared_ptr &discreteFactor) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

/// traits
template <>
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};

} // namespace gtsam
12 changes: 12 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
VectorValues optimize(const DiscreteValues& assignment) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

/// traits
template <>
struct traits<HybridBayesTree> : public Testable<HybridBayesTree> {};

/**
* @brief Class for Hybrid Bayes tree orphan subtrees.
*
Expand Down
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}

}; // HybridConditional

// traits
Expand Down
14 changes: 14 additions & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
bool isContinuous_ = false;
bool isHybrid_ = false;

// TODO(Varun) remove
size_t nrContinuous_ = 0;

protected:
Expand Down Expand Up @@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor {
const KeyVector &continuousKeys() const { return continuousKeys_; }

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
ar &BOOST_SERIALIZATION_NVP(isContinuous_);
ar &BOOST_SERIALIZATION_NVP(isHybrid_);
ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
}
};
// HybridFactor

Expand Down
15 changes: 15 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* @date December 2021
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>

Expand All @@ -28,6 +29,8 @@

using namespace std;
using namespace gtsam;
using namespace gtsam::serializationTestHelpers;

using noiseModel::Isotropic;
using symbol_shorthand::M;
using symbol_shorthand::X;
Expand Down Expand Up @@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}

/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));

EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
15 changes: 15 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* @date August 2022
*/

#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
Expand Down Expand Up @@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous()));
}

/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));

using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
27 changes: 27 additions & 0 deletions gtsam/linear/tests/testSerializationLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) {
EXPECT(equalsBinary(graph));
}

/* ****************************************************************************/
TEST(Serialization, gaussian_bayes_net) {
// Create an arbitrary Bayes Net
GaussianBayesNet gbn;
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3,
(Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4,
(Matrix2() << 11.0, 12.0, 13.0, 14.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(),
2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4,
(Matrix2() << 25.0, 26.0, 27.0, 28.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(),
3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(),
4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished()));
gbn += GaussianConditional::shared_ptr(new GaussianConditional(
4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished()));

std::string serialized = serialize(gbn);
GaussianBayesNet actual;
deserialize(serialized, actual);
EXPECT(assert_equal(gbn, actual));
}

/* ************************************************************************* */
TEST (Serialization, gaussian_bayes_tree) {
const Key x1=1, x2=2, x3=3, x4=4;
Expand Down