Skip to content

Commit

Permalink
Fixed computing marginals in BayesTree
Browse files Browse the repository at this point in the history
  • Loading branch information
richardroberts committed Oct 11, 2010
1 parent ccea5c7 commit 96eb939
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 136 deletions.
120 changes: 65 additions & 55 deletions inference/BayesTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,27 +329,27 @@ namespace gtsam {
return p_S_R;
}

// /* ************************************************************************* */
// // P(C) = \int_R P(F|S) P(S|R) P(R)
// // TODO: Maybe we should integrate given parent marginal P(Cp),
// // \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
// // Because the root clique could be very big.
// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// FactorGraph<Factor>
// BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
// // If we are the root, just return this root
// if (R.get()==this) return *R;
//
// // Combine P(F|S), P(S|R), and P(R)
// BayesNet<Conditional> p_FSR = this->shortcut<Factor>(R);
// p_FSR.push_front(*this);
// p_FSR.push_back(*R);
//
// // Find marginal on the keys we are interested in
// return marginalize<Factor,Conditional>(p_FSR,keys());
// }
/* ************************************************************************* */
// P(C) = \int_R P(F|S) P(S|R) P(R)
// TODO: Maybe we should integrate given parent marginal P(Cp),
// \int(Cp\S) P(F|S)P(S|Cp)P(Cp)
// Because the root clique could be very big.
/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
FactorGraph
BayesTree<Conditional>::Clique::marginal(shared_ptr R) {
// If we are the root, just return this root
if (R.get()==this) return *R;

// Combine P(F|S), P(S|R), and P(R)
BayesNet<Conditional> p_FSR = this->shortcut<FactorGraph>(R);
p_FSR.push_front(*this);
p_FSR.push_back(*R);

// Find marginal on the keys we are interested in
return FactorGraph(*Inference::Marginal(FactorGraph(p_FSR), keys()));
}

// /* ************************************************************************* */
// // P(C1,C2) = \int_R P(F1|S1) P(S1|R) P(F2|S1) P(S2|R) P(R)
Expand Down Expand Up @@ -676,41 +676,51 @@ namespace gtsam {
}
}

// /* ************************************************************************* */
// // First finds clique marginal then marginalizes that
// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// FactorGraph<Factor>
// BayesTree<Conditional>::marginal(varid_t key) const {
//
// // get clique containing key
// sharedClique clique = (*this)[key];
//
// // calculate or retrieve its marginal
// FactorGraph<Factor> cliqueMarginal = clique->marginal<Factor>(root_);
//
// // create an ordering where only the requested key is not eliminated
// vector<varid_t> ord = clique->keys();
// ord.remove(key);
//
// // partially eliminate, remaining factor graph is requested marginal
// eliminate<Factor,Conditional>(cliqueMarginal,ord);
// return cliqueMarginal;
// }
/* ************************************************************************* */
// First finds clique marginal then marginalizes that
/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
FactorGraph
BayesTree<Conditional>::marginal(varid_t key) const {

// get clique containing key
sharedClique clique = (*this)[key];

// calculate or retrieve its marginal
FactorGraph cliqueMarginal = clique->marginal<FactorGraph>(root_);

// Reorder so that only the requested key is not eliminated
typename FactorGraph::variableindex_type varIndex(cliqueMarginal);
vector<varid_t> keyAsVector(1); keyAsVector[0] = key;
Permutation toBack(Permutation::PushToBack(keyAsVector, varIndex.size()));
Permutation::shared_ptr toBackInverse(toBack.inverse());
varIndex.permute(toBack);
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
factor->permuteWithInverse(*toBackInverse);
}

// /* ************************************************************************* */
// template<class Conditional>
// template<class Factor>
// BayesNet<Conditional>
// BayesTree<Conditional>::marginalBayesNet(varid_t key) const {
//
// // calculate marginal as a factor graph
// FactorGraph<Factor> fg = this->marginal<Factor>(key);
//
// // eliminate further to Bayes net
// return eliminate<Factor,Conditional>(fg,Ordering(key));
// }
// partially eliminate, remaining factor graph is requested marginal
Inference::EliminateUntil(cliqueMarginal, varIndex.size()-1, varIndex);
BOOST_FOREACH(const typename FactorGraph::sharedFactor& factor, cliqueMarginal) {
if(factor)
factor->permuteWithInverse(toBack);
}
return cliqueMarginal;
}

/* ************************************************************************* */
template<class Conditional>
template<class FactorGraph>
BayesNet<Conditional>
BayesTree<Conditional>::marginalBayesNet(varid_t key) const {

// calculate marginal as a factor graph
FactorGraph fg = this->marginal<FactorGraph>(key);

// eliminate further to Bayes net
return *Inference::Eliminate(fg);
}

// /* ************************************************************************* */
// // Find two cliques, their joint, then marginalizes
Expand Down
24 changes: 12 additions & 12 deletions inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ namespace gtsam {
template<class FactorGraph>
BayesNet<Conditional> shortcut(shared_ptr root);

// /** return the marginal P(C) of the clique */
// template<class Factor>
// FactorGraph<Factor> marginal(shared_ptr root);
//
/** return the marginal P(C) of the clique */
template<class FactorGraph>
FactorGraph marginal(shared_ptr root);

// /** return the joint P(C1,C2), where C1==this. TODO: not a method? */
// template<class Factor>
// std::pair<FactorGraph<Factor>,Ordering> joint(shared_ptr C2, shared_ptr root);
Expand Down Expand Up @@ -245,14 +245,14 @@ namespace gtsam {
/** Gather data on all cliques */
CliqueData getCliqueData() const;

// /** return marginal on any variable */
// template<class Factor>
// FactorGraph<Factor> marginal(varid_t key) const;
//
// /** return marginal on any variable, as a Bayes Net */
// template<class Factor>
// BayesNet<Conditional> marginalBayesNet(varid_t key) const;
//
/** return marginal on any variable */
template<class FactorGraph>
FactorGraph marginal(varid_t key) const;

/** return marginal on any variable, as a Bayes Net */
template<class FactorGraph>
BayesNet<Conditional> marginalBayesNet(varid_t key) const;

// /** return joint on two variables */
// template<class Factor>
// FactorGraph<Factor> joint(varid_t key1, varid_t key2) const;
Expand Down
31 changes: 30 additions & 1 deletion inference/Permutation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
// Mask of which variables have been pulled, used to reorder
vector<bool> pulled(size, false);

// Put the pulled variables at the front of the permutation and set up the
// pulled flags.
for(varid_t j=0; j<toFront.size(); ++j) {
ret[j] = toFront[j];
pulled[toFront[j]] = true;
assert(toFront[j] < size);
}

// Fill in the rest of the variables
varid_t nextVar = toFront.size();
for(varid_t j=0; j<size; ++j)
if(!pulled[j])
Expand All @@ -45,6 +47,33 @@ Permutation Permutation::PullToFront(const vector<varid_t>& toFront, size_t size
return ret;
}

/* ************************************************************************* */
Permutation Permutation::PushToBack(const std::vector<varid_t>& toBack, size_t size) {

Permutation ret(size);

// Mask of which variables have been pushed, used to reorder
vector<bool> pushed(size, false);

// Put the pushed variables at the back of the permutation and set up the
// pushed flags;
varid_t nextVar = size - toBack.size();
for(varid_t j=0; j<toBack.size(); ++j) {
ret[nextVar++] = toBack[j];
pushed[toBack[j]] = true;
}
assert(nextVar == size);

// Fill in the rest of the variables
nextVar = 0;
for(varid_t j=0; j<size; ++j)
if(!pushed[j])
ret[nextVar++] = j;
assert(nextVar == size - toBack.size());

return ret;
}

/* ************************************************************************* */
Permutation::shared_ptr Permutation::permute(const Permutation& permutation) const {
const size_t nVars = permutation.size();
Expand Down
6 changes: 6 additions & 0 deletions inference/Permutation.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class Permutation {
*/
static Permutation PullToFront(const std::vector<varid_t>& toFront, size_t size);

/**
* Create a permutation that pulls the given variables to the front while
* pushing the rest to the back.
*/
static Permutation PushToBack(const std::vector<varid_t>& toBack, size_t size);

iterator begin() { return rangeIndices_.begin(); }
const_iterator begin() const { return rangeIndices_.begin(); }
iterator end() { return rangeIndices_.end(); }
Expand Down
11 changes: 11 additions & 0 deletions linear/VectorValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class VectorValues : public Testable<VectorValues> {
template<class Container>
VectorValues(const Container& dimensions);

/** Construct to hold nVars vectors of varDim dimension each. */
VectorValues(varid_t nVars, size_t varDim);

/** Construct from a container of variable dimensions in variable order and
* a combined Vector of all of the variables in order.
*/
Expand Down Expand Up @@ -179,6 +182,14 @@ inline VectorValues::VectorValues(const Container& dimensions) : varStarts_(dime
values_.resize(varStarts_.back(), false);
}

inline VectorValues::VectorValues(varid_t nVars, size_t varDim) : varStarts_(nVars+1) {
varStarts_[0] = 0;
size_t varStart = 0;
for(varid_t j=1; j<=nVars; ++j)
varStarts_[j] = (varStart += varDim);
values_.resize(varStarts_.back(), false);
}

inline VectorValues::VectorValues(const std::vector<size_t>& dimensions, const Vector& values) :
values_(values), varStarts_(dimensions.size()+1) {
varStarts_[0] = 0;
Expand Down
Loading

0 comments on commit 96eb939

Please sign in to comment.