Skip to content

Commit

Permalink
Merge pull request #1056 from borglab/feature/dt_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert authored Jan 23, 2022
2 parents 9d71c90 + 8db7f25 commit 3d86bc7
Show file tree
Hide file tree
Showing 8 changed files with 460 additions and 449 deletions.
89 changes: 43 additions & 46 deletions gtsam/discrete/AlgebraicDecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,28 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h>

#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam {

/**
* Algebraic Decision Trees fix the range to double
* Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether?
*/
template<typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> {
template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
*
* @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
*
* @param x The value passed to format.
* @return std::string
* @return std::string
*/
static std::string DefaultFormatter(const L& x) {
std::stringstream ss;
Expand All @@ -42,17 +48,12 @@ namespace gtsam {
}

public:

using Base = DecisionTree<L, double>;

/** The Real ring with addition and multiplication */
struct Ring {
static inline double zero() {
return 0.0;
}
static inline double one() {
return 1.0;
}
static inline double zero() { return 0.0; }
static inline double one() { return 1.0; }
static inline double add(const double& a, const double& b) {
return a + b;
}
Expand All @@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) {
return a / b;
}
static inline double id(const double& x) {
return x;
}
static inline double id(const double& x) { return x; }
};

AlgebraicDecisionTree() :
Base(1.0) {
}
AlgebraicDecisionTree() : Base(1.0) {}

AlgebraicDecisionTree(const Base& add) :
Base(add) {
}
// Explicitly non-explicit constructor
AlgebraicDecisionTree(const Base& add) : Base(add) {}

/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) :
Base(label, y1, y2) {
}
AlgebraicDecisionTree(const L& label, double y1, double y2)
: Base(label, y1, y2) {}

/** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
Base(labelC, y1, y2) {
}
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
double y2)
: Base(labelC, y1, y2) {}

/** Create from keys and vector table */
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::vector<double>& ys) {
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create from keys and string table */
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles
std::vector<double> ys;
std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys));
std::istream_iterator<double>(), std::back_inserter(ys));

// now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
ys.end());
this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
}

/** Create a new function splitting on a variable */
template<typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
Base(nullptr) {
template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
: Base(nullptr) {
this->root_ = compose(begin, end, label);
}

Expand All @@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L.
*/
template<typename M>
template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`.
Expand Down Expand Up @@ -160,8 +157,8 @@ namespace gtsam {

/// print method customized to value type `double`.
void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const {
auto valueFormatter = [](const double& v) {
return (boost::format("%4.4g") % v).str();
};
Expand All @@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare);
}
};
// AlgebraicDecisionTree

template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
}
// namespace gtsam
template <typename T>
struct traits<AlgebraicDecisionTree<T>>
: public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam
Loading

0 comments on commit 3d86bc7

Please sign in to comment.