Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper Hybrid Elimination #1319

Merged
merged 63 commits into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
0839331
minor typo fixes
varunagrawal Oct 25, 2022
7dec7bb
remove if guards and add timing counters
varunagrawal Oct 25, 2022
dcdcf30
new WIP test to check the discrete probabilities after elimination
varunagrawal Oct 25, 2022
1789bb7
showing difference in computed probabilities
varunagrawal Oct 25, 2022
96afdff
Increase the number of time steps for incremental test case
varunagrawal Nov 1, 2022
5bfce89
Merge branch 'hybrid/error' into hybrid/tests
varunagrawal Nov 4, 2022
a97d27e
Merge branch 'hybrid/error' into hybrid/tests
varunagrawal Nov 7, 2022
a6d1a57
fix hybrid elimination
varunagrawal Nov 7, 2022
2f2f8c9
figured out how to get the correct continuous errors
varunagrawal Nov 7, 2022
98febf2
allow optional discrete transition probability for Switching test fix…
varunagrawal Nov 7, 2022
610a535
set up unit test to verify that the probPrimeTree has the same values…
varunagrawal Nov 7, 2022
1815433
add methods to perform correct elimination
varunagrawal Nov 7, 2022
090cc42
update HybridSmoother to use the new method
varunagrawal Nov 7, 2022
083fd21
use long sequence in HybridEstimation test
varunagrawal Nov 7, 2022
64cd7c9
add docs
varunagrawal Nov 7, 2022
3987b03
add optional ordering and fix bug
varunagrawal Nov 8, 2022
1a3b343
minor clean up and get tests to pass
varunagrawal Nov 8, 2022
1b168ce
update test in testHybridEstimation
varunagrawal Nov 8, 2022
cb55af3
separate HybridGaussianFactorGraph::error() using both continuous and…
varunagrawal Nov 8, 2022
eb94ad9
add HybridGaussianFactorGraph::probPrime method
varunagrawal Nov 8, 2022
0938159
overload multifrontal elimination
varunagrawal Nov 9, 2022
98d3186
add copy constructor for HybridBayesTreeClique
varunagrawal Nov 10, 2022
7ae4e57
fix HybridBayesTree::optimize to account for pruned nodes
varunagrawal Nov 10, 2022
d54cf48
fix creation of updatedBayesTree
varunagrawal Nov 10, 2022
318f738
fixup the final tests
varunagrawal Nov 10, 2022
6e6bbff
update docstring for Ordering::+=
varunagrawal Nov 10, 2022
5e2cdfd
make continuousProbPrimes and continuousDeltas as templates
varunagrawal Nov 13, 2022
2394129
address review comments
varunagrawal Nov 15, 2022
05b2d31
better printing
varunagrawal Dec 3, 2022
3eaf4cc
move multifrontal optimize test to testHybridBayesTree and update doc…
varunagrawal Dec 3, 2022
cd3cfa0
moved sequential elimination code to HybridEliminationTree
varunagrawal Dec 3, 2022
15fffeb
add getters to HybridEliminationTree
varunagrawal Dec 4, 2022
addbe2a
override eliminate in HybridJunctionTree
varunagrawal Dec 4, 2022
ae0b3e3
split up the eliminate method to constituent parts
varunagrawal Dec 4, 2022
bed56e0
mark helper methods as protected and add docstrings
varunagrawal Dec 4, 2022
5fc114f
more unit tests
varunagrawal Dec 4, 2022
22e4a73
Add details about the role of the HybridEliminationTree in hybrid eli…
varunagrawal Dec 4, 2022
0596b2f
remove unrequired code
varunagrawal Dec 10, 2022
62bc9f2
update hybrid elimination and corresponding tests
varunagrawal Dec 10, 2022
6beffeb
remove commented out code
varunagrawal Dec 10, 2022
da5d3a2
Merge pull request #1339 from borglab/hybrid/new-elimination
varunagrawal Dec 10, 2022
812bf52
minor cleanup
varunagrawal Dec 21, 2022
6b6731a
Merge branch 'hybrid/error' into hybrid/tests
varunagrawal Dec 21, 2022
46380ca
Merge branch 'hybrid/tests' into hybrid/multifrontal
varunagrawal Dec 21, 2022
583d121
Merge pull request #1323 from borglab/hybrid/multifrontal
varunagrawal Dec 23, 2022
ffd1802
add optional model parameter to sample method
varunagrawal Dec 23, 2022
bdb7836
sampling test
varunagrawal Dec 23, 2022
ae0df47
Merge branch 'develop' into hybrid/tests
varunagrawal Dec 24, 2022
153c12e
Merge branch 'develop' into hybrid/tests
varunagrawal Dec 24, 2022
aa86af2
Revert "add optional model parameter to sample method"
varunagrawal Dec 24, 2022
6b834db
Merge branch 'hybrid/tests' into hybrid/verification
varunagrawal Dec 24, 2022
798c51a
update sampling test to use new sample method
varunagrawal Dec 24, 2022
13d22b1
address review comments
varunagrawal Dec 24, 2022
1e17dd3
compute sampling ratio for one sample and then for multiple samples
varunagrawal Dec 24, 2022
f3c85ae
Merge pull request #1346 from borglab/hybrid/verification
varunagrawal Dec 25, 2022
76e838b
Implement printing rather than calling factor graph version
dellaert Dec 25, 2022
4ad482a
Small comments
dellaert Dec 26, 2022
a7573e8
Refactor elimination setup to not use C declaration style
dellaert Dec 26, 2022
8a319c5
Separated out NFG setup and added test.
dellaert Dec 26, 2022
db17a04
Fix print test
dellaert Dec 26, 2022
cfcbdda
Merge pull request #1349 from borglab/hybrid/two_ways
varunagrawal Dec 27, 2022
cf46c36
Merge branch 'develop' into hybrid/tests
varunagrawal Dec 30, 2022
28f349c
minor fixes
varunagrawal Dec 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add methods to perform correct elimination
  • Loading branch information
varunagrawal committed Nov 7, 2022
commit 1815433cbbde1052237427f774a086b1eabe8430
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ HybridValues HybridBayesNet::optimize() const {
/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);

// Check if there exists a nullptr in the GaussianBayesNet
// If yes, return an empty VectorValues
if (std::find(gbn.begin(), gbn.end(), nullptr) != gbn.end()) {
return VectorValues();
}
return gbn.optimize();
}

Expand Down
110 changes: 78 additions & 32 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,61 +493,107 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential() const {
Ordering continuous_ordering(this->continuousKeys()),
discrete_ordering(this->discreteKeys());

// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(
continuous_ordering, EliminationTraitsType::DefaultEliminate);

// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
// Get all the discrete assignments
DiscreteKeys discrete_keys = last_conditional->discreteKeys();
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);

// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());

DecisionTree<Key, VectorValues::shared_ptr>
HybridGaussianFactorGraph::continuousDelta(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
std::vector<VectorValues::shared_ptr> vector_values;
for (const DiscreteValues &assignment : assignments) {
VectorValues values = bayesNet->optimize(assignment);
VectorValues values = continuousBayesNet->optimize(assignment);
vector_values.push_back(boost::make_shared<VectorValues>(values));
}
DecisionTree<Key, VectorValues::shared_ptr> delta_tree(discrete_keys,
vector_values);

return delta_tree;
}

/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::continuousProbPrimes(
const DiscreteKeys &discrete_keys,
const boost::shared_ptr<BayesNetType> &continuousBayesNet,
const std::vector<DiscreteValues> &assignments) const {
// Create a decision tree of all the different VectorValues
DecisionTree<Key, VectorValues::shared_ptr> delta_tree =
this->continuousDelta(discrete_keys, continuousBayesNet, assignments);

// Get the probPrime tree with the correct leaf probabilities
std::vector<double> probPrimes;
for (const DiscreteValues &assignment : assignments) {
double error = 0.0;
VectorValues delta = *delta_tree(assignment);
for (auto factor : *this) {

// If VectorValues is empty, it means this is a pruned branch.
// Set thr probPrime to 0.0.
if (delta.size() == 0) {
probPrimes.push_back(0.0);
continue;
}

double error = 0.0;

for (size_t idx = 0; idx < size(); idx++) {
auto factor = factors_.at(idx);

if (factor->isHybrid()) {
auto f = boost::static_pointer_cast<GaussianMixtureFactor>(factor);
error += f->error(delta, assignment);
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(delta, assignment);
}
if (auto f =
boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(delta, assignment);
}

} else if (factor->isContinuous()) {
auto f = boost::static_pointer_cast<HybridGaussianFactor>(factor);
error += f->inner()->error(delta);
if (auto f =
boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(delta);
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(delta);
}
}
}
probPrimes.push_back(exp(-error));
}
AlgebraicDecisionTree<Key> probPrimeTree(discrete_keys, probPrimes);
discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree));
return probPrimeTree;
}

/* ************************************************************************ */
boost::shared_ptr<HybridGaussianFactorGraph::BayesNetType>
HybridGaussianFactorGraph::eliminateHybridSequential() const {
Ordering continuous_ordering(this->continuousKeys()),
discrete_ordering(this->discreteKeys());

// Eliminate continuous
HybridBayesNet::shared_ptr bayesNet;
HybridGaussianFactorGraph::shared_ptr discreteGraph;
std::tie(bayesNet, discreteGraph) =
BaseEliminateable::eliminatePartialSequential(continuous_ordering);

// Get the last continuous conditional which will have all the discrete keys
auto last_conditional = bayesNet->at(bayesNet->size() - 1);
DiscreteKeys discrete_keys = last_conditional->discreteKeys();

const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(discrete_keys);

// Save a copy of the original discrete key ordering
DiscreteKeys orig_discrete_keys(discrete_keys);
// Reverse discrete keys order for correct tree construction
std::reverse(discrete_keys.begin(), discrete_keys.end());

AlgebraicDecisionTree<Key> probPrimeTree =
continuousProbPrimes(discrete_keys, bayesNet, assignments);

discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree));

// Perform discrete elimination
HybridBayesNet::shared_ptr discreteBayesNet =
discreteGraph->eliminateSequential(
discrete_ordering, EliminationTraitsType::DefaultEliminate);
discreteGraph->eliminateSequential(discrete_ordering);

bayesNet->add(*discreteBayesNet);

return bayesNet;
Expand Down
39 changes: 39 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/VectorValues.h>

namespace gtsam {

Expand Down Expand Up @@ -190,6 +191,44 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
AlgebraicDecisionTree<Key> probPrime(
const VectorValues& continuousValues) const;

/**
* @brief Compute the VectorValues solution for the continuous variables for
* each mode.
*
* @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous
* eliminated variables.
* @param assignments List of all discrete assignments to create the final
* decision tree.
* @return DecisionTree<Key, VectorValues::shared_ptr>
*/
DecisionTree<Key, VectorValues::shared_ptr> continuousDelta(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const;

/**
* @brief Compute the unnormalized probabilities of the continuous variables
* for each of the modes.
*
* @param discrete_keys The discrete keys which form all the modes.
* @param continuousBayesNet The Bayes Net representing the continuous
* eliminated variables.
* @param assignments List of all discrete assignments to create the final
* decision tree.
* @return AlgebraicDecisionTree<Key>
*/
AlgebraicDecisionTree<Key> continuousProbPrimes(
const DiscreteKeys& discrete_keys,
const boost::shared_ptr<BayesNetType>& continuousBayesNet,
const std::vector<DiscreteValues>& assignments) const;

/**
* @brief Custom elimination function which computes the correct
* continuous probabilities.
*
* @return boost::shared_ptr<BayesNetType>
*/
boost::shared_ptr<BayesNetType> eliminateHybridSequential() const;

/**
Expand Down