Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DefaultOrderingFunc in EliminationTraits #1373

Merged
merged 7 commits into from
Jan 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,17 @@ template<> struct EliminationTraits<DiscreteFactorGraph>
typedef DiscreteBayesTree BayesTreeType; ///< Type of Bayes tree
typedef DiscreteJunctionTree JunctionTreeType; ///< Type of Junction tree
/// The default dense elimination function
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
static std::pair<boost::shared_ptr<ConditionalType>,
boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateDiscrete(factors, keys); }
return EliminateDiscrete(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return Ordering::Colamd(*variableIndex);
}
};

/* ************************************************************************* */
Expand Down
30 changes: 15 additions & 15 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;

/* ************************************************************************ */
const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
KeySet discrete_keys = graph.discreteKeys();
for (auto &factor : graph) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(graph);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@varunagrawal you potentially introduced a performance loss here, as you are not taking an optional variable index. It might have already been computed and passed in (in fact, we do that by default).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DefaultOrderingFunc takes an optional VariableIndex. We'll just have to update the HybridOrdering function to deal with that.

Indeed the continuous only version handles the variable index so there should be no performance loss in terms of backwards compatibility.

Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
Expand Down Expand Up @@ -448,21 +463,6 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
Expand Down
23 changes: 15 additions & 8 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ GTSAM_EXPORT
std::pair<boost::shared_ptr<HybridConditional>, HybridFactor::shared_ptr>
EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys);

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
GTSAM_EXPORT const Ordering
HybridOrdering(const HybridGaussianFactorGraph& graph);

/* ************************************************************************* */
template <>
struct EliminationTraits<HybridGaussianFactorGraph> {
Expand All @@ -74,6 +83,12 @@ struct EliminationTraits<HybridGaussianFactorGraph> {
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminateHybrid(factors, keys);
}
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return HybridOrdering(graph);
}
};

/**
Expand Down Expand Up @@ -228,14 +243,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
*/
double probPrime(const HybridValues& values) const;

/**
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
const Ordering getHybridOrdering() const;

/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
Expand Down
15 changes: 5 additions & 10 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,8 @@ TEST(HybridBayesNet, OptimizeAssignment) {
TEST(HybridBayesNet, Optimize) {
Switching s(4, 1.0, 0.1, {0, 1, 2, 3}, "1/1 1/1");

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();

Expand All @@ -212,9 +211,8 @@ TEST(HybridBayesNet, Optimize) {
TEST(HybridBayesNet, Error) {
Switching s(3);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = hybridBayesNet->error(delta.continuous());
Expand Down Expand Up @@ -266,9 +264,8 @@ TEST(HybridBayesNet, Error) {
TEST(HybridBayesNet, Prune) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();

Expand All @@ -284,9 +281,8 @@ TEST(HybridBayesNet, Prune) {
TEST(HybridBayesNet, UpdateDiscreteConditionals) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);
s.linearizedFactorGraph.eliminateSequential();

size_t maxNrLeaves = 3;
auto discreteConditionals = hybridBayesNet->discreteConditionals();
Expand Down Expand Up @@ -353,8 +349,7 @@ TEST(HybridBayesNet, Sampling) {
// Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
// Eliminate into BN
Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
HybridBayesNet::shared_ptr bn = fg->eliminateSequential();

// Set up sampling
std::mt19937_64 gen(11);
Expand Down
15 changes: 3 additions & 12 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ using symbol_shorthand::X;
TEST(HybridBayesTree, OptimizeMultifrontal) {
Switching s(4);

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree::shared_ptr hybridBayesTree =
s.linearizedFactorGraph.eliminateMultifrontal(hybridOrdering);
s.linearizedFactorGraph.eliminateMultifrontal();
HybridValues delta = hybridBayesTree->optimize();

VectorValues expectedValues;
Expand Down Expand Up @@ -203,16 +202,8 @@ TEST(HybridBayesTree, Choose) {

GaussianBayesTree gbt = isam.choose(assignment);

Ordering ordering;
ordering += X(0);
ordering += X(1);
ordering += X(2);
ordering += X(3);
ordering += M(0);
ordering += M(1);
ordering += M(2);

// TODO(Varun) get segfault if ordering not provided
// Specify ordering so it matches that of HybridGaussianISAM.
Ordering ordering(KeyVector{X(0), X(1), X(2), X(3), M(0), M(1), M(2)});
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);

auto expected_gbt = bayesTree->choose(assignment);
Expand Down
5 changes: 2 additions & 3 deletions gtsam/hybrid/tests/testHybridEstimation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TEST(HybridEstimation, Full) {
}

HybridBayesNet::shared_ptr bayesNet =
graph.eliminateSequential(hybridOrdering);
graph.eliminateSequential();

EXPECT_LONGS_EQUAL(2 * K - 1, bayesNet->size());

Expand Down Expand Up @@ -481,8 +481,7 @@ TEST(HybridEstimation, CorrectnessViaSampling) {
const auto fg = createHybridGaussianFactorGraph();

// 2. Eliminate into BN
const Ordering ordering = fg->getHybridOrdering();
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
const HybridBayesNet::shared_ptr bn = fg->eliminateSequential();

// Set up sampling
std::mt19937_64 rng(11);
Expand Down
38 changes: 10 additions & 28 deletions gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {

hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt));

auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
auto result = hfg.eliminateSequential();

auto dc = result->at(2)->asDiscrete();
DiscreteValues dv;
Expand Down Expand Up @@ -161,8 +160,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) {
// Joint discrete probability table for c1, c2
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesNet::shared_ptr result = hfg.eliminateSequential(
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
HybridBayesNet::shared_ptr result = hfg.eliminateSequential();

// There are 4 variables (2 continuous + 2 discrete) in the bayes net.
EXPECT_LONGS_EQUAL(4, result->size());
Expand All @@ -187,8 +185,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
// variable throws segfault
// hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesTree::shared_ptr result =
hfg.eliminateMultifrontal(hfg.getHybridOrdering());
HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal();

// The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size());
Expand Down Expand Up @@ -218,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));

// Get a constrained ordering keeping c1 last
auto ordering_full = hfg.getHybridOrdering();
auto ordering_full = HybridOrdering(hfg);

// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
Expand Down Expand Up @@ -518,8 +515,7 @@ TEST(HybridGaussianFactorGraph, optimize) {

hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt));

auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
auto result = hfg.eliminateSequential();

HybridValues hv = result->optimize();

Expand Down Expand Up @@ -572,9 +568,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrime) {

HybridGaussianFactorGraph graph = s.linearizedFactorGraph;

Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

const HybridValues delta = hybridBayesNet->optimize();
const double error = graph.error(delta);
Expand All @@ -593,9 +587,7 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {

HybridGaussianFactorGraph graph = s.linearizedFactorGraph;

Ordering hybridOrdering = graph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
graph.eliminateSequential(hybridOrdering);
HybridBayesNet::shared_ptr hybridBayesNet = graph.eliminateSequential();

HybridValues delta = hybridBayesNet->optimize();
auto error_tree = graph.error(delta.continuous());
Expand Down Expand Up @@ -684,10 +676,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "74/26"));

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}

Expand Down Expand Up @@ -719,10 +708,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
expectedBayesNet.emplace_back(new DiscreteConditional(mode, "23/77"));

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}

Expand All @@ -741,11 +727,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny22) {
EXPECT_LONGS_EQUAL(5, fg.size());

// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
ordering.push_back(M(1));
const auto posterior = fg.eliminateSequential(ordering);
const auto posterior = fg.eliminateSequential();

// Compute the log-ratio between the Bayes net and the factor graph.
auto compute_ratio = [&](HybridValues *sample) -> double {
Expand Down
7 changes: 2 additions & 5 deletions gtsam/hybrid/tests/testSerializationHybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ TEST(HybridSerialization, GaussianMixture) {
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential());

EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
Expand All @@ -162,9 +161,7 @@ TEST(HybridSerialization, HybridBayesNet) {
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
HybridBayesTree hbt = *(s.linearizedFactorGraph.eliminateMultifrontal());

EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
Expand Down
18 changes: 16 additions & 2 deletions gtsam/inference/EliminateableFactorGraph-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,16 @@ namespace gtsam {
if (orderingType == Ordering::METIS) {
Ordering computedOrdering = Ordering::Metis(asDerived());
return eliminateSequential(computedOrdering, function, variableIndex);
} else {
} else if (orderingType == Ordering::COLAMD) {
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
return eliminateSequential(computedOrdering, function, variableIndex);
} else if (orderingType == Ordering::NATURAL) {
Ordering computedOrdering = Ordering::Natural(asDerived());
return eliminateSequential(computedOrdering, function, variableIndex);
} else {
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
asDerived(), variableIndex);
return eliminateSequential(computedOrdering, function, variableIndex);
}
}
}
Expand Down Expand Up @@ -100,9 +107,16 @@ namespace gtsam {
if (orderingType == Ordering::METIS) {
Ordering computedOrdering = Ordering::Metis(asDerived());
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else {
} else if (orderingType == Ordering::COLAMD) {
Ordering computedOrdering = Ordering::Colamd(*variableIndex);
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else if (orderingType == Ordering::NATURAL) {
Ordering computedOrdering = Ordering::Natural(asDerived());
return eliminateMultifrontal(computedOrdering, function, variableIndex);
} else {
Ordering computedOrdering = EliminationTraitsType::DefaultOrderingFunc(
asDerived(), variableIndex);
return eliminateMultifrontal(computedOrdering, function, variableIndex);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions gtsam/linear/GaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ namespace gtsam {
static std::pair<boost::shared_ptr<ConditionalType>, boost::shared_ptr<FactorType> >
DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) {
return EliminatePreferCholesky(factors, keys); }
/// The default ordering generation function
static Ordering DefaultOrderingFunc(
const FactorGraphType& graph,
boost::optional<const VariableIndex&> variableIndex) {
return Ordering::Colamd(*variableIndex);
}
};

/* ************************************************************************* */
Expand Down
Loading