Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create markdown representation #986

Merged
merged 19 commits into from
Dec 25, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,35 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}

/* ************************************************************************* */
/* ************************************************************************* */
std::string DecisionTreeFactor::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";

// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < size(); j++) ss << ":-:|";
ss << ":-:|\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& assignment : assignments) {
ss << "|";
for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n";
}
return ss.str();
}

/* ************************************************************************* */
} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ namespace gtsam {
// }

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}

};
// DecisionTreeFactor

Expand Down
73 changes: 69 additions & 4 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
keys & dk;
}
// Get all Possible Configurations
vector<DiscreteValues> allPosbValues = cartesianProduct(keys);
const auto allPosbValues = cartesianProduct(keys);

// Find the MPE
for(DiscreteValues& frontalVals: allPosbValues) {
for(const auto& frontalVals: allPosbValues) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: auto&& will automatically take care of constness.

double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
if (pValueS > maxP) {
Expand Down Expand Up @@ -222,6 +222,71 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}

/* ******************************************************************************** */
/* ************************************************************************* */
std::string DiscreteConditional::_repr_markdown_(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out signature.
ss << " $P(";
for(Key key: frontals())
ss << keyFormatter(key);
if (nrParents() > 0)
ss << "|";
bool first = true;
for (Key parent : parents()) {
if (!first) ss << ",";
ss << keyFormatter(parent);
first = false;
}
ss << ")$:" << std::endl;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
const_iterator it;
for(Key parent: parents()) {
ss << keyFormatter(parent) << "|";
pairs.emplace_back(parent, cardinalities_.at(parent));
}

size_t n = 1;
for(Key key: frontals()) {
size_t k = cardinalities_.at(key);
pairs.emplace_back(key, k);
n *= k;
}
size_t nrParents = size() - nrFrontals_;
std::vector<std::pair<Key, size_t>> slatnorf(pairs.rbegin(),
pairs.rend() - nrParents);
const auto frontal_assignments = cartesianProduct(slatnorf);
for (const auto& a : frontal_assignments) {
for (it = beginFrontals(); it != endFrontals(); ++it) ss << a.at(*it);
ss << "|";
}
ss << "\n";

// Print out separator with alignment hints.
ss << "|";
for (size_t j = 0; j < nrParents + n; j++) ss << ":-:|";
ss << "\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
size_t count = 0;
for (const auto& a : assignments) {
if (count == 0) {
ss << "|";
for (it = beginParents(); it != endParents(); ++it)
ss << a.at(*it) << "|";
}
ss << operator()(a) << "|";
count = (count + 1) % n;
if (count == 0) ss << "\n";
}
return ss.str();
}
/* ************************************************************************* */

}// namespace
} // namespace gtsam
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,14 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
void sampleInPlace(DiscreteValues* parentsValues) const;

/// @}
/// @name Wrapper support
/// @{

/// Render as markdown table.
std::string _repr_markdown_(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;

/// @}
};
// DiscreteConditional

Expand Down
20 changes: 19 additions & 1 deletion gtsam/discrete/DiscreteValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,25 @@ namespace gtsam {
* stores cardinality of a Discrete variable. It should be handled naturally in
* the new class DiscreteValue, as the variable's type (domain)
*/
using DiscreteValues = Assignment<Key>;
class DiscreteValues : public Assignment<Key> {
public:
using Assignment::Assignment; // all constructors

// Define the implicit default constructor.
DiscreteValues() = default;

// Construct from assignment.
DiscreteValues(const Assignment<Key>& a) : Assignment<Key>(a) {}

void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this)
std::cout << "(" << keyFormatter(keyValue.first) << ", "
<< keyValue.second << ")";
std::cout << std::endl;
}
};

// traits
template<> struct traits<DiscreteValues> : public Testable<DiscreteValues> {};
Expand Down
4 changes: 4 additions & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ virtual class DecisionTreeFactor: gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const;
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/discrete/DiscreteConditional.h>
Expand All @@ -65,6 +67,8 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
size_t sample(const gtsam::DiscreteValues& parentsValues) const;
void solveInPlace(gtsam::DiscreteValues@ parentsValues) const;
void sampleInPlace(gtsam::DiscreteValues@ parentsValues) const;
string _repr_markdown_(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};

#include <gtsam/discrete/DiscreteBayesNet.h>
Expand Down
37 changes: 20 additions & 17 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@ using namespace gtsam;
/* ************************************************************************* */
TEST( DecisionTreeFactor, constructors)
{
// Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2);

// Create factors
DecisionTreeFactor f1(X, "2 8");
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size());
EXPECT_LONGS_EQUAL(2,f2.size());
EXPECT_LONGS_EQUAL(3,f3.size());

// f1.print("f1:");
// f2.print("f2:");
// f3.print("f3:");

DiscreteValues values;
values[0] = 1; // x
values[1] = 2; // y
Expand All @@ -55,47 +53,52 @@ TEST( DecisionTreeFactor, constructors)
/* ************************************************************************* */
TEST_UNSAFE( DecisionTreeFactor, multiplication)
{
// Declare a bunch of keys
DiscreteKey v0(0,2), v1(1,2), v2(2,2);

// Create a factor
DecisionTreeFactor f1(v0 & v1, "1 2 3 4");
DecisionTreeFactor f2(v1 & v2, "5 6 7 8");
// f1.print("f1:");
// f2.print("f2:");

DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32");

DecisionTreeFactor actual = f1 * f2;
// actual.print("actual: ");
CHECK(assert_equal(expected, actual));
}

/* ************************************************************************* */
TEST( DecisionTreeFactor, sum_max)
{
// Declare a bunch of keys
DiscreteKey v0(0,3), v1(1,2);

// Create a factor
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");

DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
CHECK(assert_equal(expected, *actual, 1e-5));
// f1.print("f1:");
// actual->print("actual: ");
// actual->printCache("actual cache: ");

DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
CHECK(assert_equal(expected2, *actual2));

DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
// f2.print("f2: ");
// actual22->print("actual22: ");
}

/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DecisionTreeFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
"|0|0|1|\n"
"|0|1|2|\n"
"|1|0|3|\n"
"|1|1|4|\n"
"|2|0|5|\n"
"|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f1._repr_markdown_(formatter);
EXPECT(actual == expected);
}

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/tests/testDiscreteBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST(DiscreteBayesTree, ThinTree) {
auto R = self.bayesTree->roots().front();

// Check whether BN and BT give the same answer on all configurations
vector<DiscreteValues> allPosbValues =
auto allPosbValues =
cartesianProduct(keys[0] & keys[1] & keys[2] & keys[3] & keys[4] &
keys[5] & keys[6] & keys[7] & keys[8] & keys[9] &
keys[10] & keys[11] & keys[12] & keys[13] & keys[14]);
Expand Down
61 changes: 58 additions & 3 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using namespace boost::assign;
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>

using namespace std;
using namespace gtsam;
Expand Down Expand Up @@ -101,9 +102,63 @@ TEST(DiscreteConditional, Combine) {
c.push_back(boost::make_shared<DiscreteConditional>(A | B = "1/2 2/1"));
c.push_back(boost::make_shared<DiscreteConditional>(B % "1/2"));
DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222");
DiscreteConditional actual(2, factor);
auto expected = DiscreteConditional::Combine(c.begin(), c.end());
EXPECT(assert_equal(*expected, actual, 1e-5));
DiscreteConditional expected(2, factor);
auto actual = DiscreteConditional::Combine(c.begin(), c.end());
EXPECT(assert_equal(expected, *actual, 1e-5));
}

/* ************************************************************************* */
// Check markdown representation looks as expected, no parents.
TEST(DiscreteConditional, markdown_prior) {
DiscreteKey A(Symbol('x', 1), 2);
DiscreteConditional conditional(A % "1/3");
string expected =
" $P(x1)$:\n"
"|0|1|\n"
"|:-:|:-:|\n"
"|0.25|0.75|\n";
string actual = conditional._repr_markdown_();
EXPECT(actual == expected);
}

/* ************************************************************************* */
// Check markdown representation looks as expected, multivalued.
TEST(DiscreteConditional, markdown_multivalued) {
DiscreteKey A(Symbol('a', 1), 3), B(Symbol('b', 1), 5);
DiscreteConditional conditional(
A | B = "2/88/10 2/20/78 33/33/34 33/33/34 95/2/3");
string expected =
" $P(a1|b1)$:\n"
"|b1|0|1|2|\n"
"|:-:|:-:|:-:|:-:|\n"
"|0|0.02|0.88|0.1|\n"
"|1|0.02|0.2|0.78|\n"
"|2|0.33|0.33|0.34|\n"
"|3|0.33|0.33|0.34|\n"
"|4|0.95|0.02|0.03|\n";
string actual = conditional._repr_markdown_();
EXPECT(actual == expected);
}

/* ************************************************************************* */
// Check markdown representation looks as expected, two parents.
TEST(DiscreteConditional, markdown) {
DiscreteKey A(2, 2), B(1, 2), C(0, 3);
DiscreteConditional conditional(A, {B, C}, "0/1 1/3 1/1 3/1 0/1 1/0");
string expected =
" $P(A|B,C)$:\n"
"|B|C|0|1|\n"
"|:-:|:-:|:-:|:-:|\n"
"|0|0|0|1|\n"
"|0|1|0.25|0.75|\n"
"|0|2|0.5|0.5|\n"
"|1|0|0.75|0.25|\n"
"|1|1|0|1|\n"
"|1|2|1|0|\n";
vector<string> names{"C", "B", "A"};
auto formatter = [names](Key key) { return names[key]; };
string actual = conditional._repr_markdown_(formatter);
EXPECT(actual == expected);
}

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/tests/testDiscreteMarginals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ TEST_UNSAFE(DiscreteMarginals, truss2) {
graph.add(key[2] & key[3] & key[4], "1 2 3 4 5 6 7 8");

// Calculate the marginals by brute force
vector<DiscreteValues> allPosbValues =
auto allPosbValues =
cartesianProduct(key[0] & key[1] & key[2] & key[3] & key[4]);
Vector T = Z_5x1, F = Z_5x1;
for (size_t i = 0; i < allPosbValues.size(); ++i) {
Expand Down