diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 0ebfc86bce..b6e5482978 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -635,11 +635,13 @@ namespace gtsam { std::function Y_of_X) const { using LY = DecisionTree; - // ugliness below because apparently we can't have templated virtual - // functions If leaf, apply unary conversion "op" and create a unique leaf + // Ugliness below because apparently we can't have templated virtual + // functions. + // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; - if (auto leaf = boost::dynamic_pointer_cast(f)) + if (auto leaf = boost::dynamic_pointer_cast(f)) { return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + } // Check if Choice using MXChoice = typename DecisionTree::Choice; @@ -727,6 +729,14 @@ namespace gtsam { visit(root_); } + /****************************************************************************/ + template + size_t DecisionTree::nrLeaves() const { + size_t total = 0; + visit([&total](const Y& node) { total += 1; }); + return total; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 13ff0a8c65..c0a2a7a1c6 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -262,6 +262,9 @@ namespace gtsam { template void visitWith(Func f) const; + /// Return the number of leaves in the tree. + size_t nrLeaves() const; + /** * @brief Fold a binary function over the tree, returning accumulator. * diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ef4cc48f69..e95b8fe374 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -156,10 +156,7 @@ namespace gtsam { std::vector> DecisionTreeFactor::enumerate() const { // Get all possible assignments - std::vector> pairs; - for (auto& key : keys()) { - pairs.emplace_back(key, cardinalities_.at(key)); - } + std::vector> pairs = discreteKeys(); // Reverse to make cartesian product output a more natural ordering. std::vector> rpairs(pairs.rbegin(), pairs.rend()); const auto assignments = DiscreteValues::CartesianProduct(rpairs);