Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API improvements for discrete #990

Merged
merged 20 commits into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
enumerate
  • Loading branch information
dellaert committed Dec 27, 2021
commit 911819c7f2d4acd8c6076e63551b094ed82380f5
28 changes: 22 additions & 6 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,33 @@ namespace gtsam {
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}

/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
dellaert marked this conversation as resolved.
Show resolved Hide resolved
const auto assignments = cartesianProduct(rpairs);

// Construct unordered_map with values
std::vector<std::pair<DiscreteValues, double>> result;
for (const auto& assignment : assignments) {
result.emplace_back(assignment, operator()(assignment));
}
return result;
}

/* ************************************************************************* */
std::string DecisionTreeFactor::markdown(
const KeyFormatter& keyFormatter) const {
std::stringstream ss;

// Print out header and construct argument for `cartesianProduct`.
std::vector<std::pair<Key, size_t>> pairs;
ss << "|";
for (auto& key : keys()) {
ss << keyFormatter(key) << "|";
pairs.emplace_back(key, cardinalities_.at(key));
}
ss << "value|\n";

Expand All @@ -154,12 +170,12 @@ namespace gtsam {
ss << ":-:|\n";

// Print out all rows.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = cartesianProduct(rpairs);
for (const auto& assignment : assignments) {
auto rows = enumerate();
for (const auto& kv : rows) {
ss << "|";
auto assignment = kv.first;
for (auto& key : keys()) ss << assignment.at(key) << "|";
ss << operator()(assignment) << "|\n";
ss << kv.second << "|\n";
}
return ss.str();
}
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ namespace gtsam {
// Potentials::reduceWithInverse(inverseReduction);
// }

/// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const;

/// @}
/// @name Wrapper support
/// @{
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/discrete.i
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor {
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const;
string dot(bool showZero = false) const;
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
Expand Down
22 changes: 20 additions & 2 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,29 @@ TEST( DecisionTreeFactor, sum_max)
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
}

/* ************************************************************************* */
// Check enumerate yields the correct list of assignment/value pairs.
TEST(DecisionTreeFactor, enumerate) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto actual = f.enumerate();
std::vector<std::pair<DiscreteValues, double>> expected;
DiscreteValues values;
for (size_t a : {0, 1, 2}) {
for (size_t b : {0, 1}) {
values[12] = a;
values[5] = b;
expected.emplace_back(values, f(values));
}
}
EXPECT(actual == expected);
}

/* ************************************************************************* */
// Check markdown representation looks as expected.
TEST(DecisionTreeFactor, markdown) {
DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
string expected =
"|A|B|value|\n"
"|:-:|:-:|:-:|\n"
Expand All @@ -97,7 +115,7 @@ TEST(DecisionTreeFactor, markdown) {
"|2|0|5|\n"
"|2|1|6|\n";
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
string actual = f1.markdown(formatter);
string actual = f.markdown(formatter);
EXPECT(actual == expected);
}

Expand Down
54 changes: 54 additions & 0 deletions python/gtsam/tests/test_DecisionTreeFactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
GTSAM Copyright 2010-2021, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved

See LICENSE for the license information

Unit tests for DecisionTreeFactors.
Author: Frank Dellaert
"""

# pylint: disable=no-name-in-module, invalid-name

import unittest

from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys
from gtsam.utils.test_case import GtsamTestCase


class TestDecisionTreeFactor(GtsamTestCase):
"""Tests for DecisionTreeFactors."""

def setUp(self):
A = (12, 3)
B = (5, 2)
self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6")

def test_enumerate(self):
actual = self.factor.enumerate()
_, values = zip(*actual)
self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0])

def test_markdown(self):
"""Test whether the _repr_markdown_ method."""

expected = \
"|A|B|value|\n" \
"|:-:|:-:|:-:|\n" \
"|0|0|1|\n" \
"|0|1|2|\n" \
"|1|0|3|\n" \
"|1|1|4|\n" \
"|2|0|5|\n" \
"|2|1|6|\n"

def formatter(x: int):
return "A" if x == 12 else "B"

actual = self.factor._repr_markdown_(formatter)
self.assertEqual(actual, expected)


if __name__ == "__main__":
unittest.main()