From 3c62ab77de142229b10ec137515120b3e346a6fc Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 24 Mar 2022 14:18:43 -0400 Subject: [PATCH 1/2] remove redundancy in enumerate --- gtsam/discrete/DecisionTreeFactor.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ef4cc48f69..e95b8fe374 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -156,10 +156,7 @@ namespace gtsam { std::vector> DecisionTreeFactor::enumerate() const { // Get all possible assignments - std::vector> pairs; - for (auto& key : keys()) { - pairs.emplace_back(key, cardinalities_.at(key)); - } + std::vector> pairs = discreteKeys(); // Reverse to make cartesian product output a more natural ordering. std::vector> rpairs(pairs.rbegin(), pairs.rend()); const auto assignments = DiscreteValues::CartesianProduct(rpairs); From d5cc4554db9b96c78bce58d37d643cb5caa0ad01 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 24 Mar 2022 14:35:50 -0400 Subject: [PATCH 2/2] add new nrLeaves method for DecisionTree --- gtsam/discrete/DecisionTree-inl.h | 16 +++++++++++++--- gtsam/discrete/DecisionTree.h | 3 +++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 0ebfc86bce..b6e5482978 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -635,11 +635,13 @@ namespace gtsam { std::function Y_of_X) const { using LY = DecisionTree; - // ugliness below because apparently we can't have templated virtual - // functions If leaf, apply unary conversion "op" and create a unique leaf + // Ugliness below because apparently we can't have templated virtual + // functions. + // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; - if (auto leaf = boost::dynamic_pointer_cast(f)) + if (auto leaf = boost::dynamic_pointer_cast(f)) { return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + } // Check if Choice using MXChoice = typename DecisionTree::Choice; @@ -727,6 +729,14 @@ namespace gtsam { visit(root_); } + /****************************************************************************/ + template + size_t DecisionTree::nrLeaves() const { + size_t total = 0; + visit([&total](const Y& node) { total += 1; }); + return total; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 13ff0a8c65..c0a2a7a1c6 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -262,6 +262,9 @@ namespace gtsam { template void visitWith(Func f) const; + /// Return the number of leaves in the tree. + size_t nrLeaves() const; + /** * @brief Fold a binary function over the tree, returning accumulator. *