Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Gaussian Conditional Elimination #1301

Merged
merged 4 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
}

} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting!


} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors
Expand Down
11 changes: 9 additions & 2 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
/* ***************************************************************************
*/
using MotionModel = BetweenFactor<double>;
// using MotionMixture = MixtureFactor<MotionModel>;

// Test fixture with switching network.
struct Switching {
Expand All @@ -125,7 +124,13 @@ struct Switching {
HybridGaussianFactorGraph linearizedFactorGraph;
Values linearizationPoint;

/// Create with given number of time steps.
/**
* @brief Create with given number of time steps.
*
* @param K The total number of timesteps.
* @param between_sigma The stddev between poses.
* @param prior_sigma The stddev on priors (also used for measurements).
*/
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1)
: K(K) {
// Create DiscreteKeys for binary K modes, modes[0] will not be used.
Expand Down Expand Up @@ -166,6 +171,8 @@ struct Switching {
linearizationPoint.insert<double>(X(k), static_cast<double>(k));
}

// The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
}
}

/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, optimize) {
HybridGaussianFactorGraph hfg;

Expand All @@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {

EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
}

/* ************************************************************************* */
// Test adding of gaussian conditional and re-elimination.
TEST(HybridGaussianFactorGraph, Conditionals) {
Switching switching(4);
HybridGaussianFactorGraph hfg;

hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
Ordering ordering;
ordering.push_back(X(1));
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);

hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
hfg.push_back(*bayes_net);
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
ordering.push_back(X(2));
ordering.push_back(X(3));
ordering.push_back(M(1));
ordering.push_back(M(2));

bayes_net = hfg.eliminateSequential(ordering);

HybridValues result = bayes_net->optimize();

Values expected_continuous;
expected_continuous.insert<double>(X(1), 0);
expected_continuous.insert<double>(X(2), 1);
expected_continuous.insert<double>(X(3), 2);
expected_continuous.insert<double>(X(4), 4);
Values result_continuous =
switching.linearizationPoint.retract(result.continuous());
EXPECT(assert_equal(expected_continuous, result_continuous));

DiscreteValues expected_discrete;
expected_discrete[M(1)] = 1;
expected_discrete[M(2)] = 1;
EXPECT(assert_equal(expected_discrete, result.discrete()));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down