Skip to content

Commit

Permalink
Merge pull request #1055 from borglab/feature/hybrid_base
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jan 23, 2022
2 parents b441eea + 6aeb3db commit 9d71c90
Show file tree
Hide file tree
Showing 19 changed files with 189 additions and 24 deletions.
13 changes: 13 additions & 0 deletions gtsam/base/utilities.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <gtsam/base/utilities.h>

namespace gtsam {

std::string RedirectCout::str() const {
return ssBuffer_.str();
}

RedirectCout::~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}

}
12 changes: 6 additions & 6 deletions gtsam/base/utilities.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#pragma once

#include <string>
#include <iostream>
#include <sstream>

namespace gtsam {
/**
* For Python __str__().
Expand All @@ -12,14 +16,10 @@ struct RedirectCout {
RedirectCout() : ssBuffer_(), coutBuffer_(std::cout.rdbuf(ssBuffer_.rdbuf())) {}

/// return the string
std::string str() const {
return ssBuffer_.str();
}
std::string str() const;

/// destructor -- redirect stdout buffer to its original buffer
~RedirectCout() {
std::cout.rdbuf(coutBuffer_);
}
~RedirectCout();

private:
std::stringstream ssBuffer_;
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ namespace gtsam {
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.2g") % v).str();
return (boost::format("%4.4g") % v).str();
};
Base::print(s, labelFormatter, valueFormatter);
}
Expand Down
6 changes: 5 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ namespace gtsam {
using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = boost::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument(
"DecisionTree::Convert: Invalid NodePtr");
"DecisionTree::convertFrom: Invalid NodePtr");

// get new label
const M oldLabel = choice->label();
Expand Down Expand Up @@ -634,6 +634,8 @@ namespace gtsam {

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::Visit: Invalid NodePtr");
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
}
};
Expand Down Expand Up @@ -663,6 +665,8 @@ namespace gtsam {

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
Expand Down
9 changes: 8 additions & 1 deletion gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace gtsam {
* Y = function range (any algebra), e.g., bool, int, double
*/
template<typename L, typename Y>
class GTSAM_EXPORT DecisionTree {
class DecisionTree {

protected:
/// Default method for comparison of two objects of type Y.
Expand Down Expand Up @@ -340,4 +340,11 @@ namespace gtsam {
return f.apply(g, op);
}

/// unzip a DecisionTree if its leaves are `std::pair`
template<typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
}

} // namespace gtsam
14 changes: 13 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ namespace gtsam {
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
ADT::print("", formatter);
}

/* ************************************************************************* */
Expand Down Expand Up @@ -168,6 +168,18 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
DiscreteKey dkey(key, cardinality(key));
if (std::find(result.begin(), result.end(), dkey) == result.end()) {
result.push_back(dkey);
}
}
return result;
}

/* ************************************************************************* */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ namespace gtsam {
/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

/// @}
/// @name Wrapper support
/// @{
Expand Down
47 changes: 47 additions & 0 deletions gtsam/discrete/DiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,59 @@
* @author Frank Dellaert
*/

#include <gtsam/base/Vector.h>
#include <gtsam/discrete/DiscreteFactor.h>

#include <cmath>
#include <sstream>

using namespace std;

namespace gtsam {

/* ************************************************************************* */
std::vector<double> expNormalize(const std::vector<double>& logProbs) {
double maxLogProb = -std::numeric_limits<double>::infinity();
for (size_t i = 0; i < logProbs.size(); i++) {
double logProb = logProbs[i];
if ((logProb != std::numeric_limits<double>::infinity()) &&
logProb > maxLogProb) {
maxLogProb = logProb;
}
}

// After computing the max = "Z" of the log probabilities L_i, we compute
// the log of the normalizing constant, log S, where S = sum_j exp(L_j - Z).
double total = 0.0;
for (size_t i = 0; i < logProbs.size(); i++) {
double probPrime = exp(logProbs[i] - maxLogProb);
total += probPrime;
}
double logTotal = log(total);

// Now we compute the (normalized) probability (for each i):
// p_i = exp(L_i - Z - log S)
double checkNormalization = 0.0;
std::vector<double> probs;
for (size_t i = 0; i < logProbs.size(); i++) {
double prob = exp(logProbs[i] - maxLogProb - logTotal);
probs.push_back(prob);
checkNormalization += prob;
}

// Numerical tolerance for floating point comparisons
double tol = 1e-9;

if (!gtsam::fpEqual(checkNormalization, 1.0, tol)) {
std::string errMsg =
std::string("expNormalize failed to normalize probabilities. ") +
std::string("Expected normalization constant = 1.0. Got value: ") +
std::to_string(checkNormalization) +
std::string(
"\n This could have resulted from numerical overflow/underflow.");
throw std::logic_error(errMsg);
}
return probs;
}

} // namespace gtsam
20 changes: 20 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,24 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
// traits
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};


/**
* @brief Normalize a set of log probabilities.
*
* Normalizing a set of log probabilities in a numerically stable way is
* tricky. To avoid overflow/underflow issues, we compute the largest
* (finite) log probability and subtract it from each log probability before
* normalizing. This comes from the observation that if:
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
* Then,
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
*
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
* of the (unnormalized) log probabilities are either very large or very
* small.
*/
std::vector<double> expNormalize(const std::vector<double> &logProbs);


}// namespace gtsam
18 changes: 16 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,25 @@ namespace gtsam {
/* ************************************************************************* */
KeySet DiscreteFactorGraph::keys() const {
KeySet keys;
for(const sharedFactor& factor: *this)
if (factor) keys.insert(factor->begin(), factor->end());
for (const sharedFactor& factor : *this) {
if (factor) keys.insert(factor->begin(), factor->end());
}
return keys;
}

/* ************************************************************************* */
DiscreteKeys DiscreteFactorGraph::discreteKeys() const {
DiscreteKeys result;
for (auto&& factor : *this) {
if (auto p = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
DiscreteKeys factor_keys = p->discreteKeys();
result.insert(result.end(), factor_keys.begin(), factor_keys.end());
}
}

return result;
}

/* ************************************************************************* */
DecisionTreeFactor DiscreteFactorGraph::product() const {
DecisionTreeFactor result;
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class GTSAM_EXPORT DiscreteFactorGraph
/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;

/// Return the DiscreteKeys in this factor graph.
DiscreteKeys discreteKeys() const;

/** return product of all factors as a single factor */
DecisionTreeFactor product() const;

Expand Down
9 changes: 7 additions & 2 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace gtsam {

/**
* Key type for discrete conditionals
* Includes name and cardinality
* Key type for discrete variables.
* Includes Key and cardinality.
*/
using DiscreteKey = std::pair<Key,size_t>;

Expand All @@ -45,6 +45,11 @@ namespace gtsam {
/// Construct from a key
explicit DiscreteKeys(const DiscreteKey& key) { push_back(key); }

/// Construct from cardinalities.
explicit DiscreteKeys(std::map<Key, size_t> cardinalities) {
for (auto&& kv : cardinalities) emplace_back(kv);
}

/// Construct from a vector of keys
DiscreteKeys(const std::vector<DiscreteKey>& keys) :
std::vector<DiscreteKey>(keys) {
Expand Down
2 changes: 2 additions & 0 deletions gtsam/discrete/DiscreteMarginals.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GTSAM_EXPORT DiscreteMarginals {

public:

DiscreteMarginals() {}

/** Construct a marginals class.
* @param graph The factor graph defining the full joint density on all variables.
*/
Expand Down
25 changes: 25 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,31 @@ TEST(DecisionTree, labels) {
EXPECT_LONGS_EQUAL(2, labels.size());
}

/* ******************************************************************************** */
// Test retrieving all labels.
TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>;
using DT2 = DecisionTree<string, string>;

// Create small two-level tree
string A("A"), B("B"), C("C");
DTP tree(B,
DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {2, "two"}, {1337, "l33t"})
);

DT1 dt1;
DT2 dt2;
std::tie(dt1, dt2) = unzip(tree);

DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));

EXPECT(tree1.equals(dt1));
EXPECT(tree2.equals(dt2));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
1 change: 0 additions & 1 deletion gtsam/inference/Factor.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ typedef FastSet<FactorIndex> FactorIndexSet;

/// @}

public:
/// @name Advanced Interface
/// @{

Expand Down
9 changes: 7 additions & 2 deletions gtsam/inference/FactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class FactorGraph {
/** Collection of factors */
FastVector<sharedFactor> factors_;

/// Check exact equality of the factor pointers. Useful for derived ==.
bool isEqual(const FactorGraph& other) const {
return factors_ == other.factors_;
}

/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -290,11 +295,11 @@ class FactorGraph {
/// @name Testable
/// @{

/// print out graph
/// Print out graph to std::cout, with optional key formatter.
virtual void print(const std::string& s = "FactorGraph",
const KeyFormatter& formatter = DefaultKeyFormatter) const;

/** Check equality */
/// Check equality up to tolerance.
bool equals(const This& fg, double tol = 1e-9) const;
/// @}

Expand Down
4 changes: 2 additions & 2 deletions gtsam/inference/MetisIndex-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
namespace gtsam {

/* ************************************************************************* */
template<class FACTOR>
void MetisIndex::augment(const FactorGraph<FACTOR>& factors) {
template<class FACTORGRAPH>
void MetisIndex::augment(const FACTORGRAPH& factors) {
std::map<int32_t, std::set<int32_t> > iAdjMap; // Stores a set of keys that are adjacent to key x, with adjMap.first
std::map<int32_t, std::set<int32_t> >::iterator iAdjMapIt;
std::set<Key> keySet;
Expand Down
8 changes: 4 additions & 4 deletions gtsam/inference/MetisIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class GTSAM_EXPORT MetisIndex {
nKeys_(0) {
}

template<class FG>
MetisIndex(const FG& factorGraph) :
template<class FACTORGRAPH>
MetisIndex(const FACTORGRAPH& factorGraph) :
nKeys_(0) {
augment(factorGraph);
}
Expand All @@ -78,8 +78,8 @@ class GTSAM_EXPORT MetisIndex {
* Augment the variable index with new factors. This can be used when
* solving problems incrementally.
*/
template<class FACTOR>
void augment(const FactorGraph<FACTOR>& factors);
template<class FACTORGRAPH>
void augment(const FACTORGRAPH& factors);

const std::vector<int32_t>& xadj() const {
return xadj_;
Expand Down
Loading

0 comments on commit 9d71c90

Please sign in to comment.