From d0ff3ab97ecdfb8466eae676a096baf55693aacd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 12:40:29 -0500 Subject: [PATCH 01/10] Fix most lint errors --- gtsam/discrete/DecisionTree-inl.h | 166 +++++++++++++++--------------- gtsam/discrete/DecisionTree.h | 68 ++++++------ 2 files changed, 120 insertions(+), 114 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 84116ccd5f..01c7b689c1 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -21,42 +21,44 @@ #include +#include #include #include +#include #include #include #include #include #include -#include #include #include #include +#include +#include #include +#include +#include using boost::assign::operator+=; namespace gtsam { - /*********************************************************************************/ + /****************************************************************************/ // Node - /*********************************************************************************/ + /****************************************************************************/ #ifdef DT_DEBUG_MEMORY template int DecisionTree::Node::nrNodes = 0; #endif - /*********************************************************************************/ + /****************************************************************************/ // Leaf - /*********************************************************************************/ - template - class DecisionTree::Leaf: public DecisionTree::Node { - + /****************************************************************************/ + template + struct DecisionTree::Leaf : public DecisionTree::Node { /** constant stored in this leaf */ Y constant_; - public: - /** Constructor from constant */ Leaf(const Y& constant) : constant_(constant) {} @@ -96,7 +98,7 @@ namespace gtsam { std::string value = valueFormatter(constant_); if (showZero || value.compare("0")) os << "\"" << this->id() << "\" [label=\"" << value - << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; } /** evaluate */ @@ -121,13 +123,13 @@ namespace gtsam { // Applying binary operator to two leaves results in a leaf NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { - NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL return h; } // If second argument is a Choice node, call it's apply with leaf as second NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { - return fC.apply_fC_op_gL(*this, op); // operand order back to normal + return fC.apply_fC_op_gL(*this, op); // operand order back to normal } /** choose a branch, create new memory ! */ @@ -136,32 +138,30 @@ namespace gtsam { } bool isLeaf() const override { return true; } + }; // Leaf - }; // Leaf - - /*********************************************************************************/ + /****************************************************************************/ // Choice - /*********************************************************************************/ + /****************************************************************************/ template - class DecisionTree::Choice: public DecisionTree::Node { - + struct DecisionTree::Choice: public DecisionTree::Node { /** the label of the variable on which we split */ L label_; /** The children of this Choice node. */ std::vector branches_; - private: + private: /** incremental allSame */ size_t allSame_; using ChoicePtr = boost::shared_ptr; - public: - + public: ~Choice() override { #ifdef DT_DEBUG_MEMORY - std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; + std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() + << std::std::endl; #endif } @@ -172,7 +172,8 @@ namespace gtsam { assert(f->branches().size() > 0); NodePtr f0 = f->branches_[0]; assert(f0->isLeaf()); - NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + NodePtr newLeaf( + new Leaf(boost::dynamic_pointer_cast(f0)->constant())); return newLeaf; } else #endif @@ -192,7 +193,6 @@ namespace gtsam { */ Choice(const Choice& f, const Choice& g, const Binary& op) : allSame_(true) { - // Choose what to do based on label if (f.label() > g.label()) { // f higher than g @@ -318,10 +318,8 @@ namespace gtsam { */ Choice(const L& label, const Choice& f, const Unary& op) : label_(label), allSame_(true) { - - branches_.reserve(f.branches_.size()); // reserve space - for (const NodePtr& branch: f.branches_) - push_back(branch->apply(op)); + branches_.reserve(f.branches_.size()); // reserve space + for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); } /** apply unary operator */ @@ -364,8 +362,7 @@ namespace gtsam { /** choose a branch, recursively */ NodePtr choose(const L& label, size_t index) const override { - if (label_ == label) - return branches_[index]; // choose branch + if (label_ == label) return branches_[index]; // choose branch // second case, not label of interest, just recurse auto r = boost::make_shared(label_, branches_.size()); @@ -373,12 +370,11 @@ namespace gtsam { r->push_back(branch->choose(label, index)); return Unique(r); } + }; // Choice - }; // Choice - - /*********************************************************************************/ + /****************************************************************************/ // DecisionTree - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree() { } @@ -388,13 +384,13 @@ namespace gtsam { root_(root) { } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const Y& y) { root_ = NodePtr(new Leaf(y)); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const L& label, const Y& y1, const Y& y2) { auto a = boost::make_shared(label, 2); @@ -404,7 +400,7 @@ namespace gtsam { root_ = Choice::Unique(a); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const LabelC& labelC, const Y& y1, const Y& y2) { @@ -417,7 +413,7 @@ namespace gtsam { root_ = Choice::Unique(a); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::vector& ys) { @@ -425,29 +421,28 @@ namespace gtsam { root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const std::vector& labelCs, const std::string& table) { - // Convert std::string to values of type Y std::vector ys; std::istringstream iss(table); copy(std::istream_iterator(iss), std::istream_iterator(), - back_inserter(ys)); + back_inserter(ys)); // now call recursive Create root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } - /*********************************************************************************/ + /****************************************************************************/ template template DecisionTree::DecisionTree( Iterator begin, Iterator end, const L& label) { root_ = compose(begin, end, label); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree::DecisionTree(const L& label, const DecisionTree& f0, const DecisionTree& f1) { @@ -456,17 +451,17 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } - /*********************************************************************************/ + /****************************************************************************/ template template DecisionTree::DecisionTree(const DecisionTree& other, Func Y_of_X) { // Define functor for identity mapping of node label. - auto L_of_L = [](const L& label) { return label; }; + auto L_of_L = [](const L& label) { return label; }; root_ = convertFrom(other.root_, L_of_L, Y_of_X); } - /*********************************************************************************/ + /****************************************************************************/ template template DecisionTree::DecisionTree(const DecisionTree& other, @@ -475,16 +470,16 @@ namespace gtsam { root_ = convertFrom(other.root_, L_of_M, Y_of_X); } - /*********************************************************************************/ + /****************************************************************************/ // Called by two constructors above. - // Takes a label and a corresponding range of decision trees, and creates a new - // decision tree. However, the order of the labels needs to be respected, so we - // cannot just create a root Choice node on the label: if the label is not the - // highest label, we need to do a complicated and expensive recursive call. - template template - typename DecisionTree::NodePtr DecisionTree::compose(Iterator begin, - Iterator end, const L& label) const { - + // Takes a label and a corresponding range of decision trees, and creates a + // new decision tree. However, the order of the labels needs to be respected, + // so we cannot just create a root Choice node on the label: if the label is + // not the highest label, we need a complicated/ expensive recursive call. + template + template + typename DecisionTree::NodePtr DecisionTree::compose( + Iterator begin, Iterator end, const L& label) const { // find highest label among branches boost::optional highestLabel; size_t nrChoices = 0; @@ -527,7 +522,7 @@ namespace gtsam { } } - /*********************************************************************************/ + /****************************************************************************/ // "create" is a bit of a complicated thing, but very useful. // It takes a range of labels and a corresponding range of values, // and creates a decision tree, as follows: @@ -552,7 +547,6 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::create( It begin, It end, ValueIt beginY, ValueIt endY) const { - // get crucial counts size_t nrChoices = begin->second; size_t size = endY - beginY; @@ -564,7 +558,11 @@ namespace gtsam { // Create a simple choice node with values as leaves. if (size != nrChoices) { std::cout << "Trying to create DD on " << begin->first << std::endl; - std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; + std::cout << boost::format( + "DecisionTree::create: expected %d values but got %d " + "instead") % + nrChoices % size + << std::endl; throw std::invalid_argument("DecisionTree::create invalid argument"); } auto choice = boost::make_shared(begin->first, endY - beginY); @@ -585,7 +583,7 @@ namespace gtsam { return compose(functions.begin(), functions.end(), begin->first); } - /*********************************************************************************/ + /****************************************************************************/ template template typename DecisionTree::NodePtr DecisionTree::convertFrom( @@ -594,11 +592,11 @@ namespace gtsam { std::function Y_of_X) const { using LY = DecisionTree; - // ugliness below because apparently we can't have templated virtual functions - // If leaf, apply unary conversion "op" and create a unique leaf + // ugliness below because apparently we can't have templated virtual + // functions If leaf, apply unary conversion "op" and create a unique leaf using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = boost::dynamic_pointer_cast(f)) - return NodePtr(new Leaf(Y_of_X(leaf->constant()))); + return NodePtr(new Leaf(Y_of_X(leaf->constant()))); // Check if Choice using MXChoice = typename DecisionTree::Choice; @@ -612,19 +610,19 @@ namespace gtsam { // put together via Shannon expansion otherwise not sorted. std::vector functions; - for(auto && branch: choice->branches()) { + for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } return LY::compose(functions.begin(), functions.end(), newLabel); } - /*********************************************************************************/ + /****************************************************************************/ // Functor performing depth-first visit without Assignment argument. template struct Visit { using F = std::function; - Visit(F f) : f(f) {} ///< Construct from folding function. - F f; ///< folding function object. + explicit Visit(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) const { @@ -647,15 +645,15 @@ namespace gtsam { visit(root_); } - /*********************************************************************************/ + /****************************************************************************/ // Functor performing depth-first visit with Assignment argument. template struct VisitWith { using Choices = Assignment; using F = std::function; - VisitWith(F f) : f(f) {} ///< Construct from folding function. - Choices choices; ///< Assignment, mutating through recursion. - F f; ///< folding function object. + explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. + Choices choices; ///< Assignment, mutating through recursion. + F f; ///< folding function object. /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) { @@ -681,7 +679,7 @@ namespace gtsam { visit(root_); } - /*********************************************************************************/ + /****************************************************************************/ // fold is just done with a visit template template @@ -690,7 +688,7 @@ namespace gtsam { return x0; } - /*********************************************************************************/ + /****************************************************************************/ // labels is just done with a visit template std::set DecisionTree::labels() const { @@ -702,7 +700,7 @@ namespace gtsam { return unique; } -/*********************************************************************************/ +/****************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, const CompareFunc& compare) const { @@ -736,7 +734,7 @@ namespace gtsam { return DecisionTree(root_->apply(op)); } - /*********************************************************************************/ + /****************************************************************************/ template DecisionTree DecisionTree::apply(const DecisionTree& g, const Binary& op) const { @@ -752,7 +750,7 @@ namespace gtsam { return result; } - /*********************************************************************************/ + /****************************************************************************/ // The way this works: // We have an ADT, picture it as a tree. // At a certain depth, we have a branch on "label". @@ -772,7 +770,7 @@ namespace gtsam { return result; } - /*********************************************************************************/ + /****************************************************************************/ template void DecisionTree::dot(std::ostream& os, const LabelFormatter& labelFormatter, @@ -790,9 +788,11 @@ namespace gtsam { bool showZero) const { std::ofstream os((name + ".dot").c_str()); dot(os, labelFormatter, valueFormatter, showZero); - int result = system( - ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); - if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); + int result = + system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null") + .c_str()); + if (result == -1) + throw std::runtime_error("DecisionTree::dot system call failed"); } template @@ -804,8 +804,6 @@ namespace gtsam { return ss.str(); } -/*********************************************************************************/ - -} // namespace gtsam - +/******************************************************************************/ + } // namespace gtsam diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 78f3a75b72..53782ef5e3 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -26,9 +26,11 @@ #include #include #include +#include #include +#include +#include #include -#include namespace gtsam { @@ -39,15 +41,13 @@ namespace gtsam { */ template class DecisionTree { - protected: /// Default method for comparison of two objects of type Y. static bool DefaultCompare(const Y& a, const Y& b) { return a == b; } - public: - + public: using LabelFormatter = std::function; using ValueFormatter = std::function; using CompareFunc = std::function; @@ -57,15 +57,14 @@ namespace gtsam { using Binary = std::function; /** A label annotated with cardinality */ - using LabelC = std::pair; + using LabelC = std::pair; /** DTs consist of Leaf and Choice nodes, both subclasses of Node */ - class Leaf; - class Choice; + struct Leaf; + struct Choice; /** ------------------------ Node base class --------------------------- */ - class Node { - public: + struct Node { using Ptr = boost::shared_ptr; #ifdef DT_DEBUG_MEMORY @@ -75,14 +74,16 @@ namespace gtsam { // Constructor Node() { #ifdef DT_DEBUG_MEMORY - std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); + std::cout << ++nrNodes << " constructed " << id() << std::endl; + std::cout.flush(); #endif } // Destructor virtual ~Node() { #ifdef DT_DEBUG_MEMORY - std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); + std::cout << --nrNodes << " destructed " << id() << std::endl; + std::cout.flush(); #endif } @@ -110,17 +111,17 @@ namespace gtsam { }; /** ------------------------ Node base class --------------------------- */ - public: - + public: /** A function is a shared pointer to the root of a DT */ using NodePtr = typename Node::Ptr; /// A DecisionTree just contains the root. TODO(dellaert): make protected. NodePtr root_; - protected: - - /** Internal recursive function to create from keys, cardinalities, and Y values */ + protected: + /** Internal recursive function to create from keys, cardinalities, + * and Y values + */ template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; @@ -140,7 +141,6 @@ namespace gtsam { std::function Y_of_X) const; public: - /// @name Standard Constructors /// @{ @@ -148,7 +148,7 @@ namespace gtsam { DecisionTree(); /** Create a constant */ - DecisionTree(const Y& y); + explicit DecisionTree(const Y& y); /** Create a new leaf function splitting on a variable */ DecisionTree(const L& label, const Y& y1, const Y& y2); @@ -167,8 +167,8 @@ namespace gtsam { DecisionTree(Iterator begin, Iterator end, const L& label); /** Create DecisionTree from two others */ - DecisionTree(const L& label, // - const DecisionTree& f0, const DecisionTree& f1); + DecisionTree(const L& label, const DecisionTree& f0, + const DecisionTree& f1); /** * @brief Convert from a different value type. @@ -289,7 +289,8 @@ namespace gtsam { } /** combine subtrees on key with binary operation "op" */ - DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; + DecisionTree combine(const L& label, size_t cardinality, + const Binary& op) const; /** combine with LabelC for convenience */ DecisionTree combine(const LabelC& labelC, const Binary& op) const { @@ -313,15 +314,14 @@ namespace gtsam { /// @{ // internal use only - DecisionTree(const NodePtr& root); + explicit DecisionTree(const NodePtr& root); // internal use only template NodePtr compose(Iterator begin, Iterator end, const L& label) const; /// @} - - }; // DecisionTree + }; // DecisionTree /** free versions of apply */ @@ -340,11 +340,19 @@ namespace gtsam { return f.apply(g, op); } - /// unzip a DecisionTree if its leaves are `std::pair` - template - std::pair, DecisionTree > unzip(const DecisionTree > &input) { - return std::make_pair(DecisionTree(input, [](std::pair i) { return i.first; }), - DecisionTree(input, [](std::pair i) { return i.second; })); + /** + * @brief unzip a DecisionTree with `std::pair` values. + * + * @param input the DecisionTree with `(T1,T2)` values. + * @return a pair of DecisionTree on T1 and T2, respectively. + */ + template + std::pair, DecisionTree > unzip( + const DecisionTree >& input) { + return std::make_pair( + DecisionTree(input, [](std::pair i) { return i.first; }), + DecisionTree(input, + [](std::pair i) { return i.second; })); } -} // namespace gtsam +} // namespace gtsam From 9317e94452c5374bd25fa6764dc315d54e68b5a8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 09:04:27 -0500 Subject: [PATCH 02/10] Fix formatting --- gtsam/discrete/tests/testDecisionTree.cpp | 122 ++++++++++------------ 1 file changed, 56 insertions(+), 66 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 1029417764..c157a25433 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -31,14 +31,14 @@ using namespace boost::assign; using namespace std; using namespace gtsam; -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif } -#define DOT(x)(dot(x,#x)) +#define DOT(x) (dot(x, #x)) struct Crazy { int a; @@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree { // traits namespace gtsam { -template<> struct traits : public Testable {}; -} +template <> +struct traits : public Testable {}; +} // namespace gtsam GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test string labels and int range -/* ******************************************************************************** */ +/* ************************************************************************** */ struct DT : public DecisionTree { using Base = DecisionTree; @@ -98,30 +99,21 @@ struct DT : public DecisionTree { // traits namespace gtsam { -template<> struct traits
: public Testable
{}; -} +template <> +struct traits
: public Testable
{}; +} // namespace gtsam GTSAM_CONCEPT_TESTABLE_INST(DT) struct Ring { - static inline int zero() { - return 0; - } - static inline int one() { - return 1; - } - static inline int id(const int& a) { - return a; - } - static inline int add(const int& a, const int& b) { - return a + b; - } - static inline int mul(const int& a, const int& b) { - return a * b; - } + static inline int zero() { return 0; } + static inline int one() { return 1; } + static inline int id(const int& a) { return a; } + static inline int add(const int& a, const int& b) { return a + b; } + static inline int mul(const int& a, const int& b) { return a * b; } }; -/* ******************************************************************************** */ +/* ************************************************************************** */ // test DT TEST(DecisionTree, example) { // Create labels @@ -139,20 +131,20 @@ TEST(DecisionTree, example) { // A DT a(A, 0, 5); - LONGS_EQUAL(0,a(x00)) - LONGS_EQUAL(5,a(x10)) + LONGS_EQUAL(0, a(x00)) + LONGS_EQUAL(5, a(x10)) DOT(a); // pruned DT p(A, 2, 2); - LONGS_EQUAL(2,p(x00)) - LONGS_EQUAL(2,p(x10)) + LONGS_EQUAL(2, p(x00)) + LONGS_EQUAL(2, p(x10)) DOT(p); // \neg B DT notb(B, 5, 0); - LONGS_EQUAL(5,notb(x00)) - LONGS_EQUAL(5,notb(x10)) + LONGS_EQUAL(5, notb(x00)) + LONGS_EQUAL(5, notb(x10)) DOT(notb); // Check supplying empty trees yields an exception @@ -162,34 +154,34 @@ TEST(DecisionTree, example) { // apply, two nodes, in natural order DT anotb = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,anotb(x00)) - LONGS_EQUAL(0,anotb(x01)) - LONGS_EQUAL(25,anotb(x10)) - LONGS_EQUAL(0,anotb(x11)) + LONGS_EQUAL(0, anotb(x00)) + LONGS_EQUAL(0, anotb(x01)) + LONGS_EQUAL(25, anotb(x10)) + LONGS_EQUAL(0, anotb(x11)) DOT(anotb); // check pruning DT pnotb = apply(p, notb, &Ring::mul); - LONGS_EQUAL(10,pnotb(x00)) - LONGS_EQUAL( 0,pnotb(x01)) - LONGS_EQUAL(10,pnotb(x10)) - LONGS_EQUAL( 0,pnotb(x11)) + LONGS_EQUAL(10, pnotb(x00)) + LONGS_EQUAL(0, pnotb(x01)) + LONGS_EQUAL(10, pnotb(x10)) + LONGS_EQUAL(0, pnotb(x11)) DOT(pnotb); // check pruning DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); - LONGS_EQUAL(0,zeros(x00)) - LONGS_EQUAL(0,zeros(x01)) - LONGS_EQUAL(0,zeros(x10)) - LONGS_EQUAL(0,zeros(x11)) + LONGS_EQUAL(0, zeros(x00)) + LONGS_EQUAL(0, zeros(x01)) + LONGS_EQUAL(0, zeros(x10)) + LONGS_EQUAL(0, zeros(x11)) DOT(zeros); // apply, two nodes, in switched order DT notba = apply(a, notb, &Ring::mul); - LONGS_EQUAL(0,notba(x00)) - LONGS_EQUAL(0,notba(x01)) - LONGS_EQUAL(25,notba(x10)) - LONGS_EQUAL(0,notba(x11)) + LONGS_EQUAL(0, notba(x00)) + LONGS_EQUAL(0, notba(x01)) + LONGS_EQUAL(25, notba(x10)) + LONGS_EQUAL(0, notba(x11)) DOT(notba); // Test choose 0 @@ -204,10 +196,10 @@ TEST(DecisionTree, example) { // apply, two nodes at same level DT a_and_a = apply(a, a, &Ring::mul); - LONGS_EQUAL(0,a_and_a(x00)) - LONGS_EQUAL(0,a_and_a(x01)) - LONGS_EQUAL(25,a_and_a(x10)) - LONGS_EQUAL(25,a_and_a(x11)) + LONGS_EQUAL(0, a_and_a(x00)) + LONGS_EQUAL(0, a_and_a(x01)) + LONGS_EQUAL(25, a_and_a(x10)) + LONGS_EQUAL(25, a_and_a(x11)) DOT(a_and_a); // create a function on C @@ -219,16 +211,16 @@ TEST(DecisionTree, example) { // mul notba with C DT notbac = apply(notba, c, &Ring::mul); - LONGS_EQUAL(125,notbac(x101)) + LONGS_EQUAL(125, notbac(x101)) DOT(notbac); // mul now in different order DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); - LONGS_EQUAL(125,acnotb(x101)) + LONGS_EQUAL(125, acnotb(x101)) DOT(acnotb); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test Conversion of values bool bool_of_int(const int& y) { return y != 0; }; typedef DecisionTree StringBoolTree; @@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) { EXPECT(!f2(x00)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test Conversion of both values and labels. -enum Label { - U, V, X, Y, Z -}; +enum Label { U, V, X, Y, Z }; typedef DecisionTree LabelBoolTree; TEST(DecisionTree, ConvertBoth) { @@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) { EXPECT(!f2(x11)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test Compose expansion TEST(DecisionTree, Compose) { // Create labels @@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) { // Create from string vector keys; - keys += DT::LabelC(A,2), DT::LabelC(B,2); + keys += DT::LabelC(A, 2), DT::LabelC(B, 2); DT f2(keys, "0 2 1 3"); EXPECT(assert_equal(f2, f1, 1e-9)); @@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) { DOT(f4); // a bigger tree - keys += DT::LabelC(C,2); + keys += DT::LabelC(C, 2); DT f5(keys, "0 4 2 6 1 5 3 7"); EXPECT(assert_equal(f5, f4, 1e-9)); DOT(f5); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Check we can create a decision tree of containers. TEST(DecisionTree, Containers) { using Container = std::vector; @@ -330,7 +320,7 @@ TEST(DecisionTree, Containers) { StringContainerTree converted(stringIntTree, container_of_int); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test visit. TEST(DecisionTree, visit) { // Create small two-level tree @@ -342,7 +332,7 @@ TEST(DecisionTree, visit) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test visit, with Choices argument. TEST(DecisionTree, visitWith) { // Create small two-level tree @@ -354,7 +344,7 @@ TEST(DecisionTree, visitWith) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test fold. TEST(DecisionTree, fold) { // Create small two-level tree @@ -365,7 +355,7 @@ TEST(DecisionTree, fold) { EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test retrieving all labels. TEST(DecisionTree, labels) { // Create small two-level tree From 241906d2c95f9dac6ed0034cdc73fd2fa597eb54 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 10:43:46 -0500 Subject: [PATCH 03/10] Thresholding test --- gtsam/discrete/tests/testDecisionTree.cpp | 35 +++++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c157a25433..c338bb86fa 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -308,7 +308,7 @@ TEST(DecisionTree, Containers) { StringContainerTree tree; // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); // Check conversion @@ -324,7 +324,7 @@ TEST(DecisionTree, Containers) { // Test visit. TEST(DecisionTree, visit) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); double sum = 0.0; auto visitor = [&](int y) { sum += y; }; @@ -336,7 +336,7 @@ TEST(DecisionTree, visit) { // Test visit, with Choices argument. TEST(DecisionTree, visitWith) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); double sum = 0.0; auto visitor = [&](const Assignment& choices, int y) { sum += y; }; @@ -348,7 +348,7 @@ TEST(DecisionTree, visitWith) { // Test fold. TEST(DecisionTree, fold) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); auto add = [](const int& y, double x) { return y + x; }; double sum = tree.fold(add, 0.0); @@ -359,14 +359,14 @@ TEST(DecisionTree, fold) { // Test retrieving all labels. TEST(DecisionTree, labels) { // Create small two-level tree - string A("A"), B("B"), C("C"); + string A("A"), B("B"); DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); auto labels = tree.labels(); EXPECT_LONGS_EQUAL(2, labels.size()); } /* ******************************************************************************** */ -// Test retrieving all labels. +// Test unzip method. TEST(DecisionTree, unzip) { using DTP = DecisionTree>; using DT1 = DecisionTree; @@ -390,6 +390,29 @@ TEST(DecisionTree, unzip) { EXPECT(tree2.equals(dt2)); } +/* ************************************************************************** */ +// Test thresholding. +TEST(DecisionTree, threshold) { + // Create three level tree + vector keys; + keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); + DT tree(keys, "0 1 2 3 4 5 6 7"); + + // Check number of elements equal to zero + auto count = [](const int& value, int count) { + return value == 0 ? count + 1 : count; + }; + EXPECT_LONGS_EQUAL(1, tree.fold(count, 0)); + + // Now threshold + auto threshold = [](int value) { return value < 5 ? 0 : value; }; + DT thresholded(tree, threshold); + + // Check number of elements equal to zero now = 5 + // TODO(frank): it is 2, because the pruned branches are counted as 1! + EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0)); +} + /* ************************************************************************* */ int main() { TestResult tr; From 94c692ddd1eb1e62067bd44d3237c4dbd6e15559 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 11:59:48 -0500 Subject: [PATCH 04/10] New test on marginal --- .../tests/testDiscreteConditional.cpp | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index c2d941eaa7..13a34dd19d 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) { DiscreteConditional prior(B % "1/2"); DiscreteConditional pAB = prior * conditional; + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional pA(A % "5/4"); EXPECT(assert_equal(pA, actualA)); - EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT(actualA.frontals() == KeyVector{1}); EXPECT_LONGS_EQUAL(0, actualA.nrParents()); - KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); - EXPECT((frontalsA == KeyVector{1})); DiscreteConditional actualB = pAB.marginal(B.first); EXPECT(assert_equal(prior, actualB)); - EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT(actualB.frontals() == KeyVector{0}); EXPECT_LONGS_EQUAL(0, actualB.nrParents()); - KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); - EXPECT((frontalsB == KeyVector{0})); +} + +/* ************************************************************************* */ +// Check calculation of marginals in case branches are pruned +TEST(DiscreteConditional, marginals2) { + DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen! + DiscreteConditional conditional(A | B = "2/2 3/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + GTSAM_PRINT(pAB); + // P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8 + // P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4 + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "8/4"); + EXPECT(assert_equal(pA, actualA)); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); } /* ************************************************************************* */ From ca329daa13dd93cc2d284951bd1da1f8595a6b6a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 12:50:35 -0500 Subject: [PATCH 05/10] linting --- gtsam/discrete/DecisionTreeFactor.cpp | 146 ++++++++++++++------------ gtsam/discrete/DecisionTreeFactor.h | 82 ++++++--------- 2 files changed, 106 insertions(+), 122 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1e8f5aa3e7..ef4cc48f69 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -17,9 +17,9 @@ * @author Frank Dellaert */ +#include #include #include -#include #include #include @@ -29,42 +29,42 @@ using namespace std; namespace gtsam { - /* ******************************************************************************** */ - DecisionTreeFactor::DecisionTreeFactor() { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor() {} - /* ******************************************************************************** */ + /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) : - DiscreteFactor(keys.indices()), ADT(potentials), - cardinalities_(keys.cardinalities()) { - } + const ADT& potentials) + : DiscreteFactor(keys.indices()), + ADT(potentials), + cardinalities_(keys.cardinalities()) {} - /* *************************************************************************/ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), AlgebraicDecisionTree(c), cardinalities_(c.cardinalities_) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) + : DiscreteFactor(c.keys()), + AlgebraicDecisionTree(c), + cardinalities_(c.cardinalities_) {} - /* ************************************************************************* */ - bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { - if(!dynamic_cast(&other)) { + /* ************************************************************************ */ + bool DecisionTreeFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { return false; - } - else { + } else { const auto& f(static_cast(other)); return ADT::equals(f, tol); } } - /* ************************************************************************* */ - double DecisionTreeFactor::safe_div(const double &a, const double &b) { + /* ************************************************************************ */ + double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum // factor. If the product or sum is zero, we accord zero probability to the // event. return (a == 0 || b == 0) ? 0 : (a / b); } - /* ************************************************************************* */ + /* ************************************************************************ */ void DecisionTreeFactor::print(const string& s, const KeyFormatter& formatter) const { cout << s; @@ -75,31 +75,32 @@ namespace gtsam { ADT::print("", formatter); } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { - map cs; // new cardinalities + ADT::Binary op) const { + map cs; // new cardinalities // make unique key-cardinality map - for(Key j: keys()) cs[j] = cardinality(j); - for(Key j: f.keys()) cs[j] = f.cardinality(j); + for (Key j : keys()) cs[j] = cardinality(j); + for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for(const std::pair& key: cs) - keys.push_back(key); + for (const std::pair& key : cs) keys.push_back(key); // apply operand ADT result = ADT::apply(f, op); // Make a new factor return DecisionTreeFactor(keys, result); } - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, - ADT::Binary op) const { - - if (nrFrontals > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % nrFrontals % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + size_t nrFrontals, ADT::Binary op) const { + if (nrFrontals > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + nrFrontals % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -113,20 +114,21 @@ namespace gtsam { DiscreteKeys dkeys; for (; i < keys().size(); i++) { Key j = keys()[i]; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, - ADT::Binary op) const { - - if (frontalKeys.size() > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % frontalKeys.size() % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + const Ordering& frontalKeys, ADT::Binary op) const { + if (frontalKeys.size() > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + frontalKeys.size() % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -137,20 +139,22 @@ namespace gtsam { } // create new factor, note we collect keys that are not in frontalKeys - // TODO: why do we need this??? result should contain correct keys!!! + // TODO(frank): why do we need this??? result should contain correct keys!!! DiscreteKeys dkeys; for (i = 0; i < keys().size(); i++) { Key j = keys()[i]; - // TODO: inefficient! - if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) + // TODO(frank): inefficient! + if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != + frontalKeys.end()) continue; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - /* ************************************************************************* */ - std::vector> DecisionTreeFactor::enumerate() const { + /* ************************************************************************ */ + std::vector> DecisionTreeFactor::enumerate() + const { // Get all possible assignments std::vector> pairs; for (auto& key : keys()) { @@ -168,7 +172,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ DiscreteKeys DecisionTreeFactor::discreteKeys() const { DiscreteKeys result; for (auto&& key : keys()) { @@ -180,7 +184,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ static std::string valueFormatter(const double& v) { return (boost::format("%4.2g") % v).str(); } @@ -194,8 +198,8 @@ namespace gtsam { /** output to graphviz format, open a file */ void DecisionTreeFactor::dot(const std::string& name, - const KeyFormatter& keyFormatter, - bool showZero) const { + const KeyFormatter& keyFormatter, + bool showZero) const { ADT::dot(name, keyFormatter, valueFormatter, showZero); } @@ -205,8 +209,8 @@ namespace gtsam { return ADT::dot(keyFormatter, valueFormatter, showZero); } - // Print out header. - /* ************************************************************************* */ + // Print out header. + /* ************************************************************************ */ string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, const Names& names) const { stringstream ss; @@ -271,17 +275,19 @@ namespace gtsam { return ss.str(); } - /* ************************************************************************* */ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector &table) : - DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const vector& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} - /* ************************************************************************* */ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : - DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const string& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} - /* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 751b8c62c4..91fa7c4849 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,16 +18,18 @@ #pragma once +#include #include #include -#include #include +#include #include - -#include -#include +#include #include +#include +#include +#include namespace gtsam { @@ -36,21 +38,19 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree { - - public: - + class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor, + public AlgebraicDecisionTree { + public: // typedefs needed to play nice with gtsam typedef DecisionTreeFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class + typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; - protected: - std::map cardinalities_; - - public: + protected: + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ @@ -61,7 +61,8 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); /** Constructor from doubles */ - DecisionTreeFactor(const DiscreteKeys& keys, const std::vector& table); + DecisionTreeFactor(const DiscreteKeys& keys, + const std::vector& table); /** Constructor from string */ DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); @@ -86,7 +87,8 @@ namespace gtsam { bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; // print - void print(const std::string& s = "DecisionTreeFactor:\n", + void print( + const std::string& s = "DecisionTreeFactor:\n", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// @} @@ -105,7 +107,7 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - size_t cardinality(Key j) const { return cardinalities_.at(j);} + size_t cardinality(Key j) const { return cardinalities_.at(j); } /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { @@ -113,9 +115,7 @@ namespace gtsam { } /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - return *this; - } + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values shared_ptr sum(size_t nrFrontals) const { @@ -164,27 +164,6 @@ namespace gtsam { */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; - -// /** -// * @brief Permutes the keys in Potentials and DiscreteFactor -// * -// * This re-implements the permuteWithInverse() in both Potentials -// * and DiscreteFactor by doing both of them together. -// */ -// -// void permuteWithInverse(const Permutation& inversePermutation){ -// DiscreteFactor::permuteWithInverse(inversePermutation); -// Potentials::permuteWithInverse(inversePermutation); -// } -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) { -// DiscreteFactor::reduceWithInverse(inverseReduction); -// Potentials::reduceWithInverse(inverseReduction); -// } - /// Enumerate all values into a map from values to double. std::vector> enumerate() const; @@ -194,16 +173,16 @@ namespace gtsam { /// @} /// @name Wrapper support /// @{ - + /** output to graphviz format, stream version */ void dot(std::ostream& os, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - bool showZero = true) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; /** output to graphviz format, open a file */ void dot(const std::string& name, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - bool showZero = true) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; /** output to graphviz format string */ std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, @@ -217,7 +196,7 @@ namespace gtsam { * @return std::string a markdown string. */ std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + const Names& names = {}) const override; /** * @brief Render as html table @@ -227,14 +206,13 @@ namespace gtsam { * @return std::string a html string. */ std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + const Names& names = {}) const override; /// @} - -}; -// DecisionTreeFactor + }; // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -}// namespace gtsam +} // namespace gtsam From 8acf67d4c86838fbe1401bea98ab7db013bb2d80 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 12:58:12 -0500 Subject: [PATCH 06/10] linting --- gtsam/discrete/AlgebraicDecisionTree.h | 87 ++++++++++++-------------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 566357a485..6ce36a688a 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -20,6 +20,10 @@ #include +#include +#include +#include +#include namespace gtsam { /** @@ -27,13 +31,14 @@ namespace gtsam { * Just has some nice constructors and some syntactic sugar * TODO: consider eliminating this class altogether? */ - template - class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree { + template + class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree { /** - * @brief Default method used by `labelFormatter` or `valueFormatter` when printing. - * + * @brief Default method used by `labelFormatter` or `valueFormatter` when + * printing. + * * @param x The value passed to format. - * @return std::string + * @return std::string */ static std::string DefaultFormatter(const L& x) { std::stringstream ss; @@ -42,17 +47,12 @@ namespace gtsam { } public: - using Base = DecisionTree; /** The Real ring with addition and multiplication */ struct Ring { - static inline double zero() { - return 0.0; - } - static inline double one() { - return 1.0; - } + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } static inline double add(const double& a, const double& b) { return a + b; } @@ -65,54 +65,49 @@ namespace gtsam { static inline double div(const double& a, const double& b) { return a / b; } - static inline double id(const double& x) { - return x; - } + static inline double id(const double& x) { return x; } }; - AlgebraicDecisionTree() : - Base(1.0) { - } + AlgebraicDecisionTree() : Base(1.0) {} - AlgebraicDecisionTree(const Base& add) : - Base(add) { - } + explicit AlgebraicDecisionTree(const Base& add) : Base(add) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const L& label, double y1, double y2) : - Base(label, y1, y2) { - } + AlgebraicDecisionTree(const L& label, double y1, double y2) + : Base(label, y1, y2) {} /** Create a new leaf function splitting on a variable */ - AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : - Base(labelC, y1, y2) { - } + AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, + double y2) + : Base(labelC, y1, y2) {} /** Create from keys and vector table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::vector& ys) { - this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::vector& ys) { + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create from keys and string table */ - AlgebraicDecisionTree // - (const std::vector& labelCs, const std::string& table) { + AlgebraicDecisionTree // + (const std::vector& labelCs, + const std::string& table) { // Convert string to doubles std::vector ys; std::istringstream iss(table); std::copy(std::istream_iterator(iss), - std::istream_iterator(), std::back_inserter(ys)); + std::istream_iterator(), std::back_inserter(ys)); // now call recursive Create - this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), - ys.end()); + this->root_ = + Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); } /** Create a new function splitting on a variable */ - template - AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : - Base(nullptr) { + template + AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) + : Base(nullptr) { this->root_ = compose(begin, end, label); } @@ -122,7 +117,7 @@ namespace gtsam { * @param other: The AlgebraicDecisionTree with label type M to convert. * @param map: Map from label type M to label type L. */ - template + template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, const std::map& map) { // Functor for label conversion so we can use `convertFrom`. @@ -160,8 +155,8 @@ namespace gtsam { /// print method customized to value type `double`. void print(const std::string& s, - const typename Base::LabelFormatter& labelFormatter = - &DefaultFormatter) const { + const typename Base::LabelFormatter& labelFormatter = + &DefaultFormatter) const { auto valueFormatter = [](const double& v) { return (boost::format("%4.4g") % v).str(); }; @@ -177,8 +172,8 @@ namespace gtsam { return Base::equals(other, compare); } }; -// AlgebraicDecisionTree -template struct traits> : public Testable> {}; -} -// namespace gtsam +template +struct traits> + : public Testable> {}; +} // namespace gtsam From 289382ea7654658158d212ef87543eb43ab5159a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 13:07:20 -0500 Subject: [PATCH 07/10] linting --- .../tests/testAlgebraicDecisionTree.cpp | 150 +++++++++--------- 1 file changed, 71 insertions(+), 79 deletions(-) diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp index 910515b5c4..9d130a1f66 100644 --- a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -17,38 +17,39 @@ */ #include -#include // make sure we have traits +#include // make sure we have traits #include // headers first to make sure no missing headers //#define DT_NO_PRUNING #include -#include // for convert only +#include // for convert only #define DISABLE_TIMING -#include #include #include +#include using namespace boost::assign; #include -#include #include +#include using namespace std; using namespace gtsam; -/* ******************************************************************************** */ +/* ************************************************************************** */ typedef AlgebraicDecisionTree ADT; // traits namespace gtsam { -template<> struct traits : public Testable {}; -} +template <> +struct traits : public Testable {}; +} // namespace gtsam #define DISABLE_DOT -template -void dot(const T&f, const string& filename) { +template +void dot(const T& f, const string& filename) { #ifndef DISABLE_DOT f.dot(filename); #endif @@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) { // If second argument of binary op is Leaf template - typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( - Cache& cache, const Leaf& gL, Mul op) const { + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const { Ptr h(new Choice(label(), cardinality())); for(const NodePtr& branch: branches_) h->push_back(branch->apply_f_op_g(cache, gL, op)); @@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) { } */ -/* ******************************************************************************** */ +/* ************************************************************************** */ // instrumented operators -/* ******************************************************************************** */ +/* ************************************************************************** */ size_t muls = 0, adds = 0; double elapsed; void resetCounts() { @@ -83,8 +84,9 @@ void resetCounts() { } void printCounts(const string& s) { #ifndef DISABLE_TIMING - cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds - % (1000 * elapsed) << endl; + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds % + (1000 * elapsed) + << endl; #endif resetCounts(); } @@ -97,12 +99,11 @@ double add_(const double& a, const double& b) { return a + b; } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test ADT -TEST(ADT, example3) -{ +TEST(ADT, example3) { // Create labels - DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2); // Literals ADT a(A, 0.5, 0.5); @@ -114,22 +115,21 @@ TEST(ADT, example3) ADT cnotb = c * notb; dot(cnotb, "ADT-cnotb"); -// a.print("a: "); -// cnotb.print("cnotb: "); + // a.print("a: "); + // cnotb.print("cnotb: "); ADT acnotb = a * cnotb; -// acnotb.print("acnotb: "); -// acnotb.printCache("acnotb Cache:"); + // acnotb.print("acnotb: "); + // acnotb.printCache("acnotb Cache:"); dot(acnotb, "ADT-acnotb"); - ADT big = apply(apply(d, note, &mul), acnotb, &add_); dot(big, "ADT-big"); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Asia Bayes Network -/* ******************************************************************************** */ +/* ************************************************************************** */ /** Convert Signature into CPT */ ADT create(const Signature& signature) { @@ -143,9 +143,9 @@ ADT create(const Signature& signature) { /* ************************************************************************* */ // test Asia Joint -TEST(ADT, joint) -{ - DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); +TEST(ADT, joint) { + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), + D(7, 2); resetCounts(); gttic_(asiaCPTs); @@ -204,10 +204,9 @@ TEST(ADT, joint) /* ************************************************************************* */ // test Inference with joint -TEST(ADT, inference) -{ - DiscreteKey A(0,2), D(1,2),// - B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); +TEST(ADT, inference) { + DiscreteKey A(0, 2), D(1, 2), // + B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); resetCounts(); gttic_(infCPTs); @@ -244,7 +243,7 @@ TEST(ADT, inference) dot(joint, "Joint-Product-ASTLBEX"); joint = apply(joint, pD, &mul); dot(joint, "Joint-Product-ASTLBEXD"); - EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering + EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering gttoc_(asiaProd); tictoc_getNode(asiaProdNode, asiaProd); elapsed = asiaProdNode->secs() + asiaProdNode->wall(); @@ -271,9 +270,8 @@ TEST(ADT, inference) } /* ************************************************************************* */ -TEST(ADT, factor_graph) -{ - DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); +TEST(ADT, factor_graph) { + DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); resetCounts(); gttic_(createCPTs); @@ -403,18 +401,19 @@ TEST(ADT, factor_graph) /* ************************************************************************* */ // test equality -TEST(ADT, equality_noparser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_noparser) { + DiscreteKey A(0, 2), B(1, 2); Signature::Table tableA, tableB; Signature::Row rA, rB; - rA += 80, 20; rB += 60, 40; - tableA += rA; tableB += rB; + rA += 80, 20; + rB += 60, 40; + tableA += rA; + tableB += rB; // Check straight equality ADT pA1 = create(A % tableA); ADT pA2 = create(A % tableA); - EXPECT(pA1.equals(pA2)); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % tableB); @@ -425,13 +424,12 @@ TEST(ADT, equality_noparser) /* ************************************************************************* */ // test equality -TEST(ADT, equality_parser) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, equality_parser) { + DiscreteKey A(0, 2), B(1, 2); // Check straight equality ADT pA1 = create(A % "80/20"); ADT pA2 = create(A % "80/20"); - EXPECT(pA1.equals(pA2)); // should be equal + EXPECT(pA1.equals(pA2)); // should be equal // Check equality after apply ADT pB = create(B % "60/40"); @@ -440,12 +438,11 @@ TEST(ADT, equality_parser) EXPECT(pAB2.equals(pAB1)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Factor graph construction // test constructor from strings -TEST(ADT, constructor) -{ - DiscreteKey v0(0,2), v1(1,3); +TEST(ADT, constructor) { + DiscreteKey v0(0, 2), v1(1, 3); DiscreteValues x00, x01, x02, x10, x11, x12; x00[0] = 0, x00[1] = 0; x01[0] = 0, x01[1] = 1; @@ -470,11 +467,10 @@ TEST(ADT, constructor) EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); - DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2); vector table(5 * 4 * 3 * 2); double x = 0; - for(double& t: table) - t = x++; + for (double& t : table) t = x++; ADT f3(z0 & z1 & z2 & z3, table); DiscreteValues assignment; assignment[0] = 0; @@ -487,9 +483,8 @@ TEST(ADT, constructor) /* ************************************************************************* */ // test conversion to integer indices // Only works if DiscreteKeys are binary, as size_t has binary cardinality! -TEST(ADT, conversion) -{ - DiscreteKey X(0,2), Y(1,2); +TEST(ADT, conversion) { + DiscreteKey X(0, 2), Y(1, 2); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); dot(fDiscreteKey, "conversion-f1"); @@ -513,11 +508,10 @@ TEST(ADT, conversion) EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test operations in elimination -TEST(ADT, elimination) -{ - DiscreteKey A(0,2), B(1,3), C(2,2); +TEST(ADT, elimination) { + DiscreteKey A(0, 2), B(1, 3), C(2, 2); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); dot(f1, "elimination-f1"); @@ -525,53 +519,51 @@ TEST(ADT, elimination) // sum out lower key ADT actualSum = f1.sum(C); ADT expectedSum(A & B, "3 7 11 9 6 10"); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // - 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } { // sum out lower 2 keys ADT actualSum = f1.sum(C).sum(B); ADT expectedSum(A, 21, 25); - CHECK(assert_equal(expectedSum,actualSum)); + CHECK(assert_equal(expectedSum, actualSum)); // normalize ADT actual = f1 / actualSum; vector cpt; - cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // - 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; ADT expected(A & B & C, cpt); - CHECK(assert_equal(expected,actual)); + CHECK(assert_equal(expected, actual)); } } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test non-commutative op -TEST(ADT, div) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, div) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 8, 16); ADT b(B, 2, 4); - ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 - ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_b_div_a, b / a)); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // test zero shortcut -TEST(ADT, zero) -{ - DiscreteKey A(0,2), B(1,2); +TEST(ADT, zero) { + DiscreteKey A(0, 2), B(1, 2); // Literals ADT a(A, 0, 1); From beb3985c8c0e363b034702fa941647a0b16627f8 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 13:28:40 -0500 Subject: [PATCH 08/10] Added missing header --- gtsam/discrete/AlgebraicDecisionTree.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index 6ce36a688a..828f0b1a27 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -18,6 +18,7 @@ #pragma once +#include #include #include @@ -70,7 +71,8 @@ namespace gtsam { AlgebraicDecisionTree() : Base(1.0) {} - explicit AlgebraicDecisionTree(const Base& add) : Base(add) {} + // Explicitly non-explicit constructor + AlgebraicDecisionTree(const Base& add) : Base(add) {} /** Create a new leaf function splitting on a variable */ AlgebraicDecisionTree(const L& label, double y1, double y2) From fa1cde2f602c6aecf039d225eb35749cafa1bf6a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 13:28:56 -0500 Subject: [PATCH 09/10] Added cautionary notes about fold/visit --- gtsam/discrete/DecisionTree.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 53782ef5e3..d655756b86 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -234,6 +234,8 @@ namespace gtsam { * * @param f side-effect taking a value. * + * @note Due to pruning, leaves might not exhaust choices. + * * Example: * int sum = 0; * auto visitor = [&](int y) { sum += y; }; @@ -247,6 +249,8 @@ namespace gtsam { * * @param f side-effect taking an assignment and a value. * + * @note Due to pruning, leaves might not exhaust choices. + * * Example: * int sum = 0; * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; @@ -264,6 +268,7 @@ namespace gtsam { * @return X final value for accumulator. * * @note X is always passed by value. + * @note Due to pruning, leaves might not exhaust choices. * * Example: * auto add = [](const double& y, double x) { return y + x; }; From 8db7f250216fbe37e4d14dee1842bde7113d6722 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 22 Jan 2022 13:29:16 -0500 Subject: [PATCH 10/10] Fixed thresholding and fold example --- gtsam/discrete/tests/testDecisionTree.cpp | 26 +++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index c338bb86fa..dbfb2dc403 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -24,8 +24,8 @@ using namespace boost::assign; #include #include -//#define DT_DEBUG_MEMORY -//#define DT_NO_PRUNING +// #define DT_DEBUG_MEMORY +// #define DT_NO_PRUNING #define DISABLE_DOT #include using namespace std; @@ -349,10 +349,10 @@ TEST(DecisionTree, visitWith) { TEST(DecisionTree, fold) { // Create small two-level tree string A("A"), B("B"); - DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); + DT tree(B, DT(A, 1, 1), DT(A, 2, 3)); auto add = [](const int& y, double x) { return y + x; }; double sum = tree.fold(add, 0.0); - EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); + EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning! } /* ************************************************************************** */ @@ -365,7 +365,7 @@ TEST(DecisionTree, labels) { EXPECT_LONGS_EQUAL(2, labels.size()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ // Test unzip method. TEST(DecisionTree, unzip) { using DTP = DecisionTree>; @@ -374,15 +374,13 @@ TEST(DecisionTree, unzip) { // Create small two-level tree string A("A"), B("B"), C("C"); - DTP tree(B, - DTP(A, {0, "zero"}, {1, "one"}), - DTP(A, {2, "two"}, {1337, "l33t"}) - ); + DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}), + DTP(A, {2, "two"}, {1337, "l33t"})); DT1 dt1; DT2 dt2; std::tie(dt1, dt2) = unzip(tree); - + DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337)); DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t")); @@ -398,7 +396,7 @@ TEST(DecisionTree, threshold) { keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); DT tree(keys, "0 1 2 3 4 5 6 7"); - // Check number of elements equal to zero + // Check number of leaves equal to zero auto count = [](const int& value, int count) { return value == 0 ? count + 1 : count; }; @@ -408,9 +406,9 @@ TEST(DecisionTree, threshold) { auto threshold = [](int value) { return value < 5 ? 0 : value; }; DT thresholded(tree, threshold); - // Check number of elements equal to zero now = 5 - // TODO(frank): it is 2, because the pruned branches are counted as 1! - EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0)); + // Check number of leaves equal to zero now = 2 + // Note: it is 2, because the pruned branches are counted as 1! + EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0)); } /* ************************************************************************* */