Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
dellaert committed Jan 22, 2022
1 parent 94c692d commit ca329da
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 122 deletions.
146 changes: 76 additions & 70 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
* @author Frank Dellaert
*/

#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>

#include <boost/make_shared.hpp>
#include <boost/format.hpp>
Expand All @@ -29,42 +29,42 @@ using namespace std;

namespace gtsam {

/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor() {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() {}

/* ******************************************************************************** */
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}
const ADT& potentials)
: DiscreteFactor(keys.indices()),
ADT(potentials),
cardinalities_(keys.cardinalities()) {}

/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
: DiscreteFactor(c.keys()),
AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}

/* ************************************************************************* */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) {
/* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false;
}
else {
} else {
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}

/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
/* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}

/* ************************************************************************* */
/* ************************************************************************ */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
cout << s;
Expand All @@ -75,31 +75,32 @@ namespace gtsam {
ADT::print("", formatter);
}

/* ************************************************************************* */
/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities
ADT::Binary op) const {
map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j);
for (Key j : keys()) cs[j] = cardinality(j);
for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys
DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs)
keys.push_back(key);
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
// apply operand
ADT result = ADT::apply(f, op);
// Make a new factor
return DecisionTreeFactor(keys, result);
}

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
ADT::Binary op) const {

if (nrFrontals > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% nrFrontals % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
nrFrontals % size())
.str());

// sum over nrFrontals keys
size_t i;
Expand All @@ -113,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys;
for (; i < keys().size(); i++) {
Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}


/* ************************************************************************* */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
ADT::Binary op) const {

if (frontalKeys.size() > size()) throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
% frontalKeys.size() % size()).str());
/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
const Ordering& frontalKeys, ADT::Binary op) const {
if (frontalKeys.size() > size())
throw invalid_argument(
(boost::format(
"DecisionTreeFactor::combine: invalid number of frontal "
"keys %d, nr.keys=%d") %
frontalKeys.size() % size())
.str());

// sum over nrFrontals keys
size_t i;
Expand All @@ -137,20 +139,22 @@ namespace gtsam {
}

// create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!!
// TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) {
Key j = keys()[i];
// TODO: inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
// TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue;
dkeys.push_back(DiscreteKey(j,cardinality(j)));
dkeys.push_back(DiscreteKey(j, cardinality(j)));
}
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
}

/* ************************************************************************* */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
/* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) {
Expand All @@ -168,7 +172,7 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
/* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result;
for (auto&& key : keys()) {
Expand All @@ -180,7 +184,7 @@ namespace gtsam {
return result;
}

/* ************************************************************************* */
/* ************************************************************************ */
static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str();
}
Expand All @@ -194,8 +198,8 @@ namespace gtsam {

/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
const KeyFormatter& keyFormatter,
bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}

Expand All @@ -205,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero);
}

// Print out header.
/* ************************************************************************* */
// Print out header.
/* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
Expand Down Expand Up @@ -271,17 +275,19 @@ namespace gtsam {
return ss.str();
}

/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const vector<double>& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}

/* ************************************************************************* */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const string& table)
: DiscreteFactor(keys.indices()),
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}

/* ************************************************************************* */
} // namespace gtsam
/* ************************************************************************ */
} // namespace gtsam
Loading

0 comments on commit ca329da

Please sign in to comment.