Skip to content

Commit

Permalink
add optional model parameter to sample method
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Dec 23, 2022
1 parent 583d121 commit ffd1802
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
10 changes: 6 additions & 4 deletions gtsam/linear/GaussianBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,18 @@ namespace gtsam {
}

/* ************************************************************************ */
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng,
const SharedDiagonal& model) const {
VectorValues result; // no missing variables -> create an empty vector
return sample(result, rng);
return sample(result, rng, model);
}

VectorValues GaussianBayesNet::sample(VectorValues result,
std::mt19937_64* rng) const {
std::mt19937_64* rng,
const SharedDiagonal& model) const {
// sample each node in reverse topological sort order (parents first)
for (auto cg : boost::adaptors::reverse(*this)) {
const VectorValues sampled = cg->sample(result, rng);
const VectorValues sampled = cg->sample(result, rng, model);
result.insert(sampled);
}
return result;
Expand Down
6 changes: 4 additions & 2 deletions gtsam/linear/GaussianBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ namespace gtsam {
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
VectorValues sample(std::mt19937_64* rng,
const SharedDiagonal& model = nullptr) const;

/**
* Sample from an incomplete BayesNet, given missing variables
Expand All @@ -110,7 +111,8 @@ namespace gtsam {
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
VectorValues sample(VectorValues given, std::mt19937_64* rng,
const SharedDiagonal& model = nullptr) const;

/// Sample using ancestral sampling, use default rng
VectorValues sample() const;
Expand Down
22 changes: 15 additions & 7 deletions gtsam/linear/GaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,30 +279,38 @@ namespace gtsam {

/* ************************************************************************ */
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const {
std::mt19937_64* rng,
const SharedDiagonal& model) const {
if (nrFrontals() != 1) {
throw std::invalid_argument(
"GaussianConditional::sample can only be called on single variable "
"conditionals");
}
if (!model_) {

VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();

Vector sigmas;
if (model_) {
sigmas = model_->sigmas();
} else if (model) {
sigmas = model->sigmas();
} else {
throw std::invalid_argument(
"GaussianConditional::sample can only be called if a diagonal noise "
"model was specified at construction.");
}
VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();
const Vector& sigmas = model_->sigmas();
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
return solution;
}

VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
VectorValues GaussianConditional::sample(std::mt19937_64* rng,
const SharedDiagonal& model) const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
VectorValues values;
return sample(values);
return sample(values, rng, model);
}

/* ************************************************************************ */
Expand Down
7 changes: 4 additions & 3 deletions gtsam/linear/GaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ namespace gtsam {
* std::mt19937_64 rng(42);
* auto sample = gbn.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
VectorValues sample(std::mt19937_64* rng,
const SharedDiagonal& model = nullptr) const;

/**
* Sample from conditional, given missing variables
Expand All @@ -175,8 +176,8 @@ namespace gtsam {
* VectorValues given = ...;
* auto sample = gbn.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;
VectorValues sample(const VectorValues& parentsValues, std::mt19937_64* rng,
const SharedDiagonal& model = nullptr) const;

/// Sample, use default rng
VectorValues sample() const;
Expand Down

0 comments on commit ffd1802

Please sign in to comment.