Skip to content

Commit

Permalink
Merge pull request #1290 from borglab/hybrid/improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Oct 4, 2022
2 parents 0909c46 + d6feb4d commit 903d7c6
Show file tree
Hide file tree
Showing 9 changed files with 879 additions and 22 deletions.
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianISAM.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class GTSAM_EXPORT HybridGaussianISAM : public ISAM<HybridBayesTree> {
HybridBayesTree::EliminationTraitsType::DefaultEliminate);

/**
* @brief
* @brief Prune the underlying Bayes tree.
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
Expand Down
21 changes: 10 additions & 11 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ void HybridNonlinearFactorGraph::add(
}

/* ************************************************************************* */
void HybridNonlinearFactorGraph::add(
boost::shared_ptr<DiscreteFactor> factor) {
void HybridNonlinearFactorGraph::add(boost::shared_ptr<DiscreteFactor> factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

Expand All @@ -49,12 +48,12 @@ void HybridNonlinearFactorGraph::print(const std::string& s,
}

/* ************************************************************************* */
HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
const Values& continuousValues) const {
// create an empty linear FG
HybridGaussianFactorGraph linearFG;
auto linearFG = boost::make_shared<HybridGaussianFactorGraph>();

linearFG.reserve(size());
linearFG->reserve(size());

// linearize all hybrid factors
for (auto&& factor : factors_) {
Expand All @@ -66,9 +65,9 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
if (factor->isHybrid()) {
// Check if it is a nonlinear mixture factor
if (auto nlmf = boost::dynamic_pointer_cast<MixtureFactor>(factor)) {
linearFG.push_back(nlmf->linearize(continuousValues));
linearFG->push_back(nlmf->linearize(continuousValues));
} else {
linearFG.push_back(factor);
linearFG->push_back(factor);
}

// Now check if the factor is a continuous only factor.
Expand All @@ -80,18 +79,18 @@ HybridGaussianFactorGraph HybridNonlinearFactorGraph::linearize(
boost::dynamic_pointer_cast<NonlinearFactor>(nlhf->inner())) {
auto hgf = boost::make_shared<HybridGaussianFactor>(
nlf->linearize(continuousValues));
linearFG.push_back(hgf);
linearFG->push_back(hgf);
} else {
linearFG.push_back(factor);
linearFG->push_back(factor);
}
// Finally if nothing else, we are discrete-only which doesn't need
// lineariztion.
} else {
linearFG.push_back(factor);
linearFG->push_back(factor);
}

} else {
linearFG.push_back(GaussianFactor::shared_ptr());
linearFG->push_back(GaussianFactor::shared_ptr());
}
}
return linearFG;
Expand Down
28 changes: 27 additions & 1 deletion gtsam/hybrid/HybridNonlinearFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
using IsNonlinear = typename std::enable_if<
std::is_base_of<NonlinearFactor, FACTOR>::value>::type;

/// Check if T has a value_type derived from FactorType.
template <typename T>
using HasDerivedValueType = typename std::enable_if<
std::is_base_of<HybridFactor, typename T::value_type>::value>::type;

/// Check if T has a pointer type derived from FactorType.
template <typename T>
using HasDerivedElementType = typename std::enable_if<std::is_base_of<
HybridFactor, typename T::value_type::element_type>::value>::type;

public:
using Base = HybridFactorGraph;
using This = HybridNonlinearFactorGraph; ///< this class
Expand Down Expand Up @@ -109,6 +119,21 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
}
}

/**
* Push back many factors as shared_ptr's in a container (factors are not
* copied)
*/
template <typename CONTAINER>
HasDerivedElementType<CONTAINER> push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}

/// Push back non-pointer objects in a container (factors are copied).
template <typename CONTAINER>
HasDerivedValueType<CONTAINER> push_back(const CONTAINER& container) {
Base::push_back(container.begin(), container.end());
}

/// Add a nonlinear factor as a shared ptr.
void add(boost::shared_ptr<NonlinearFactor> factor);

Expand All @@ -127,7 +152,8 @@ class GTSAM_EXPORT HybridNonlinearFactorGraph : public HybridFactorGraph {
* @param continuousValues: Dictionary of continuous values.
* @return HybridGaussianFactorGraph::shared_ptr
*/
HybridGaussianFactorGraph linearize(const Values& continuousValues) const;
HybridGaussianFactorGraph::shared_ptr linearize(
const Values& continuousValues) const;
};

template <>
Expand Down
111 changes: 111 additions & 0 deletions gtsam/hybrid/HybridNonlinearISAM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridNonlinearISAM.cpp
* @date Sep 12, 2022
* @author Varun Agrawal
*/

#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearISAM.h>
#include <gtsam/inference/Ordering.h>

#include <iostream>

using namespace std;

namespace gtsam {

/* ************************************************************************* */
void HybridNonlinearISAM::saveGraph(const string& s,
const KeyFormatter& keyFormatter) const {
isam_.saveGraph(s, keyFormatter);
}

/* ************************************************************************* */
void HybridNonlinearISAM::update(const HybridNonlinearFactorGraph& newFactors,
const Values& initialValues) {
if (newFactors.size() > 0) {
// Reorder and relinearize every reorderInterval updates
if (reorderInterval_ > 0 && ++reorderCounter_ >= reorderInterval_) {
reorder_relinearize();
reorderCounter_ = 0;
}

factors_.push_back(newFactors);

// Linearize new factors and insert them
// TODO: optimize for whole config?
linPoint_.insert(initialValues);

boost::shared_ptr<HybridGaussianFactorGraph> linearizedNewFactors =
newFactors.linearize(linPoint_);

// Update ISAM
isam_.update(*linearizedNewFactors, boost::none, eliminationFunction_);
}
}

/* ************************************************************************* */
void HybridNonlinearISAM::reorder_relinearize() {
if (factors_.size() > 0) {
// Obtain the new linearization point
const Values newLinPoint = estimate();

isam_.clear();

// Just recreate the whole BayesTree
// TODO: allow for constrained ordering here
// TODO: decouple relinearization and reordering to avoid
isam_.update(*factors_.linearize(newLinPoint), boost::none,
eliminationFunction_);

// Update linearization point
linPoint_ = newLinPoint;
}
}

/* ************************************************************************* */
Values HybridNonlinearISAM::estimate() {
Values result;
if (isam_.size() > 0) {
HybridValues values = isam_.optimize();
assignment_ = values.discrete();
return linPoint_.retract(values.continuous());
} else {
return linPoint_;
}
}

// /* *************************************************************************
// */ Matrix HybridNonlinearISAM::marginalCovariance(Key key) const {
// return isam_.marginalCovariance(key);
// }

/* ************************************************************************* */
void HybridNonlinearISAM::print(const string& s,
const KeyFormatter& keyFormatter) const {
cout << s << "ReorderInterval: " << reorderInterval_
<< " Current Count: " << reorderCounter_ << endl;
isam_.print("HybridGaussianISAM:\n");
linPoint_.print("Linearization Point:\n", keyFormatter);
factors_.print("Nonlinear Graph:\n", keyFormatter);
}

/* ************************************************************************* */
void HybridNonlinearISAM::printStats() const {
isam_.getCliqueData().getStats().print();
}

/* ************************************************************************* */

} // namespace gtsam
132 changes: 132 additions & 0 deletions gtsam/hybrid/HybridNonlinearISAM.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridNonlinearISAM.h
* @date Sep 12, 2022
* @author Varun Agrawal
*/

#pragma once

#include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>

namespace gtsam {
/**
* Wrapper class to manage ISAM in a nonlinear context
*/
class GTSAM_EXPORT HybridNonlinearISAM {
protected:
/** The internal iSAM object */
gtsam::HybridGaussianISAM isam_;

/** The current linearization point */
Values linPoint_;

/// The discrete assignment
DiscreteValues assignment_;

/** The original factors, used when relinearizing */
HybridNonlinearFactorGraph factors_;

/** The reordering interval and counter */
int reorderInterval_;
int reorderCounter_;

/** The elimination function */
HybridGaussianFactorGraph::Eliminate eliminationFunction_;

public:
/// @name Standard Constructors
/// @{

/**
* Periodically reorder and relinearize
* @param reorderInterval is the number of updates between reorderings,
* 0 never reorders (and is dangerous for memory consumption)
* 1 (default) reorders every time, in worse case is batch every update
* typical values are 50 or 100
*/
HybridNonlinearISAM(
int reorderInterval = 1,
const HybridGaussianFactorGraph::Eliminate& eliminationFunction =
HybridGaussianFactorGraph::EliminationTraitsType::DefaultEliminate)
: reorderInterval_(reorderInterval),
reorderCounter_(0),
eliminationFunction_(eliminationFunction) {}

/// @}
/// @name Standard Interface
/// @{

/** Return the current solution estimate */
Values estimate();

// /** find the marginal covariance for a single variable */
// Matrix marginalCovariance(Key key) const;

// access

/** access the underlying bayes tree */
const HybridGaussianISAM& bayesTree() const { return isam_; }

/**
* @brief Prune the underlying Bayes tree.
*
* @param root The root key in the discrete conditional decision tree.
* @param maxNumberLeaves
*/
void prune(const Key& root, const size_t maxNumberLeaves) {
isam_.prune(root, maxNumberLeaves);
}

/** Return the current linearization point */
const Values& getLinearizationPoint() const { return linPoint_; }

/** Return the current discrete assignment */
const DiscreteValues& getAssignment() const { return assignment_; }

/** get underlying nonlinear graph */
const HybridNonlinearFactorGraph& getFactorsUnsafe() const {
return factors_;
}

/** get counters */
int reorderInterval() const { return reorderInterval_; } ///< TODO: comment
int reorderCounter() const { return reorderCounter_; } ///< TODO: comment

/** prints out all contents of the system */
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/** prints out clique statistics */
void printStats() const;

/** saves the Tree to a text file in GraphViz format */
void saveGraph(const std::string& s,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
/// @name Advanced Interface
/// @{

/** Add new factors along with their initial linearization points */
void update(const HybridNonlinearFactorGraph& newFactors,
const Values& initialValues);

/** Relinearization and reordering of variables */
void reorder_relinearize();

/// @}
};

} // namespace gtsam
2 changes: 1 addition & 1 deletion gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ struct Switching {
linearizationPoint.insert<double>(X(k), static_cast<double>(k));
}

linearizedFactorGraph = nonlinearFactorGraph.linearize(linearizationPoint);
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
}

// Create motion models for a given time step
Expand Down
Loading

0 comments on commit 903d7c6

Please sign in to comment.