-
Notifications
You must be signed in to change notification settings - Fork 768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DecisionTree Refactor #1155
DecisionTree Refactor #1155
Changes from 2 commits
d5d5ecc
e81e04a
039ecfc
dac84e9
8e6a583
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,7 +59,7 @@ namespace gtsam { | |
/** constant stored in this leaf */ | ||
Y constant_; | ||
|
||
/** The number of assignments contained within this leaf | ||
/** The number of assignments contained within this leaf. | ||
* Particularly useful when leaves have been pruned. | ||
*/ | ||
size_t nrAssignments_; | ||
|
@@ -68,7 +68,7 @@ namespace gtsam { | |
Leaf(const Y& constant, size_t nrAssignments = 1) | ||
: constant_(constant), nrAssignments_(nrAssignments) {} | ||
|
||
/** return the constant */ | ||
/// Return the constant | ||
const Y& constant() const { | ||
return constant_; | ||
} | ||
|
@@ -81,19 +81,19 @@ namespace gtsam { | |
return constant_ == q.constant_; | ||
} | ||
|
||
/// polymorphic equality: is q is a leaf, could be | ||
/// polymorphic equality: is q a leaf and is it the same as this leaf? | ||
bool sameLeaf(const Node& q) const override { | ||
return (q.isLeaf() && q.sameLeaf(*this)); | ||
} | ||
|
||
/** equality up to tolerance */ | ||
/// equality up to tolerance | ||
bool equals(const Node& q, const CompareFunc& compare) const override { | ||
const Leaf* other = dynamic_cast<const Leaf*>(&q); | ||
if (!other) return false; | ||
return compare(this->constant_, other->constant_); | ||
} | ||
|
||
/** print */ | ||
void print(const std::string& s, const LabelFormatter& labelFormatter, | ||
const ValueFormatter& valueFormatter) const override { | ||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; | ||
|
@@ -122,8 +122,8 @@ namespace gtsam { | |
|
||
/// Apply unary operator with assignment | ||
NodePtr apply(const UnaryAssignment& op, | ||
const Assignment<L>& choices) const override { | ||
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); | ||
const Assignment<L>& assignment) const override { | ||
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); | ||
return f; | ||
} | ||
|
||
|
@@ -168,7 +168,10 @@ namespace gtsam { | |
std::vector<NodePtr> branches_; | ||
|
||
private: | ||
/** incremental allSame */ | ||
/** | ||
* Incremental allSame. | ||
* Records if all the branches are the same leaf. | ||
*/ | ||
size_t allSame_; | ||
|
||
using ChoicePtr = boost::shared_ptr<const Choice>; | ||
|
@@ -181,9 +184,9 @@ namespace gtsam { | |
#endif | ||
} | ||
|
||
/** If all branches of a choice node f are the same, just return a branch */ | ||
/// If all branches of a choice node f are the same, just return a branch. | ||
static NodePtr Unique(const ChoicePtr& f) { | ||
#ifndef DT_NO_PRUNING | ||
#ifndef GTSAM_DT_NO_PRUNING | ||
if (f->allSame_) { | ||
assert(f->branches().size() > 0); | ||
NodePtr f0 = f->branches_[0]; | ||
|
@@ -205,15 +208,13 @@ namespace gtsam { | |
|
||
bool isLeaf() const override { return false; } | ||
|
||
/** Constructor, given choice label and mandatory expected branch count */ | ||
/// Constructor, given choice label and mandatory expected branch count. | ||
Choice(const L& label, size_t count) : | ||
label_(label), allSame_(true) { | ||
branches_.reserve(count); | ||
} | ||
|
||
/** | ||
* Construct from applying binary op to two Choice nodes | ||
*/ | ||
/// Construct from applying binary op to two Choice nodes. | ||
Choice(const Choice& f, const Choice& g, const Binary& op) : | ||
allSame_(true) { | ||
// Choose what to do based on label | ||
|
@@ -241,6 +242,7 @@ namespace gtsam { | |
} | ||
} | ||
|
||
/// Return the label of this choice node. | ||
const L& label() const { | ||
return label_; | ||
} | ||
|
@@ -262,7 +264,7 @@ namespace gtsam { | |
branches_.push_back(node); | ||
} | ||
|
||
/** print (as a tree) */ | ||
/// print (as a tree). | ||
void print(const std::string& s, const LabelFormatter& labelFormatter, | ||
const ValueFormatter& valueFormatter) const override { | ||
std::cout << s << " Choice("; | ||
|
@@ -308,7 +310,7 @@ namespace gtsam { | |
return (q.isLeaf() && q.sameLeaf(*this)); | ||
} | ||
|
||
/** equality */ | ||
/// equality | ||
bool equals(const Node& q, const CompareFunc& compare) const override { | ||
const Choice* other = dynamic_cast<const Choice*>(&q); | ||
if (!other) return false; | ||
|
@@ -321,7 +323,7 @@ namespace gtsam { | |
return true; | ||
} | ||
|
||
/** evaluate */ | ||
/// evaluate | ||
const Y& operator()(const Assignment<L>& x) const override { | ||
#ifndef NDEBUG | ||
typename Assignment<L>::const_iterator it = x.find(label_); | ||
|
@@ -336,13 +338,13 @@ namespace gtsam { | |
return (*child)(x); | ||
} | ||
|
||
/** | ||
* Construct from applying unary op to a Choice node | ||
*/ | ||
/// Construct from applying unary op to a Choice node. | ||
Choice(const L& label, const Choice& f, const Unary& op) : | ||
label_(label), allSame_(true) { | ||
branches_.reserve(f.branches_.size()); // reserve space | ||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); | ||
for (const NodePtr& branch : f.branches_) { | ||
push_back(branch->apply(op)); | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -353,37 +355,37 @@ namespace gtsam { | |
* @param f The original choice node to apply the op on. | ||
* @param op Function to apply on the choice node. Takes Assignment and | ||
* value as arguments. | ||
* @param choices The Assignment that will go to op. | ||
* @param assignment The Assignment that will go to op. | ||
*/ | ||
Choice(const L& label, const Choice& f, const UnaryAssignment& op, | ||
const Assignment<L>& choices) | ||
const Assignment<L>& assignment) | ||
: label_(label), allSame_(true) { | ||
branches_.reserve(f.branches_.size()); // reserve space | ||
|
||
Assignment<L> choices_ = choices; | ||
Assignment<L> assignment_ = assignment; | ||
|
||
for (size_t i = 0; i < f.branches_.size(); i++) { | ||
choices_[label_] = i; // Set assignment for label to i | ||
assignment_[label_] = i; // Set assignment for label to i | ||
|
||
const NodePtr branch = f.branches_[i]; | ||
push_back(branch->apply(op, choices_)); | ||
push_back(branch->apply(op, assignment_)); | ||
|
||
// Remove the choice so we are backtracking | ||
auto choice_it = choices_.find(label_); | ||
choices_.erase(choice_it); | ||
// Remove the assignment so we are backtracking | ||
auto assignment_it = assignment_.find(label_); | ||
assignment_.erase(assignment_it); | ||
} | ||
} | ||
|
||
/** apply unary operator */ | ||
/// apply unary operator. | ||
NodePtr apply(const Unary& op) const override { | ||
auto r = boost::make_shared<Choice>(label_, *this, op); | ||
return Unique(r); | ||
} | ||
|
||
/// Apply unary operator with assignment | ||
NodePtr apply(const UnaryAssignment& op, | ||
const Assignment<L>& choices) const override { | ||
auto r = boost::make_shared<Choice>(label_, *this, op, choices); | ||
const Assignment<L>& assignment) const override { | ||
auto r = boost::make_shared<Choice>(label_, *this, op, assignment); | ||
return Unique(r); | ||
} | ||
|
||
|
@@ -678,7 +680,14 @@ namespace gtsam { | |
} | ||
|
||
/****************************************************************************/ | ||
// Functor performing depth-first visit without Assignment<L> argument. | ||
/** | ||
* Functor performing depth-first visit without Assignment<L> argument. | ||
* | ||
* NOTE: We differentiate between leaves and assignments. Concretely, a 3 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, what does this Functor do? Leaves, right? And should it not now pass the number of assignments captured by the leaf? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
* binary variable tree will have 2^3=8 assignments, but based on pruning, it | ||
* can have <8 leaves. For example, if a tree has all assignment values as 1, | ||
* then pruning will cause the tree to have only 1 leaf yet 8 assignments. | ||
*/ | ||
template <typename L, typename Y> | ||
struct Visit { | ||
using F = std::function<void(const Y&)>; | ||
|
@@ -707,33 +716,36 @@ namespace gtsam { | |
} | ||
|
||
/****************************************************************************/ | ||
// Functor performing depth-first visit with Assignment<L> argument. | ||
/** | ||
* Functor performing depth-first visit with Assignment<L> argument. | ||
* | ||
* NOTE: Follows the same pruning semantics as `visit`. | ||
*/ | ||
template <typename L, typename Y> | ||
struct VisitWith { | ||
using Choices = Assignment<L>; | ||
using F = std::function<void(const Choices&, const Y&)>; | ||
using F = std::function<void(const Assignment<L>&, const Y&)>; | ||
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. | ||
Choices choices; ///< Assignment, mutating through recursion. | ||
F f; ///< folding function object. | ||
Assignment<L> assignment; ///< Assignment, mutating through recursion. | ||
F f; ///< folding function object. | ||
|
||
/// Do a depth-first visit on the tree rooted at node. | ||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { | ||
using Leaf = typename DecisionTree<L, Y>::Leaf; | ||
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) | ||
return f(choices, leaf->constant()); | ||
return f(assignment, leaf->constant()); | ||
|
||
using Choice = typename DecisionTree<L, Y>::Choice; | ||
auto choice = boost::dynamic_pointer_cast<const Choice>(node); | ||
if (!choice) | ||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); | ||
for (size_t i = 0; i < choice->nrChoices(); i++) { | ||
choices[choice->label()] = i; // Set assignment for label to i | ||
assignment[choice->label()] = i; // Set assignment for label to i | ||
|
||
(*this)(choice->branches()[i]); // recurse! | ||
|
||
// Remove the choice so we are backtracking | ||
auto choice_it = choices.find(choice->label()); | ||
choices.erase(choice_it); | ||
auto choice_it = assignment.find(choice->label()); | ||
assignment.erase(choice_it); | ||
} | ||
} | ||
}; | ||
|
@@ -763,12 +775,14 @@ namespace gtsam { | |
} | ||
|
||
/****************************************************************************/ | ||
// labels is just done with a visit | ||
// Get (partial) labels by performing a visit. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explain more. Maybe with example. |
||
template <typename L, typename Y> | ||
std::set<L> DecisionTree<L, Y>::labels() const { | ||
std::set<L> unique; | ||
auto f = [&](const Assignment<L>& choices, const Y&) { | ||
for (auto&& kv : choices) unique.insert(kv.first); | ||
auto f = [&](const Assignment<L>& assignment, const Y&) { | ||
for (auto&& kv : assignment) { | ||
unique.insert(kv.first); | ||
} | ||
}; | ||
visitWith(f); | ||
return unique; | ||
|
@@ -817,8 +831,8 @@ namespace gtsam { | |
throw std::runtime_error( | ||
"DecisionTree::apply(unary op) undefined for empty tree."); | ||
} | ||
Assignment<L> choices; | ||
return DecisionTree(root_->apply(op, choices)); | ||
Assignment<L> assignment; | ||
return DecisionTree(root_->apply(op, assignment)); | ||
} | ||
|
||
/****************************************************************************/ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,7 +105,7 @@ namespace gtsam { | |
virtual const Y& operator()(const Assignment<L>& x) const = 0; | ||
virtual Ptr apply(const Unary& op) const = 0; | ||
virtual Ptr apply(const UnaryAssignment& op, | ||
const Assignment<L>& choices) const = 0; | ||
const Assignment<L>& assignment) const = 0; | ||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; | ||
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; | ||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; | ||
|
@@ -153,7 +153,7 @@ namespace gtsam { | |
/** Create a constant */ | ||
explicit DecisionTree(const Y& y); | ||
|
||
/** Create a new leaf function splitting on a variable */ | ||
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` | ||
DecisionTree(const L& label, const Y& y1, const Y& y2); | ||
|
||
/** Allow Label+Cardinality for convenience */ | ||
|
@@ -219,9 +219,8 @@ namespace gtsam { | |
/// @name Standard Interface | ||
/// @{ | ||
|
||
/** Make virtual */ | ||
virtual ~DecisionTree() { | ||
} | ||
/// Make virtual | ||
virtual ~DecisionTree() {} | ||
|
||
/// Check if tree is empty. | ||
bool empty() const { return !root_; } | ||
|
@@ -234,11 +233,13 @@ namespace gtsam { | |
|
||
/** | ||
* @brief Visit all leaves in depth-first fashion. | ||
* | ||
* @param f side-effect taking a value. | ||
* | ||
* @note Due to pruning, leaves might not exhaust choices. | ||
* | ||
* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think these functions passed in should now also take an argument for the number of assignments captured in a particular leaf. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shoot, I merged too soon! I just noticed this new visitLeaf, which in this PR seems identical?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And the examples are wrong I think... We can talk about it in our meeting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay! Though I don't want to spend more than 5 minutes on this since it is secondary to the workshop paper. |
||
* @param f (side-effect) Function taking a value. | ||
* | ||
* @note Due to pruning, the number of leaves may not be the same as the | ||
* number of assignments. E.g. if we have a tree on 2 binary variables with | ||
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf. | ||
* | ||
* Example: | ||
* int sum = 0; | ||
* auto visitor = [&](int y) { sum += y; }; | ||
|
@@ -249,14 +250,16 @@ namespace gtsam { | |
|
||
/** | ||
* @brief Visit all leaves in depth-first fashion. | ||
* | ||
* @param f side-effect taking an assignment and a value. | ||
* | ||
* @note Due to pruning, leaves might not exhaust choices. | ||
* | ||
* | ||
* @param f (side-effect) Function taking an assignment and a value. | ||
* | ||
* @note Due to pruning, the number of leaves may not be the same as the | ||
* number of assignments. E.g. if we have a tree on 2 binary variables with | ||
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf. | ||
* | ||
* Example: | ||
* int sum = 0; | ||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; | ||
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; }; | ||
* tree.visitWith(visitor); | ||
*/ | ||
template <typename Func> | ||
|
@@ -275,7 +278,7 @@ namespace gtsam { | |
* | ||
* @note X is always passed by value. | ||
* @note Due to pruning, leaves might not exhaust choices. | ||
* | ||
* | ||
* Example: | ||
* auto add = [](const double& y, double x) { return y + x; }; | ||
* double sum = tree.fold(add, 0.0); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the argument then?