Skip to content

Commit

Permalink
Use provenance tracking to compute downstream costs in `TraceGraph_EL…
Browse files Browse the repository at this point in the history
…BO` (pyro-ppl#3081)

* use provenance to compute downstream costs

* fix comments

* rm test_compute_downstream_costs

* update the docstring

* fix typo

* update the docstring

* add tests back

* fix air example
  • Loading branch information
ordabayevy authored May 31, 2022
1 parent e2aad5e commit 2b653e0
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 77 deletions.
2 changes: 1 addition & 1 deletion examples/air/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def count_vec_to_mat(vec, max_index):
true_counts_m = count_vec_to_mat(true_counts_batch, 2)
inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
counts += torch.mm(true_counts_m.t(), inferred_counts_m)
error_ind = 1 - (true_counts_batch == inferred_counts)
error_ind = 1 - (true_counts_batch == inferred_counts).long()
error_ix = error_ind.nonzero(as_tuple=False).squeeze()
error_latents.append(
latents_to_tensor((z_where, z_pres)).index_select(0, error_ix)
Expand Down
115 changes: 80 additions & 35 deletions pyro/infer/tracegraph_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import weakref
from collections import defaultdict
from operator import itemgetter

import torch
Expand All @@ -18,6 +19,9 @@
torch_backward,
torch_item,
)
from pyro.ops.provenance import detach_provenance, get_provenance, track_provenance
from pyro.poutine.messenger import Messenger
from pyro.poutine.subsample_messenger import _Subsample
from pyro.util import check_if_enumerated, warn_if_nan


Expand Down Expand Up @@ -172,7 +176,7 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes): #
return downstream_costs, downstream_guide_cost_nodes


def _compute_elbo_reparam(model_trace, guide_trace):
def _compute_elbo(model_trace, guide_trace):

# In ref [1], section 3.2, the part of the surrogate loss computed here is
# \sum{cost}, which in this case is the ELBO. Instead of using the ELBO,
Expand All @@ -182,12 +186,18 @@ def _compute_elbo_reparam(model_trace, guide_trace):

elbo = 0.0
surrogate_elbo = 0.0
baseline_loss = 0.0
# mapping from non-reparameterizable sample sites to cost terms influenced by each of them
downstream_costs = defaultdict(lambda: MultiFrameTensor())

# Bring log p(x, z|...) terms into both the ELBO and the surrogate
for name, site in model_trace.nodes.items():
if site["type"] == "sample":
elbo += site["log_prob_sum"]
surrogate_elbo += site["log_prob_sum"]
# add the log_prob to each non-reparam sample site upstream
for key in get_provenance(site["log_prob_sum"]):
downstream_costs[key].add((site["cond_indep_stack"], site["log_prob"]))

# Bring log q(z|...) terms into the ELBO, and effective terms into the
# surrogate. Depending on the parameterization of a site, its log q(z|...)
Expand All @@ -202,19 +212,16 @@ def _compute_elbo_reparam(model_trace, guide_trace):
# For fully non-reparameterized terms, it is zero
if not is_identically_zero(entropy_term):
surrogate_elbo -= entropy_term.sum()
# add the -log_prob to each non-reparam sample site upstream
for key in get_provenance(site["log_prob_sum"]):
downstream_costs[key].add((site["cond_indep_stack"], -site["log_prob"]))

return elbo, surrogate_elbo


def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs):
# construct all the reinforce-like terms.
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
surrogate_elbo = 0.0
baseline_loss = 0.0
for node in non_reparam_nodes:
for node, downstream_cost in downstream_costs.items():
guide_site = guide_trace.nodes[node]
downstream_cost = downstream_costs[node]
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function

use_baseline, baseline_loss_term, baseline = _construct_baseline(
Expand All @@ -227,7 +234,59 @@ def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs):

surrogate_elbo += (score_function * downstream_cost.detach()).sum()

return surrogate_elbo, baseline_loss
surrogate_loss = -surrogate_elbo + baseline_loss
return detach_provenance(elbo), detach_provenance(surrogate_loss)


class TrackNonReparam(Messenger):
"""
Track non-reparameterizable sample sites.
**References:**
1. *Nonstandard Interpretations of Probabilistic Programs for Efficient Inference*,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
**Example:**
.. doctest::
>>> import torch
>>> import pyro
>>> import pyro.distributions as dist
>>> from pyro.infer.tracegraph_elbo import TrackNonReparam
>>> from pyro.ops.provenance import get_provenance
>>> from pyro.poutine import trace
>>> def model():
... probs_a = torch.tensor([0.3, 0.7])
... probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])
... probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])
... a = pyro.sample("a", dist.Categorical(probs_a))
... b = pyro.sample("b", dist.Categorical(probs_b[a]))
... pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0))
>>> with TrackNonReparam():
... model_tr = trace(model).get_trace()
>>> model_tr.compute_log_prob()
>>> print(get_provenance(model_tr.nodes["a"]["log_prob"])) # doctest: +SKIP
frozenset({'a'})
>>> print(get_provenance(model_tr.nodes["b"]["log_prob"])) # doctest: +SKIP
frozenset({'b', 'a'})
>>> print(get_provenance(model_tr.nodes["c"]["log_prob"])) # doctest: +SKIP
frozenset({'b', 'a'})
"""

def _pyro_post_sample(self, msg):
if (
msg["type"] == "sample"
and not isinstance(msg["fn"], _Subsample)
and not msg["is_observed"]
and not getattr(msg["fn"], "has_rsample", False)
):
provenance = frozenset({msg["name"]})
msg["value"] = track_provenance(msg["value"], provenance)


class TraceGraph_ELBO(ELBO):
Expand All @@ -236,13 +295,10 @@ class TraceGraph_ELBO(ELBO):
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide as well as baselines for non-reparameterizable random variables.
Where possible, conditional dependency information as recorded in the
Fine-grained conditional dependency information as recorded in the
:class:`~pyro.poutine.trace.Trace` is used to reduce the variance of the gradient estimator.
In particular two kinds of conditional dependency information are
used to reduce variance:
- the sequential order of samples (z is sampled after y => y does not depend on z)
- :class:`~pyro.plate` generators
In particular provenance tracking [3] is used to find the ``cost`` terms
that depend on each non-reparameterizable sample site.
References
Expand All @@ -251,16 +307,20 @@ class TraceGraph_ELBO(ELBO):
[2] `Neural Variational Inference and Learning in Belief Networks`
Andriy Mnih, Karol Gregor
[3] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
"""

def _get_trace(self, model, guide, args, kwargs):
"""
Returns a single trace from the guide, and the model that is run
against it.
"""
model_trace, guide_trace = get_importance_trace(
"dense", self.max_plate_nesting, model, guide, args, kwargs
)
with TrackNonReparam():
model_trace, guide_trace = get_importance_trace(
"dense", self.max_plate_nesting, model, guide, args, kwargs
)
if is_validation_enabled():
check_if_enumerated(guide_trace)
return model_trace, guide_trace
Expand Down Expand Up @@ -319,22 +379,7 @@ def _loss_and_surrogate_loss(self, model, guide, args, kwargs):

def _loss_and_surrogate_loss_particle(self, model_trace, guide_trace):

# compute elbo for reparameterized nodes
elbo, surrogate_elbo = _compute_elbo_reparam(model_trace, guide_trace)
baseline_loss = 0.0

# the following computations are only necessary if we have non-reparameterizable nodes
non_reparam_nodes = set(guide_trace.nonreparam_stochastic_nodes)
if non_reparam_nodes:
downstream_costs, _ = _compute_downstream_costs(
model_trace, guide_trace, non_reparam_nodes
)
surrogate_elbo_term, baseline_loss = _compute_elbo_non_reparam(
guide_trace, non_reparam_nodes, downstream_costs
)
surrogate_elbo += surrogate_elbo_term

surrogate_loss = -surrogate_elbo + baseline_loss
elbo, surrogate_loss = _compute_elbo(model_trace, guide_trace)

return elbo, surrogate_loss

Expand Down
6 changes: 6 additions & 0 deletions pyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ def _track_provenance_list(x, provenance: frozenset):
return type(x)(track_provenance(part, provenance) for part in x)


@track_provenance.register
def _track_provenance_provenancetensor(x: ProvenanceTensor, provenance: frozenset):
x_value, old_provenance = extract_provenance(x)
return track_provenance(x_value, old_provenance | provenance)


@singledispatch
def extract_provenance(x) -> Tuple[object, frozenset]:
"""
Expand Down
Loading

0 comments on commit 2b653e0

Please sign in to comment.