Skip to content

Commit

Permalink
Merge pull request #1135 from borglab/fix/visitWith
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Mar 18, 2022
2 parents 9be5967 + fa542a2 commit e5fb4cd
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
2 changes: 2 additions & 0 deletions gtsam/discrete/Assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ namespace gtsam {
template <class L>
class Assignment : public std::map<L, size_t> {
public:
using std::map<L, size_t>::operator=;

void print(const std::string& s = "Assignment: ") const {
std::cout << s << ": ";
for (const typename Assignment::value_type& keyValue : *this)
Expand Down
7 changes: 6 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,8 +666,13 @@ namespace gtsam {
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
choices[choice->label()] = i; // Set assignment for label to i

(*this)(choice->branches()[i]); // recurse!

// Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label());
choices.erase(choice_it);
}
}
};
Expand Down
38 changes: 38 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,44 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
}

/* ************************************************************************** */
// Test visit, with Choices argument.
TEST(DecisionTree, VisitWithPruned) {
// Create pruned tree
std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
std::vector<std::pair<string, size_t>> labels = {C, B, A};
std::vector<int> nodes = {0, 0, 2, 3, 4, 4, 6, 7};
DT tree(labels, nodes);

std::vector<Assignment<string>> choices;
auto func = [&](const Assignment<string>& choice, const int& d) {
choices.push_back(choice);
};
tree.visitWith(func);

EXPECT_LONGS_EQUAL(6, choices.size());

Assignment<string> expectedAssignment;

expectedAssignment = {{"B", 0}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(0));

expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(1));

expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 0}};
EXPECT(expectedAssignment == choices.at(2));

expectedAssignment = {{"B", 0}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(3));

expectedAssignment = {{"A", 0}, {"B", 1}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(4));

expectedAssignment = {{"A", 1}, {"B", 1}, {"C", 1}};
EXPECT(expectedAssignment == choices.at(5));
}

/* ************************************************************************** */
// Test fold.
TEST(DecisionTree, fold) {
Expand Down

0 comments on commit e5fb4cd

Please sign in to comment.