Skip to content

Commit

Permalink
Merge pull request #990 from borglab/feature/api_improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Dec 29, 2021
2 parents 04629e4 + 4c0761d commit e5b928c
Show file tree
Hide file tree
Showing 22 changed files with 647 additions and 160 deletions.
14 changes: 10 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@ endif()

# Set the version number for the library
set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 1)
set (GTSAM_VERSION_PATCH 1)
set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a0")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}")

set (CMAKE_PROJECT_VERSION ${GTSAM_VERSION_STRING})
if (${GTSAM_VERSION_PATCH} EQUAL 0)
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}${GTSAM_PRERELEASE_VERSION}")
else()
set (GTSAM_VERSION_STRING "${GTSAM_VERSION_MAJOR}.${GTSAM_VERSION_MINOR}.${GTSAM_VERSION_PATCH}${GTSAM_PRERELEASE_VERSION}")
endif()
message(STATUS "GTSAM Version: ${GTSAM_VERSION_STRING}")

set (CMAKE_PROJECT_VERSION_MAJOR ${GTSAM_VERSION_MAJOR})
set (CMAKE_PROJECT_VERSION_MINOR ${GTSAM_VERSION_MINOR})
set (CMAKE_PROJECT_VERSION_PATCH ${GTSAM_VERSION_PATCH})
Expand Down
29 changes: 23 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,34 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}

/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
// Reverse to make cartesianProduct output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);

// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}

/* ************************************************************************* */
std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";

Expand All @@ -154,12 +171,12 @@ namespace gtsam {
ss << ":-:|\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& assignment : assignments) {
auto rows = enumerate();
for (const auto& kv : rows) {
ss << "|";
auto assignment = kv.first;
for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n";
ss << kv.second << "|\n";
}
return ss.str();
}
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ namespace gtsam {
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}

/// Single-key specialization
template <class SOURCE>
DecisionTreeFactor(const DiscreteKey& key, SOURCE table)
: DecisionTreeFactor(DiscreteKeys{key}, table) {}

/// Single-key specialization, with vector of doubles.
DecisionTreeFactor(const DiscreteKey& key, const std::vector<double>& row)
: DecisionTreeFactor(DiscreteKeys{key}, row) {}

/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);

Expand Down Expand Up @@ -162,6 +171,9 @@ namespace gtsam {
// Potentials::reduceWithInverse(inverseReduction);
// }

/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// @}
/// @name Wrapper support
/// @{
Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>
#include <gtsam/discrete/DiscretePrior.h>
#include <gtsam/discrete/DiscreteConditional.h>

namespace gtsam {
Expand Down Expand Up @@ -75,6 +76,11 @@ namespace gtsam {
// Add inherited versions of add.
using Base::add;

/** Add a DiscretePrior using a table or a string */
void add(const DiscreteKey& key, const std::string& spec) {
emplace_shared<DiscretePrior>(key, spec);
}

/** Add a DiscreteCondtional */
template <typename... Args>
void add(Args&&... args) {
Expand Down
104 changes: 82 additions & 22 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,45 +97,90 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}

/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& parentsValues) {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
size_t value;
for (Key j : conditional.parents()) {
try {
value = parentsValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
return adt;
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT pFS(*this);
ADT adt(*this);
size_t value;
for (Key j : parents()) {
try {
value = parentsValues.at(j);
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}
return pFS;

// Convert ADT to factor.
DiscreteKeys discreteKeys;
for (Key j : frontals()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues);
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
ADT adt(*this);
size_t value;
for (Key j : frontals()) {
try {
value = frontalValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
};
}

// Convert ADT to factor.
if (nrFrontals() != 1) {
throw std::runtime_error("Expected only one frontal variable in choose.");
DiscreteKeys discreteKeys;
for (Key j : parents()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
DiscreteKeys keys;
const Key frontalKey = keys_[0];
size_t frontalCardinality = this->cardinality(frontalKey);
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const {
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], parent_value);
return likelihood(values);
}

/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
ADT pFS = choose(*values); // P(F|S=parentsValues)
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)

// Initialize
DiscreteValues mpe;
Expand Down Expand Up @@ -177,7 +222,7 @@ void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {

// TODO: is this really the fastest way? I think it is.
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)

// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
Expand All @@ -203,10 +248,14 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator

// Get the correct conditional density
ADT pFS = choose(parentsValues); // P(F|S=parentsValues)
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)

// TODO(Duy): only works for one key now, seems horribly slow this way
assert(nrFrontals() == 1);
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sample can only be called on single variable "
"conditionals");
}
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
Expand All @@ -222,13 +271,24 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}

/* ******************************************************************************** */
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
}

/* ************************************************************************* */
std::string DiscreteConditional::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out signature.
ss << " $P(";
ss << " *P(";
bool first = true;
for (Key key : frontals()) {
if (!first) ss << ",";
Expand All @@ -237,7 +297,7 @@ std::string DiscreteConditional::markdown(
}
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << ")$:" << std::endl;
ss << ")*:\n" << std::endl;
ss << DecisionTreeFactor::markdown(keyFormatter);
return ss.str();
}
Expand All @@ -250,7 +310,7 @@ std::string DiscreteConditional::markdown(
ss << keyFormatter(parent);
first = false;
}
ss << ")$:" << std::endl;
ss << ")*:\n" << std::endl;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
Expand Down
25 changes: 17 additions & 8 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
* conditional probability table (CPT) in 00 01 10 11 order. For
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents.
*
* Example: DiscreteConditional P(D, {B,E}, table);
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
Expand All @@ -75,15 +73,18 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
* be 00 01 02 10 11 12 20 21 22, etc....
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
* The string is parsed into a Signature::Table.
*
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}

/// No-parent specialization; can also use DiscretePrior.
DiscreteConditional(const DiscreteKey& key, const std::string& spec)
: DiscreteConditional(Signature(key, {}, spec)) {}

/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
Expand Down Expand Up @@ -135,13 +136,17 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this));
}

/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const DiscreteValues& parentsValues) const;

/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr chooseAsFactor(
DecisionTreeFactor::shared_ptr choose(
const DiscreteValues& parentsValues) const;

/** Convert to a likelihood factor by providing value before bar. */
DecisionTreeFactor::shared_ptr likelihood(
const DiscreteValues& frontalValues) const;

/** Single variable version of likelihood. */
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;

/**
* solve a conditional
* @param parentsValues Known values of the parents
Expand All @@ -156,6 +161,10 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
*/
size_t sample(const DiscreteValues& parentsValues) const;


/// Single value version.
size_t sample(size_t parent_value) const;

/// @}
/// @name Advanced Interface
/// @{
Expand Down
27 changes: 5 additions & 22 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,29 +101,12 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {

/// @}

// Add single key decision-tree factor.
template <class SOURCE>
void add(const DiscreteKey& j, SOURCE table) {
DiscreteKeys keys;
keys.push_back(j);
emplace_shared<DecisionTreeFactor>(keys, table);
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DecisionTreeFactor>(std::forward<Args>(args)...);
}

// Add binary key decision-tree factor.
template <class SOURCE>
void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) {
DiscreteKeys keys;
keys.push_back(j1);
keys.push_back(j2);
emplace_shared<DecisionTreeFactor>(keys, table);
}

// Add shared discreteFactor immediately from arguments.
template <class SOURCE>
void add(const DiscreteKeys& keys, SOURCE table) {
emplace_shared<DecisionTreeFactor>(keys, table);
}


/** Return the set of variables involved in the factors (set union) */
KeySet keys() const;

Expand Down
Loading

0 comments on commit e5b928c

Please sign in to comment.