Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapped classes in discrete #967

Merged
merged 23 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e022084
Added wrapper files
dellaert Oct 5, 2021
055d8c7
Added WIP python test
dellaert Oct 5, 2021
c2f7827
Merge remote-tracking branch 'origin/develop' into feature/discrete_w…
ProfFan Oct 8, 2021
64bbc79
Add wrapping and tests
ProfFan Oct 8, 2021
f50f963
Add main
dellaert Oct 27, 2021
7891154
Merge branch 'develop' into feature/discrete_wrapper
dellaert Dec 6, 2021
16672da
Merge branch 'develop' into feature/discrete_wrapper
dellaert Dec 13, 2021
02dbcb4
Get rid of "and" business
dellaert Dec 13, 2021
44b4f21
Merge branch 'feature/DiscreteValues' into feature/discrete_wrapper
dellaert Dec 13, 2021
e22f389
Added value, for wrapper
dellaert Dec 14, 2021
ebc37ee
Wrapped more DiscreteFactorGraph functionality
dellaert Dec 14, 2021
f593428
Use evaluate not value
dellaert Dec 15, 2021
fd7640b
Simplified parsing as we moved on from this boost version
dellaert Dec 15, 2021
4e5530b
New, non-fancy constructors
dellaert Dec 15, 2021
8f4b15b
Added chooseAsFactor method for wrapper
dellaert Dec 16, 2021
a4dab12
Wrapped and test Discrete Bayes Nets
dellaert Dec 16, 2021
995e7a5
add default constructor for DiscreteKeys and minor improvements
varunagrawal Dec 16, 2021
fefa991
Add operators
dellaert Dec 16, 2021
b2e3654
Add documentation and test for it
dellaert Dec 16, 2021
7257797
Wrap () operators
dellaert Dec 16, 2021
6bcd129
Attempt at fixing CI issue
dellaert Dec 16, 2021
7401b6e
Merge branch 'feature/discrete_wrapper' into feature/discrete_wrapper_2
varunagrawal Dec 16, 2021
93978cf
Merge pull request #968 from borglab/feature/discrete_wrapper_2
dellaert Dec 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@ namespace gtsam {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
// void DiscreteBayesNet::add_front(const Signature& s) {
// push_front(boost::make_shared<DiscreteConditional>(s));
// }

/* ************************************************************************* */
void DiscreteBayesNet::add(const Signature& s) {
push_back(boost::make_shared<DiscreteConditional>(s));
}

/* ************************************************************************* */
double DiscreteBayesNet::evaluate(const DiscreteValues & values) const {
// evaluate all conditionals and multiply
Expand Down
13 changes: 8 additions & 5 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ namespace gtsam {
/// @name Standard Interface
/// @{

/** Add a DiscreteCondtional */
void add(const Signature& s);

// /** Add a DiscreteCondtional in front, when listing parents first*/
// GTSAM_EXPORT void add_front(const Signature& s);
// Add inherited versions of add.
using Base::add;

/** Add a DiscreteCondtional */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<DiscreteConditional>(std::forward<Args>(args)...);
}

//** evaluate for given DiscreteValues */
double evaluate(const DiscreteValues & values) const;

Expand Down
30 changes: 23 additions & 7 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,41 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}

/* ******************************************************************************** */
Potentials::ADT DiscreteConditional::choose(const DiscreteValues& parentsValues) const {
Potentials::ADT DiscreteConditional::choose(
const DiscreteValues& parentsValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
ADT pFS(*this);
Key j; size_t value;
for(Key key: parents()) {
size_t value;
for (Key j : parents()) {
try {
j = (key);
value = parentsValues.at(j);
pFS = pFS.choose(j, value);
pFS = pFS.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
cout << "Key: " << j << " Value: " << value << endl;
parentsValues.print("parentsValues: ");
// pFS.print("pFS: ");
throw runtime_error("DiscreteConditional::choose: parent value missing");
};
}

return pFS;
}

/* ******************************************************************************** */
DecisionTreeFactor::shared_ptr DiscreteConditional::chooseAsFactor(
const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues);

// Convert ADT to factor.
if (nrFrontals() != 1) {
throw std::runtime_error("Expected only one frontal variable in choose.");
}
DiscreteKeys keys;
const Key frontalKey = keys_[0];
size_t frontalCardinality = this->cardinality(frontalKey);
keys.push_back(DiscreteKey(frontalKey, frontalCardinality));
return boost::make_shared<DecisionTreeFactor>(keys, pFS);
}

/* ******************************************************************************** */
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// TODO: Abhijit asks: is this really the fastest way? He thinks it is.
Expand Down
27 changes: 27 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
/** Construct from signature */
DiscreteConditional(const Signature& signature);

/**
* Construct from key, parents, and a Table specifying the CPT.
dellaert marked this conversation as resolved.
Show resolved Hide resolved
*
* The first string is parsed to add a key and parents.
*
* Example: DiscreteConditional P(D, {B,E}, table);
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const Signature::Table& table)
: DiscreteConditional(Signature(key, parents, table)) {}

/**
* Construct from key, parents, and a string specifying the CPT.
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example: DiscreteConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
DiscreteConditional(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: DiscreteConditional(Signature(key, parents, spec)) {}

/** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */
DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal);
Expand Down Expand Up @@ -111,6 +134,10 @@ class GTSAM_EXPORT DiscreteConditional: public DecisionTreeFactor,
/** Restrict to given parent values, returns AlgebraicDecisionDiagram */
ADT choose(const DiscreteValues& parentsValues) const;

/** Restrict to given parent values, returns DecisionTreeFactor */
DecisionTreeFactor::shared_ptr chooseAsFactor(
const DiscreteValues& parentsValues) const;

/**
* solve a conditional
* @param parentsValues Known values of the parents
Expand Down
3 changes: 3 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ class GTSAM_EXPORT DiscreteFactor: public Factor {
/// Find value for given assignment of values to variables
virtual double operator()(const DiscreteValues&) const = 0;

/// Synonym for operator(), mostly for wrapper
double evaluate(const DiscreteValues& values) const { return operator()(values); }

/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;

Expand Down
10 changes: 8 additions & 2 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,14 @@ public EliminateableFactorGraph<DiscreteFactorGraph> {
/** return product of all factors as a single factor */
DecisionTreeFactor product() const;

/** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/
double operator()(const DiscreteValues & values) const;
/**
* Evaluates the factor graph given values, returns the joint probability of
* the factor graph given specific instantiation of values
*/
double operator()(const DiscreteValues& values) const;

/// Synonym for operator(), mostly for wrapper
double evaluate(const DiscreteValues& values) const { return operator()(values); }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrapper supports operator overloading. I am pretty sure we added support for the callable operator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The syntax would be double operator()(const gtsam::DiscreteValues& values) const;

I used this in the DecisionTreeFactor and it works as expected. 🙂

In [1]: f = gtsam.DecisionTreeFactor()

In [2]: f
Out[2]: 
DecisionTreeFactor
Potentials:
  Cardinalities: {}
  Leaf 1

In [3]: f()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-c43e34e6d405> in <module>
----> 1 f()

TypeError: __call__(): incompatible function arguments. The following argument types are supported:
    1. (self: gtsam.gtsam.DiscreteFactor, arg0: gtsam::Assignment<unsigned long>) -> float

Invoked with: DecisionTreeFactor
Potentials:
  Cardinalities: {}
  Leaf 1


In [4]: values = gtsam.DiscreteValues()

In [5]: f(values)
Out[5]: 1.0

In [6]: 


/// print
void print(
Expand Down
5 changes: 2 additions & 3 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ namespace gtsam {
/// DiscreteKeys is a set of keys that can be assembled using the & operator
struct DiscreteKeys: public std::vector<DiscreteKey> {

/// Default constructor
DiscreteKeys() {
}
// Forward all constructors.
using std::vector<DiscreteKey>::vector;

/// Construct from a key
DiscreteKeys(const DiscreteKey& key) {
Expand Down
66 changes: 18 additions & 48 deletions gtsam/discrete/Signature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,7 @@ namespace gtsam {
using boost::phoenix::push_back;

// Special rows, true and false
Signature::Row createF() {
Signature::Row r(2);
r[0] = 1;
r[1] = 0;
return r;
}
Signature::Row createT() {
Signature::Row r(2);
r[0] = 0;
r[1] = 1;
return r;
}
Signature::Row T = createT(), F = createF();
Signature::Row F{1, 0}, T{0, 1};

// Special tables (inefficient, but do we care for user input?)
Signature::Table logic(bool ff, bool ft, bool tf, bool tt) {
Expand All @@ -69,40 +57,13 @@ namespace gtsam {
table = or_ | and_ | rows;
or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)];
and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)];
rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42
rows = +(row | true_ | false_);
row = qi::double_ >> +("/" >> qi::double_);
true_ = qi::lit("T")[qi::_val = T];
false_ = qi::lit("F")[qi::_val = F];
}
} grammar;

// Create simpler parsing function to avoid the issue of only parsing a single row
bool parse_table(const string& spec, Signature::Table& table) {
// check for OR, AND on whole phrase
It f = spec.begin(), l = spec.end();
if (qi::parse(f, l,
qi::lit("OR")[ph::ref(table) = logic(false, true, true, true)]) ||
qi::parse(f, l,
qi::lit("AND")[ph::ref(table) = logic(false, false, false, true)]))
return true;

// tokenize into separate rows
istringstream iss(spec);
string token;
while (iss >> token) {
Signature::Row values;
It tf = token.begin(), tl = token.end();
bool r = qi::parse(tf, tl,
qi::double_[push_back(ph::ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ph::ref(values), qi::_1)]) |
qi::lit("T")[ph::ref(values) = T] |
qi::lit("F")[ph::ref(values) = F] );
if (!r)
return false;
table.push_back(values);
}

return true;
}
} // \namespace parser

ostream& operator <<(ostream &os, const Signature::Row &row) {
Expand All @@ -118,6 +79,18 @@ namespace gtsam {
return os;
}

Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table)
: key_(key), parents_(parents) {
operator=(table);
}

Signature::Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec)
: key_(key), parents_(parents) {
operator=(spec);
}

Signature::Signature(const DiscreteKey& key) :
key_(key) {
}
Expand Down Expand Up @@ -166,14 +139,11 @@ namespace gtsam {
Signature& Signature::operator=(const string& spec) {
spec_.reset(spec);
Table table;
// NOTE: using simpler parse function to ensure boost back compatibility
// parser::It f = spec.begin(), l = spec.end();
bool success = //
// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar
parser::parse_table(spec, table);
parser::It f = spec.begin(), l = spec.end();
bool success =
qi::phrase_parse(f, l, parser::grammar.table, qi::space, table);
if (success) {
for(Row& row: table)
normalize(row);
for (Row& row : table) normalize(row);
table_.reset(table);
}
return *this;
Expand Down
104 changes: 62 additions & 42 deletions gtsam/discrete/Signature.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ namespace gtsam {
* T|A = "99/1 95/5"
* L|S = "99/1 90/10"
* B|S = "70/30 40/60"
* E|T,L = "F F F 1"
* (E|T,L) = "F F F 1"
* X|E = "95/5 2/98"
* D|E,B = "9/1 2/8 3/7 1/9"
* (D|E,B) = "9/1 2/8 3/7 1/9"
*/
class GTSAM_EXPORT Signature {

Expand All @@ -72,45 +72,66 @@ namespace gtsam {
boost::optional<Table> table_;

public:

/** Constructor from DiscreteKey */
Signature(const DiscreteKey& key);

/** the variable key */
const DiscreteKey& key() const {
return key_;
}

/** the parent keys */
const DiscreteKeys& parents() const {
return parents_;
}

/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;

/** All key indices, with variable key first */
KeyVector indices() const;

// the CPT as parsed, if successful
const boost::optional<Table>& table() const {
return table_;
}

// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;

/** Add a parent */
Signature& operator,(const DiscreteKey& parent);

/** Add the CPT spec - Fails in boost 1.40 */
Signature& operator=(const std::string& spec);

/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);

/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator <<(std::ostream &os, const Signature &s);
/**
* Construct from key, parents, and a Table specifying the CPT.
dellaert marked this conversation as resolved.
Show resolved Hide resolved
*
* The first string is parsed to add a key and parents.
*
* Example: Signature sig(D, {B,E}, table);
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const Table& table);

/**
* Construct from key, parents, and a string specifying the CPT.
*
* The first string is parsed to add a key and parents. The second string
* parses into a table.
*
* Example: Signature sig(D, {B,E}, "9/1 2/8 3/7 1/9");
*/
Signature(const DiscreteKey& key, const DiscreteKeys& parents,
const std::string& spec);

/**
* Construct from a single DiscreteKey.
*
* The resulting signature has no parents or CPT table. Typical use then
* either adds parents with | and , operators below, or assigns a table with
* operator=().
*/
Signature(const DiscreteKey& key);

/** the variable key */
const DiscreteKey& key() const { return key_; }

/** the parent keys */
const DiscreteKeys& parents() const { return parents_; }

/** All keys, with variable key first */
DiscreteKeys discreteKeys() const;

/** All key indices, with variable key first */
KeyVector indices() const;

// the CPT as parsed, if successful
const boost::optional<Table>& table() const { return table_; }

// the CPT as a vector of doubles, with key's values most rapidly changing
std::vector<double> cpt() const;

/** Add a parent */
Signature& operator,(const DiscreteKey& parent);

/** Add the CPT spec */
Signature& operator=(const std::string& spec);

/** Add the CPT spec directly as a table */
Signature& operator=(const Table& table);

/** provide streaming */
GTSAM_EXPORT friend std::ostream& operator<<(std::ostream& os,
const Signature& s);
};

/**
Expand All @@ -122,7 +143,6 @@ namespace gtsam {
/**
* Helper function to create Signature objects
* example: Signature s(D % "99/1");
* Uses string parser, which requires BOOST 1.42 or higher
*/
GTSAM_EXPORT Signature operator%(const DiscreteKey& key, const std::string& parent);

Expand Down
Loading