Skip to content

Commit

Permalink
Merge pull request #1144 from borglab/decision-tree-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Mar 24, 2022
2 parents c693d7d + d5cc455 commit 239a978
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
16 changes: 13 additions & 3 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -635,11 +635,13 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>;

// ugliness below because apparently we can't have templated virtual
// functions If leaf, apply unary conversion "op" and create a unique leaf
// Ugliness below because apparently we can't have templated virtual
// functions.
// If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) {
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
}

// Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice;
Expand Down Expand Up @@ -727,6 +729,14 @@ namespace gtsam {
visit(root_);
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrLeaves() const {
size_t total = 0;
visit([&total](const Y& node) { total += 1; });
return total;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ namespace gtsam {
template <typename Func>
void visitWith(Func f) const;

/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
5 changes: 1 addition & 4 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,7 @@ namespace gtsam {
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
pairs.emplace_back(key, cardinalities_.at(key));
}
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
// Reverse to make cartesian product output a more natural ordering.
std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
const auto assignments = DiscreteValues::CartesianProduct(rpairs);
Expand Down

0 comments on commit 239a978

Please sign in to comment.