Skip to content

Commit

Permalink
Merge pull request #986 from borglab/feature/markdown
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Dec 25, 2021
2 parents fb3f00d + 38f0a40 commit 501a6db
Show file tree
Hide file tree
Showing 35 changed files with 751 additions and 281 deletions.
32 changes: 31 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,35 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}

/* ************************************************************************* */
/* ************************************************************************* */
std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";

// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& assignment : assignments) {
ss << "|";
for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n";
}
return ss.str();
}

/* ************************************************************************* */
} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ namespace gtsam {
// }

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

/// @}

};
// DecisionTreeFactor

Expand Down
12 changes: 11 additions & 1 deletion gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace gtsam {
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply
double result = 1.0;
for(DiscreteConditional::shared_ptr conditional: *this)
for(const DiscreteConditional::shared_ptr& conditional: *this)
result *= (*conditional)(values);
return result;
}
Expand All @@ -61,5 +61,15 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesNet` of size " << size() << endl << endl;
for(const DiscreteConditional::shared_ptr& conditional: *this)
ss << conditional->markdown(keyFormatter) << endl;
return ss.str();
}
/* ************************************************************************* */
} // namespace
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* @file DiscreteBayesNet.h
* @date Feb 15, 2011
* @author Duy-Nguyen Ta
* @author Frank dellaert
*/

#pragma once
Expand Down Expand Up @@ -97,6 +98,14 @@ namespace gtsam {
DiscreteValues sample() const;

///@}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}

private:
/** Serialization function */
Expand Down
21 changes: 17 additions & 4 deletions gtsam/discrete/DiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,21 @@ namespace gtsam {
return result;
}

} // \namespace gtsam



/* **************************************************************************/
std::string DiscreteBayesTree::markdown(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteBayesTree` of size " << nodes_.size() << endl << endl;
auto visitor = [&](const DiscreteBayesTreeClique::shared_ptr& clique,
size_t& indent) {
ss << "\n" << clique->conditional()->markdown(keyFormatter);
return indent + 1;
};
size_t indent;
treeTraversal::DepthFirstForest(*this, indent, visitor);
return ss.str();
}

/* **************************************************************************/
} // namespace gtsam
13 changes: 12 additions & 1 deletion gtsam/discrete/DiscreteBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class GTSAM_EXPORT DiscreteBayesTree
typedef DiscreteBayesTree This;
typedef boost::shared_ptr<This> shared_ptr;

/// @name Standard interface
/// @{
/** Default constructor, creates an empty Bayes tree */
DiscreteBayesTree() {}

Expand All @@ -82,10 +84,19 @@ class GTSAM_EXPORT DiscreteBayesTree
double evaluate(const DiscreteValues& values) const;

//** (Preferred) sugar for the above for given DiscreteValues */
double operator()(const DiscreteValues & values) const {
double operator()(const DiscreteValues& values) const {
return evaluate(values);
}

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
};

} // namespace gtsam
83 changes: 79 additions & 4 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
keys & dk;
}
// Get all Possible Configurations
vector<DiscreteValues> allPosbValues = cartesianProduct(keys);
const auto allPosbValues = cartesianProduct(keys);

// Find the MPE
for(DiscreteValues& frontalVals: allPosbValues) {
for(const auto& frontalVals: allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
Expand Down Expand Up @@ -222,6 +222,81 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}

/* ******************************************************************************** */
/* ************************************************************************* */
std::string DiscreteConditional::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out signature.
ss << " $P(";
bool first = true;
for (Key key : frontals()) {
if (!first) ss << ",";
ss << keyFormatter(key);
first = false;
}
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << ")$:" << std::endl;
ss << DecisionTreeFactor::markdown();
return ss.str();
}

// We have parents, continue signature and do custom print.
ss << "|";
first = true;
for (Key parent : parents()) {
if (!first) ss << ",";
ss << keyFormatter(parent);
first = false;
}
ss << ")$:" << std::endl;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
const_iterator it;
for(Key parent: parents()) {
ss << keyFormatter(parent) << "|";
pairs.emplace_back(parent, cardinalities_.at(parent));
}

size_t n = 1;
for(Key key: frontals()) {
size_t k = cardinalities_.at(key);
pairs.emplace_back(key, k);
n *= k;
}
std::vector<std::pair<Key, size_t>> slatnorf(pairs.rbegin(),
pairs.rend() - nrParents());
const auto frontal_assignments = cartesianProduct(slatnorf);
for (const auto& a : frontal_assignments) {
for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it);
ss << "|";
}
ss << "\n";

// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
ss << "\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
size_t count = 0;
for (const auto& a : assignments) {
if (count == 0) {
ss << "|";
for (it = beginParents(); it != endParents(); ++it)
ss << a.at(*it) << "|";
}
ss << operator()(a) << "|";
count = (count + 1) % n;
if (count == 0) ss << "\n";
}
return ss.str();
}
/* ************************************************************************* */

}// namespace
} // namespace gtsam
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,14 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
void sampleInPlace(DiscreteValues* parentsValues) const;

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const override;

/// @}
};
// DiscreteConditional

Expand Down
8 changes: 8 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
virtual std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const = 0;

/// @}
};
// DiscreteFactor
Expand Down
16 changes: 14 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ namespace gtsam {
return std::make_pair(cond, sum);
}

/* ************************************************************************* */
} // namespace
/* ************************************************************************* */
std::string DiscreteFactorGraph::markdown(
const KeyFormatter& keyFormatter) const {
using std::endl;
std::stringstream ss;
ss << "`DiscreteFactorGraph` of size " << size() << endl << endl;
for (size_t i = 0; i < factors_.size(); i++) {
ss << "factor " << i << ":\n";
ss << factors_[i]->markdown(keyFormatter) << endl;
}
return ss.str();
}

/* ************************************************************************* */
} // namespace gtsam
8 changes: 8 additions & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
// /** Apply a reduction, which is a remapping of variable indices. */
// GTSAM_EXPORT void reduceWithInverse(const internal::Reduction& inverseReduction);

/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string markdown(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
}; // \ DiscreteFactorGraph

/// traits
Expand Down
20 changes: 19 additions & 1 deletion gtsam/discrete/DiscreteValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,25 @@ namespace gtsam {
* stores cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the variable's type (domain)
*/
using DiscreteValues = Assignment<Key>;
class DiscreteValues : public Assignment<Key> {
public:
using Assignment::Assignment; // all constructors

// Define the implicit default constructor.
DiscreteValues() = default;

// Construct from assignment.
DiscreteValues(const Assignment<Key>& a) : Assignment<Key>(a) {}

void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this)
std::cout << "(" << keyFormatter(keyValue.first) << ", "
<< keyValue.second << ")";
std::cout << std::endl;
}
};

// traits
template<> struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/discrete/DiscreteConditional.h>
Expand All @@ -65,6 +67,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/discrete/DiscreteBayesNet.h>
Expand All @@ -89,6 +93,8 @@ class DiscreteBayesNet {
double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const;
gtsam::DiscreteValues sample() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/discrete/DiscreteBayesTree.h>
Expand Down Expand Up @@ -120,6 +126,9 @@ class DiscreteBayesTree {
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const;

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/inference/DotWriter.h>
Expand Down Expand Up @@ -160,6 +169,9 @@ class DiscreteFactorGraph {
gtsam::DiscreteBayesNet eliminateSequential(const gtsam::Ordering& ordering);
gtsam::DiscreteBayesTree eliminateMultifrontal();
gtsam::DiscreteBayesTree eliminateMultifrontal(const gtsam::Ordering& ordering);

string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

} // namespace gtsam
Loading

0 comments on commit 501a6db

Please sign in to comment.