Skip to content

Commit

Permalink
Fix infer_discrete to work under the PyTorch jit (#1646)
Browse files Browse the repository at this point in the history
* Fix infer_discrete to work under the PyTorch jit

* Fix python3 error
  • Loading branch information
fritzo authored and eb8680 committed Dec 7, 2018
1 parent 1c86dba commit a9ef595
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 7 deletions.
19 changes: 19 additions & 0 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os

import opt_einsum
import torch

if 'READTHEDOCS' not in os.environ:
Expand Down Expand Up @@ -84,4 +85,22 @@ def _einsum(equation, *operands):
return _einsum._pyro_unpatched(equation, *operands)


# This can be removed after https://github.com/dgasmith/opt_einsum/pull/77 is released.
@patch_dependency('opt_einsum.helpers.compute_size_by_dict', opt_einsum)
def _compute_size_by_dict(indices, idx_dict):
if torch._C._get_tracing_state():
# If running under the jit, convert all sizes from tensors to ints, the
# first time each idx_dict is seen.
last_idx_dict = getattr(_compute_size_by_dict, '_last_idx_dict', None)
if idx_dict is not last_idx_dict:
_compute_size_by_dict._last_idx_dict = idx_dict
for key, value in idx_dict.items():
idx_dict[key] = int(value)

ret = 1
for i in indices:
ret *= idx_dict[i]
return ret


__all__ = []
3 changes: 2 additions & 1 deletion pyro/infer/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.poutine.enumerate_messenger import EnumerateMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import jit_iter

_RINGS = {0: MapRing, 1: SampleRing}

Expand Down Expand Up @@ -104,7 +105,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs):
sample = log_prob._pyro_backward_result
if sample is not None:
new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"])
for index, dim in zip(sample, sample._pyro_sample_dims):
for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims):
if dim in new_value._pyro_dims:
index._pyro_dims = sample._pyro_dims[1:]
new_value = packed.gather(new_value, index, dim)
Expand Down
12 changes: 8 additions & 4 deletions pyro/ops/einsum/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from six import add_metaclass

from pyro.ops import packed
from pyro.util import jit_iter

SAMPLE_SYMBOL = " " # must be unique and precede alphanumeric characters

Expand Down Expand Up @@ -89,7 +90,7 @@ def einsum_backward_sample(operands, sample1, sample2):
# Slice sample1 down based on choices in sample2.
assert set(sample1._pyro_sample_dims).isdisjoint(sample2._pyro_sample_dims)
sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
if dim in sample1._pyro_dims:
index._pyro_dims = sample2._pyro_dims[1:]
sample1 = packed.gather(sample1, index, dim)
Expand All @@ -100,7 +101,8 @@ def einsum_backward_sample(operands, sample1, sample2):
sample._pyro_dims = parts[0]._pyro_dims
sample._pyro_sample_dims = sample_dims
assert sample.dim() == len(sample._pyro_dims)
assert sample.size(0) == len(sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert sample.size(0) == len(sample._pyro_sample_dims)

# Select sample dimensions to pass on to downstream sites.
for x in operands:
Expand All @@ -122,7 +124,8 @@ def einsum_backward_sample(operands, sample1, sample2):
x_sample._pyro_dims = sample._pyro_dims
x_sample._pyro_sample_dims = x_sample_dims
assert x_sample.dim() == len(x_sample._pyro_dims)
assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
yield x._pyro_backward, x_sample


Expand All @@ -142,5 +145,6 @@ def unflatten(flat_sample, output_dims, contract_dims, contract_shape):
sample._pyro_dims = SAMPLE_SYMBOL + output_dims
sample._pyro_sample_dims = contract_dims
assert sample.dim() == len(sample._pyro_dims)
assert sample.size(0) == len(sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert sample.size(0) == len(sample._pyro_sample_dims)
return sample
3 changes: 2 additions & 1 deletion pyro/ops/einsum/torch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.ops import packed
from pyro.ops.einsum.adjoint import Backward, einsum_backward_sample, transpose, unflatten
from pyro.ops.einsum.util import Tensordot
from pyro.util import jit_iter


class _EinsumBackward(Backward):
Expand All @@ -25,7 +26,7 @@ def process(self, message):
# Slice down operands before combining terms.
sample2 = message
if sample2 is not None:
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
batch_dims = batch_dims.replace(dim, '')
for i, x in enumerate(operands):
if dim in x._pyro_dims:
Expand Down
12 changes: 12 additions & 0 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def ignore_jit_warnings(filter=None):
yield


def jit_iter(tensor):
"""
Iterate over a tensor, ignoring jit warnings.
"""
# The "Iterating over a tensor" warning is erroneously a RuntimeWarning
# so we use a custom filter here.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Iterating over a tensor")
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return list(tensor)


@contextmanager
def optional(context_manager, condition):
"""
Expand Down
36 changes: 35 additions & 1 deletion tests/infer/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import pyro.ops.jit
import pyro.poutine as poutine
from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, JitTraceMeanField_ELBO, Trace_ELBO,
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO)
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, infer_discrete)
from pyro.optim import Adam
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.util import ignore_jit_warnings
from tests.common import assert_equal


Expand Down Expand Up @@ -399,6 +400,39 @@ def guide(data):
svi.step(data)


@pytest.mark.parametrize('length', [1, 2, 10])
@pytest.mark.parametrize('temperature', [0, 1], ids=['map', 'sample'])
def test_discrete(temperature, length):

@ignore_jit_warnings()
def hmm(transition, means, data):
states = [torch.tensor(0)]
for t in pyro.markov(range(len(data))):
states.append(pyro.sample("states_{}".format(t),
dist.Categorical(transition[states[-1]]),
infer={"enumerate": "parallel"}))
pyro.sample("obs_{}".format(t),
dist.Normal(means[states[-1]], 1.),
obs=data[t])
return tuple(states)

hidden_dim = 10
transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
means = torch.arange(float(hidden_dim))
data = 1 + 2 * torch.randn(length)

decoder = infer_discrete(hmm, first_available_dim=-1, temperature=temperature)
jit_decoder = pyro.ops.jit.trace(decoder)

states = decoder(transition, means, data)
jit_states = jit_decoder(transition, means, data)
assert len(states) == len(jit_states)
for state, jit_state in zip(states, jit_states):
assert state.shape == jit_state.shape
if temperature == 0:
assert_equal(state, jit_state)


@pytest.mark.parametrize("x,y", [
(CondIndepStackFrame("a", -1, torch.tensor(2000), 2), CondIndepStackFrame("a", -1, 2000, 2)),
(CondIndepStackFrame("a", -1, 1, 2), CondIndepStackFrame("a", -1, torch.tensor(1), 2)),
Expand Down

0 comments on commit a9ef595

Please sign in to comment.