Skip to content

Commit

Permalink
Merge pull request #1301 from borglab/hybrid/gaussian-conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 7, 2022
2 parents cae787a + 6238a1f commit fc9fc72
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
8 changes: 6 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ GaussianMixtureFactor::Sum sumFrontals(
}

} else if (f->isContinuous()) {
deferredFactors.push_back(
boost::dynamic_pointer_cast<HybridGaussianFactor>(f)->inner());
if (auto gf = boost::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
deferredFactors.push_back(gf->inner());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(f)) {
deferredFactors.push_back(cg->asGaussian());
}

} else if (f->isDiscrete()) {
// Don't do anything for discrete-only factors
Expand Down
11 changes: 9 additions & 2 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
/* ***************************************************************************
*/
using MotionModel = BetweenFactor<double>;
// using MotionMixture = MixtureFactor<MotionModel>;

// Test fixture with switching network.
struct Switching {
Expand All @@ -125,7 +124,13 @@ struct Switching {
HybridGaussianFactorGraph linearizedFactorGraph;
Values linearizationPoint;

/// Create with given number of time steps.
/**
* @brief Create with given number of time steps.
*
* @param K The total number of timesteps.
* @param between_sigma The stddev between poses.
* @param prior_sigma The stddev on priors (also used for measurements).
*/
Switching(size_t K, double between_sigma = 1.0, double prior_sigma = 0.1)
: K(K) {
// Create DiscreteKeys for binary K modes, modes[0] will not be used.
Expand Down Expand Up @@ -166,6 +171,8 @@ struct Switching {
linearizationPoint.insert<double>(X(k), static_cast<double>(k));
}

// The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
}
}

/* ************************************************************************* */
TEST(HybridGaussianFactorGraph, optimize) {
HybridGaussianFactorGraph hfg;

Expand All @@ -521,6 +522,46 @@ TEST(HybridGaussianFactorGraph, optimize) {

EXPECT(assert_equal(hv.atDiscrete(C(1)), int(0)));
}

/* ************************************************************************* */
// Test adding of gaussian conditional and re-elimination.
TEST(HybridGaussianFactorGraph, Conditionals) {
Switching switching(4);
HybridGaussianFactorGraph hfg;

hfg.push_back(switching.linearizedFactorGraph.at(0)); // P(X1)
Ordering ordering;
ordering.push_back(X(1));
HybridBayesNet::shared_ptr bayes_net = hfg.eliminateSequential(ordering);

hfg.push_back(switching.linearizedFactorGraph.at(1)); // P(X1, X2 | M1)
hfg.push_back(*bayes_net);
hfg.push_back(switching.linearizedFactorGraph.at(2)); // P(X2, X3 | M2)
hfg.push_back(switching.linearizedFactorGraph.at(5)); // P(M1)
ordering.push_back(X(2));
ordering.push_back(X(3));
ordering.push_back(M(1));
ordering.push_back(M(2));

bayes_net = hfg.eliminateSequential(ordering);

HybridValues result = bayes_net->optimize();

Values expected_continuous;
expected_continuous.insert<double>(X(1), 0);
expected_continuous.insert<double>(X(2), 1);
expected_continuous.insert<double>(X(3), 2);
expected_continuous.insert<double>(X(4), 4);
Values result_continuous =
switching.linearizationPoint.retract(result.continuous());
EXPECT(assert_equal(expected_continuous, result_continuous));

DiscreteValues expected_discrete;
expected_discrete[M(1)] = 1;
expected_discrete[M(2)] = 1;
EXPECT(assert_equal(expected_discrete, result.discrete()));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down

0 comments on commit fc9fc72

Please sign in to comment.