Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DecisionTree Refactor #1155

Merged
merged 5 commits into from
Apr 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 61 additions & 47 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;

/** The number of assignments contained within this leaf
/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;
Expand All @@ -68,7 +68,7 @@ namespace gtsam {
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}

/** return the constant */
/// Return the constant
const Y& constant() const {
return constant_;
}
Expand All @@ -81,19 +81,19 @@ namespace gtsam {
return constant_ == q.constant_;
}

/// polymorphic equality: is q is a leaf, could be
/// polymorphic equality: is q a leaf and is it the same as this leaf?
bool sameLeaf(const Node& q) const override {
return (q.isLeaf() && q.sameLeaf(*this));
}

/** equality up to tolerance */
/// equality up to tolerance
bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false;
return compare(this->constant_, other->constant_);
}

/** print */
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
Expand Down Expand Up @@ -122,8 +122,8 @@ namespace gtsam {

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
return f;
}

Expand Down Expand Up @@ -168,7 +168,10 @@ namespace gtsam {
std::vector<NodePtr> branches_;

private:
/** incremental allSame */
/**
* Incremental allSame.
* Records if all the branches are the same leaf.
*/
size_t allSame_;

using ChoicePtr = boost::shared_ptr<const Choice>;
Expand All @@ -181,9 +184,9 @@ namespace gtsam {
#endif
}

/** If all branches of a choice node f are the same, just return a branch */
/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef DT_NO_PRUNING
#ifndef GTSAM_DT_NO_PRUNING
if (f->allSame_) {
assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0];
Expand All @@ -205,15 +208,13 @@ namespace gtsam {

bool isLeaf() const override { return false; }

/** Constructor, given choice label and mandatory expected branch count */
/// Constructor, given choice label and mandatory expected branch count.
Choice(const L& label, size_t count) :
label_(label), allSame_(true) {
branches_.reserve(count);
}

/**
* Construct from applying binary op to two Choice nodes
*/
/// Construct from applying binary op to two Choice nodes.
Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) {
// Choose what to do based on label
Expand Down Expand Up @@ -241,6 +242,7 @@ namespace gtsam {
}
}

/// Return the label of this choice node.
const L& label() const {
return label_;
}
Expand All @@ -262,7 +264,7 @@ namespace gtsam {
branches_.push_back(node);
}

/** print (as a tree) */
/// print (as a tree).
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
Expand Down Expand Up @@ -308,7 +310,7 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this));
}

/** equality */
/// equality
bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false;
Expand All @@ -321,7 +323,7 @@ namespace gtsam {
return true;
}

/** evaluate */
/// evaluate
const Y& operator()(const Assignment<L>& x) const override {
#ifndef NDEBUG
typename Assignment<L>::const_iterator it = x.find(label_);
Expand All @@ -336,13 +338,13 @@ namespace gtsam {
return (*child)(x);
}

/**
* Construct from applying unary op to a Choice node
*/
/// Construct from applying unary op to a Choice node.
Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch : f.branches_) {
push_back(branch->apply(op));
}
}

/**
Expand All @@ -353,37 +355,37 @@ namespace gtsam {
* @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and
* value as arguments.
* @param choices The Assignment that will go to op.
* @param assignment The Assignment that will go to op.
*/
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices)
const Assignment<L>& assignment)
: label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space

Assignment<L> choices_ = choices;
Assignment<L> assignment_ = assignment;

for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i
assignment_[label_] = i; // Set assignment for label to i

const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_));
push_back(branch->apply(op, assignment_));

// Remove the choice so we are backtracking
auto choice_it = choices_.find(label_);
choices_.erase(choice_it);
// Remove the assignment so we are backtracking
auto assignment_it = assignment_.find(label_);
assignment_.erase(assignment_it);
}
}

/** apply unary operator */
/// apply unary operator.
NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
const Assignment<L>& assignment) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
return Unique(r);
}

Expand Down Expand Up @@ -678,7 +680,14 @@ namespace gtsam {
}

/****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
/**
* Functor performing depth-first visit without Assignment<L> argument.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the argument then?

*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, what does this Functor do? Leaves, right? And should it not now pass the number of assignments captured by the leaf?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VisitLeaf functor passes the Leaf object which has the number of assignments recorded in it.

* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1,
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
*/
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
Expand Down Expand Up @@ -707,33 +716,36 @@ namespace gtsam {
}

/****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
/**
* Functor performing depth-first visit with Assignment<L> argument.
*
* NOTE: Follows the same pruning semantics as `visit`.
*/
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
using F = std::function<void(const Assignment<L>&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
Assignment<L> assignment; ///< Assignment, mutating through recursion.
F f; ///< folding function object.

/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());
return f(assignment, leaf->constant());

using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
assignment[choice->label()] = i; // Set assignment for label to i

(*this)(choice->branches()[i]); // recurse!

// Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label());
choices.erase(choice_it);
auto choice_it = assignment.find(choice->label());
assignment.erase(choice_it);
}
}
};
Expand Down Expand Up @@ -763,12 +775,14 @@ namespace gtsam {
}

/****************************************************************************/
// labels is just done with a visit
// Get (partial) labels by performing a visit.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain more. Maybe with example.

template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
auto f = [&](const Assignment<L>& assignment, const Y&) {
for (auto&& kv : assignment) {
unique.insert(kv.first);
}
};
visitWith(f);
return unique;
Expand Down Expand Up @@ -817,8 +831,8 @@ namespace gtsam {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
Assignment<L> choices;
return DecisionTree(root_->apply(op, choices));
Assignment<L> assignment;
return DecisionTree(root_->apply(op, assignment));
}

/****************************************************************************/
Expand Down
37 changes: 20 additions & 17 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ namespace gtsam {
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0;
const Assignment<L>& assignment) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
Expand Down Expand Up @@ -153,7 +153,7 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);

/** Create a new leaf function splitting on a variable */
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
DecisionTree(const L& label, const Y& y1, const Y& y2);

/** Allow Label+Cardinality for convenience */
Expand Down Expand Up @@ -219,9 +219,8 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/** Make virtual */
virtual ~DecisionTree() {
}
/// Make virtual
virtual ~DecisionTree() {}

/// Check if tree is empty.
bool empty() const { return !root_; }
Expand All @@ -234,11 +233,13 @@ namespace gtsam {

/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these functions passed in should now also take an argument for the number of assignments captured in a particular leaf.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added visitLeaf as an alternative to that. Should I remove that and update the rest to do this (will be API breaking)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shoot, I merged too soon! I just noticed this new visitLeaf, which in this PR seems identical??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the examples are wrong I think... We can talk about it in our meeting.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Though I don't want to spend more than 5 minutes on this since it is secondary to the workshop paper.

* @param f (side-effect) Function taking a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
Expand All @@ -249,14 +250,16 @@ namespace gtsam {

/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
* @param f (side-effect) Function taking an assignment and a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
Expand All @@ -275,7 +278,7 @@ namespace gtsam {
*
* @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
*
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/tests/testAlgebraicDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers
//#define DT_NO_PRUNING
//#define GTSAM_DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

// #define DT_DEBUG_MEMORY
// #define DT_NO_PRUNING
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>

Expand Down