Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infer_discrete to work under the PyTorch jit #1646

Merged
merged 2 commits into from
Dec 7, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix infer_discrete to work under the PyTorch jit
  • Loading branch information
fritzo committed Dec 6, 2018
commit f8c00d0c0f65f3c4d631e52869587b291af27496
19 changes: 19 additions & 0 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os

import opt_einsum
import torch

if 'READTHEDOCS' not in os.environ:
Expand Down Expand Up @@ -84,4 +85,22 @@ def _einsum(equation, *operands):
return _einsum._pyro_unpatched(equation, *operands)


# This can be removed after https://github.com/dgasmith/opt_einsum/pull/77 is released.
@patch_dependency('opt_einsum.helpers.compute_size_by_dict', opt_einsum)
def _compute_size_by_dict(indices, idx_dict):
if torch._C._get_tracing_state():
# If running under the jit, convert all sizes from tensors to ints, the
# first time each idx_dict is seen.
last_idx_dict = getattr(_compute_size_by_dict, '_last_idx_dict', None)
if idx_dict is not last_idx_dict:
_compute_size_by_dict._last_idx_dict = idx_dict
for key, value in idx_dict.items():
idx_dict[key] = int(value)

ret = 1
for i in indices:
ret *= idx_dict[i]
return ret


__all__ = []
3 changes: 2 additions & 1 deletion pyro/infer/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.poutine.enumerate_messenger import EnumerateMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import jit_iter

_RINGS = {0: MapRing, 1: SampleRing}

Expand Down Expand Up @@ -104,7 +105,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs):
sample = log_prob._pyro_backward_result
if sample is not None:
new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"])
for index, dim in zip(sample, sample._pyro_sample_dims):
for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims):
if dim in new_value._pyro_dims:
index._pyro_dims = sample._pyro_dims[1:]
new_value = packed.gather(new_value, index, dim)
Expand Down
3 changes: 2 additions & 1 deletion pyro/ops/einsum/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from six import add_metaclass

from pyro.ops import packed
from pyro.util import jit_iter

SAMPLE_SYMBOL = " " # must be unique and precede alphanumeric characters

Expand Down Expand Up @@ -89,7 +90,7 @@ def einsum_backward_sample(operands, sample1, sample2):
# Slice sample1 down based on choices in sample2.
assert set(sample1._pyro_sample_dims).isdisjoint(sample2._pyro_sample_dims)
sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
if dim in sample1._pyro_dims:
index._pyro_dims = sample2._pyro_dims[1:]
sample1 = packed.gather(sample1, index, dim)
Expand Down
3 changes: 2 additions & 1 deletion pyro/ops/einsum/torch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.ops import packed
from pyro.ops.einsum.adjoint import Backward, einsum_backward_sample, transpose, unflatten
from pyro.ops.einsum.util import Tensordot
from pyro.util import jit_iter


class _EinsumBackward(Backward):
Expand All @@ -25,7 +26,7 @@ def process(self, message):
# Slice down operands before combining terms.
sample2 = message
if sample2 is not None:
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
batch_dims = batch_dims.replace(dim, '')
for i, x in enumerate(operands):
if dim in x._pyro_dims:
Expand Down
12 changes: 12 additions & 0 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def ignore_jit_warnings(filter=None):
yield


def jit_iter(tensor):
"""
Iterate over a tensor, ignoring jit warnings.
"""
# The "Iterating over a tensor" warning is erroneously a RuntimeWarning
# so we use a custom filter here.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Iterating over a tensor")
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return list(tensor)


@contextmanager
def optional(context_manager, condition):
"""
Expand Down
36 changes: 35 additions & 1 deletion tests/infer/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import pyro.ops.jit
import pyro.poutine as poutine
from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, JitTraceMeanField_ELBO, Trace_ELBO,
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO)
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, infer_discrete)
from pyro.optim import Adam
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.util import ignore_jit_warnings
from tests.common import assert_equal


Expand Down Expand Up @@ -399,6 +400,39 @@ def guide(data):
svi.step(data)


@pytest.mark.parametrize('length', [1, 2, 10])
@pytest.mark.parametrize('temperature', [0, 1], ids=['map', 'sample'])
def test_discrete(temperature, length):

@ignore_jit_warnings()
def hmm(transition, means, data):
states = [torch.tensor(0)]
for t in pyro.markov(range(len(data))):
states.append(pyro.sample("states_{}".format(t),
dist.Categorical(transition[states[-1]]),
infer={"enumerate": "parallel"}))
pyro.sample("obs_{}".format(t),
dist.Normal(means[states[-1]], 1.),
obs=data[t])
return tuple(states)

hidden_dim = 10
transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
means = torch.arange(float(hidden_dim))
data = 1 + 2 * torch.randn(length)

decoder = infer_discrete(hmm, first_available_dim=-1, temperature=temperature)
jit_decoder = pyro.ops.jit.trace(decoder)

states = decoder(transition, means, data)
jit_states = jit_decoder(transition, means, data)
assert len(states) == len(jit_states)
for state, jit_state in zip(states, jit_states):
assert state.shape == jit_state.shape
if temperature == 0:
assert_equal(state, jit_state)


@pytest.mark.parametrize("x,y", [
(CondIndepStackFrame("a", -1, torch.tensor(2000), 2), CondIndepStackFrame("a", -1, 2000, 2)),
(CondIndepStackFrame("a", -1, 1, 2), CondIndepStackFrame("a", -1, torch.tensor(1), 2)),
Expand Down