diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 1e8f5aa3e7..ef4cc48f69 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -17,9 +17,9 @@ * @author Frank Dellaert */ +#include #include #include -#include #include #include @@ -29,42 +29,42 @@ using namespace std; namespace gtsam { - /* ******************************************************************************** */ - DecisionTreeFactor::DecisionTreeFactor() { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor() {} - /* ******************************************************************************** */ + /* ************************************************************************ */ DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, - const ADT& potentials) : - DiscreteFactor(keys.indices()), ADT(potentials), - cardinalities_(keys.cardinalities()) { - } + const ADT& potentials) + : DiscreteFactor(keys.indices()), + ADT(potentials), + cardinalities_(keys.cardinalities()) {} - /* *************************************************************************/ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : - DiscreteFactor(c.keys()), AlgebraicDecisionTree(c), cardinalities_(c.cardinalities_) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) + : DiscreteFactor(c.keys()), + AlgebraicDecisionTree(c), + cardinalities_(c.cardinalities_) {} - /* ************************************************************************* */ - bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { - if(!dynamic_cast(&other)) { + /* ************************************************************************ */ + bool DecisionTreeFactor::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { return false; - } - else { + } else { const auto& f(static_cast(other)); return ADT::equals(f, tol); } } - /* ************************************************************************* */ - double DecisionTreeFactor::safe_div(const double &a, const double &b) { + /* ************************************************************************ */ + double DecisionTreeFactor::safe_div(const double& a, const double& b) { // The use for safe_div is when we divide the product factor by the sum // factor. If the product or sum is zero, we accord zero probability to the // event. return (a == 0 || b == 0) ? 0 : (a / b); } - /* ************************************************************************* */ + /* ************************************************************************ */ void DecisionTreeFactor::print(const string& s, const KeyFormatter& formatter) const { cout << s; @@ -75,31 +75,32 @@ namespace gtsam { ADT::print("", formatter); } - /* ************************************************************************* */ + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, - ADT::Binary op) const { - map cs; // new cardinalities + ADT::Binary op) const { + map cs; // new cardinalities // make unique key-cardinality map - for(Key j: keys()) cs[j] = cardinality(j); - for(Key j: f.keys()) cs[j] = f.cardinality(j); + for (Key j : keys()) cs[j] = cardinality(j); + for (Key j : f.keys()) cs[j] = f.cardinality(j); // Convert map into keys DiscreteKeys keys; - for(const std::pair& key: cs) - keys.push_back(key); + for (const std::pair& key : cs) keys.push_back(key); // apply operand ADT result = ADT::apply(f, op); // Make a new factor return DecisionTreeFactor(keys, result); } - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, - ADT::Binary op) const { - - if (nrFrontals > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % nrFrontals % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + size_t nrFrontals, ADT::Binary op) const { + if (nrFrontals > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + nrFrontals % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -113,20 +114,21 @@ namespace gtsam { DiscreteKeys dkeys; for (; i < keys().size(); i++) { Key j = keys()[i]; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - - /* ************************************************************************* */ - DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, - ADT::Binary op) const { - - if (frontalKeys.size() > size()) throw invalid_argument( - (boost::format( - "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") - % frontalKeys.size() % size()).str()); + /* ************************************************************************ */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( + const Ordering& frontalKeys, ADT::Binary op) const { + if (frontalKeys.size() > size()) + throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal " + "keys %d, nr.keys=%d") % + frontalKeys.size() % size()) + .str()); // sum over nrFrontals keys size_t i; @@ -137,20 +139,22 @@ namespace gtsam { } // create new factor, note we collect keys that are not in frontalKeys - // TODO: why do we need this??? result should contain correct keys!!! + // TODO(frank): why do we need this??? result should contain correct keys!!! DiscreteKeys dkeys; for (i = 0; i < keys().size(); i++) { Key j = keys()[i]; - // TODO: inefficient! - if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) + // TODO(frank): inefficient! + if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != + frontalKeys.end()) continue; - dkeys.push_back(DiscreteKey(j,cardinality(j))); + dkeys.push_back(DiscreteKey(j, cardinality(j))); } return boost::make_shared(dkeys, result); } - /* ************************************************************************* */ - std::vector> DecisionTreeFactor::enumerate() const { + /* ************************************************************************ */ + std::vector> DecisionTreeFactor::enumerate() + const { // Get all possible assignments std::vector> pairs; for (auto& key : keys()) { @@ -168,7 +172,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ DiscreteKeys DecisionTreeFactor::discreteKeys() const { DiscreteKeys result; for (auto&& key : keys()) { @@ -180,7 +184,7 @@ namespace gtsam { return result; } - /* ************************************************************************* */ + /* ************************************************************************ */ static std::string valueFormatter(const double& v) { return (boost::format("%4.2g") % v).str(); } @@ -194,8 +198,8 @@ namespace gtsam { /** output to graphviz format, open a file */ void DecisionTreeFactor::dot(const std::string& name, - const KeyFormatter& keyFormatter, - bool showZero) const { + const KeyFormatter& keyFormatter, + bool showZero) const { ADT::dot(name, keyFormatter, valueFormatter, showZero); } @@ -205,8 +209,8 @@ namespace gtsam { return ADT::dot(keyFormatter, valueFormatter, showZero); } - // Print out header. - /* ************************************************************************* */ + // Print out header. + /* ************************************************************************ */ string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, const Names& names) const { stringstream ss; @@ -271,17 +275,19 @@ namespace gtsam { return ss.str(); } - /* ************************************************************************* */ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector &table) : - DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const vector& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} - /* ************************************************************************* */ - DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : - DiscreteFactor(keys.indices()), AlgebraicDecisionTree(keys, table), - cardinalities_(keys.cardinalities()) { - } + /* ************************************************************************ */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const string& table) + : DiscreteFactor(keys.indices()), + AlgebraicDecisionTree(keys, table), + cardinalities_(keys.cardinalities()) {} - /* ************************************************************************* */ -} // namespace gtsam + /* ************************************************************************ */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 751b8c62c4..91fa7c4849 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -18,16 +18,18 @@ #pragma once +#include #include #include -#include #include +#include #include - -#include -#include +#include #include +#include +#include +#include namespace gtsam { @@ -36,21 +38,19 @@ namespace gtsam { /** * A discrete probabilistic factor */ - class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree { - - public: - + class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor, + public AlgebraicDecisionTree { + public: // typedefs needed to play nice with gtsam typedef DecisionTreeFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class + typedef DiscreteFactor Base; ///< Typedef to base class typedef boost::shared_ptr shared_ptr; typedef AlgebraicDecisionTree ADT; - protected: - std::map cardinalities_; - - public: + protected: + std::map cardinalities_; + public: /// @name Standard Constructors /// @{ @@ -61,7 +61,8 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); /** Constructor from doubles */ - DecisionTreeFactor(const DiscreteKeys& keys, const std::vector& table); + DecisionTreeFactor(const DiscreteKeys& keys, + const std::vector& table); /** Constructor from string */ DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); @@ -86,7 +87,8 @@ namespace gtsam { bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; // print - void print(const std::string& s = "DecisionTreeFactor:\n", + void print( + const std::string& s = "DecisionTreeFactor:\n", const KeyFormatter& formatter = DefaultKeyFormatter) const override; /// @} @@ -105,7 +107,7 @@ namespace gtsam { static double safe_div(const double& a, const double& b); - size_t cardinality(Key j) const { return cardinalities_.at(j);} + size_t cardinality(Key j) const { return cardinalities_.at(j); } /// divide by factor f (safely) DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { @@ -113,9 +115,7 @@ namespace gtsam { } /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override { - return *this; - } + DecisionTreeFactor toDecisionTreeFactor() const override { return *this; } /// Create new factor by summing all values with the same separator values shared_ptr sum(size_t nrFrontals) const { @@ -164,27 +164,6 @@ namespace gtsam { */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; - -// /** -// * @brief Permutes the keys in Potentials and DiscreteFactor -// * -// * This re-implements the permuteWithInverse() in both Potentials -// * and DiscreteFactor by doing both of them together. -// */ -// -// void permuteWithInverse(const Permutation& inversePermutation){ -// DiscreteFactor::permuteWithInverse(inversePermutation); -// Potentials::permuteWithInverse(inversePermutation); -// } -// -// /** -// * Apply a reduction, which is a remapping of variable indices. -// */ -// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) { -// DiscreteFactor::reduceWithInverse(inverseReduction); -// Potentials::reduceWithInverse(inverseReduction); -// } - /// Enumerate all values into a map from values to double. std::vector> enumerate() const; @@ -194,16 +173,16 @@ namespace gtsam { /// @} /// @name Wrapper support /// @{ - + /** output to graphviz format, stream version */ void dot(std::ostream& os, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - bool showZero = true) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; /** output to graphviz format, open a file */ void dot(const std::string& name, - const KeyFormatter& keyFormatter = DefaultKeyFormatter, - bool showZero = true) const; + const KeyFormatter& keyFormatter = DefaultKeyFormatter, + bool showZero = true) const; /** output to graphviz format string */ std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, @@ -217,7 +196,7 @@ namespace gtsam { * @return std::string a markdown string. */ std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + const Names& names = {}) const override; /** * @brief Render as html table @@ -227,14 +206,13 @@ namespace gtsam { * @return std::string a html string. */ std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + const Names& names = {}) const override; /// @} - -}; -// DecisionTreeFactor + }; // traits -template<> struct traits : public Testable {}; +template <> +struct traits : public Testable {}; -}// namespace gtsam +} // namespace gtsam