Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logDensity and evaluate to GaussianBN and HybridBN #1352

Merged
merged 10 commits into from
Dec 29, 2022
77 changes: 45 additions & 32 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f;
}
}
Expand Down Expand Up @@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());

Expand Down Expand Up @@ -150,16 +150,11 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {

// Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);

if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();

for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should make an issue for this? There may have been a reason why I (a Lisp lover) made this imperative, so it'll be good to re-examine this now.


// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
Expand All @@ -186,24 +181,21 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscreteConditional();
return at(i)->asDiscrete();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
Copy link
Collaborator

@varunagrawal varunagrawal Dec 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is less readable just to save 1 line of code, but okay.

// If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment));

} else if (conditional->isContinuous()) {
gbn.push_back((*gm)(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian()));

} else if (conditional->isDiscrete()) {
gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
Expand All @@ -218,31 +210,55 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
discrete_bn.push_back(conditional->asDiscrete());
}
}

DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();

// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
GaussianBayesNet gbn = choose(mpe);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally like using this because:

  1. Specifies this is a class method similar to self in python.
  2. Auto-complete works better.

return HybridValues(mpe, gbn.optimize());
}

/* ************************************************************************* */
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
GaussianBayesNet gbn = choose(assignment);
return gbn.optimize();
}

/* ************************************************************************* */
double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous();

double logDensity = 0.0, probability = 1.0;

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
}
}

return probability * exp(logDensity);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given,
std::mt19937_64 *rng) const {
DiscreteBayesNet dbn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional());
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
Expand Down Expand Up @@ -273,7 +289,7 @@ HybridValues HybridBayesNet::sample() const {
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = this->choose(discreteValues);
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
}

Expand All @@ -284,23 +300,20 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(

// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this is more clever but sacrifices readability...

// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);

error_tree = error_tree + conditional_error;

} else if (conditional->isContinuous()) {
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = conditional->asGaussian()->error(continuousValues);
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });

} else if (conditional->isDiscrete()) {
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip.
continue;
}
Expand Down
8 changes: 8 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Evaluate hybrid probability density for given HybridValues.
double evaluate(const HybridValues &values) const;

/// Evaluate hybrid probability density for given HybridValues, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}

/**
* @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {

// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
Expand Down Expand Up @@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional();
this->roots_.at(0)->conditional()->asDiscrete();

DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
Expand Down
23 changes: 9 additions & 14 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,34 +131,29 @@ class GTSAM_EXPORT HybridConditional

/**
* @brief Return HybridConditional as a GaussianMixture
*
* @return GaussianMixture::shared_ptr
* @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise
*/
GaussianMixture::shared_ptr asMixture() {
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner_);
return boost::dynamic_pointer_cast<GaussianMixture>(inner_);
}

/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
* @return nullptr if not a GaussianConditional
* @return GaussianConditional::shared_ptr otherwise
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
return boost::dynamic_pointer_cast<GaussianConditional>(inner_);
}

/**
* @brief Return conditional as a DiscreteConditional
*
* @return nullptr if not a DiscreteConditional
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscreteConditional() {
if (!isDiscrete())
throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner_);
DiscreteConditional::shared_ptr asDiscrete() {
return boost::dynamic_pointer_cast<DiscreteConditional>(inner_);
}

/// @}
Expand Down
4 changes: 2 additions & 2 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ class GTSAM_EXPORT HybridValues {
/// @{

/// Return the discrete MPE assignment
DiscreteValues discrete() const { return discrete_; }
const DiscreteValues& discrete() const { return discrete_; }

/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }
const VectorValues& continuous() const { return continuous_; }

/// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };
Expand Down
Loading