Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infer_discrete to work under the PyTorch jit #1646

Merged
merged 2 commits into from
Dec 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os

import opt_einsum
import torch

if 'READTHEDOCS' not in os.environ:
Expand Down Expand Up @@ -84,4 +85,22 @@ def _einsum(equation, *operands):
return _einsum._pyro_unpatched(equation, *operands)


# This can be removed after https://github.com/dgasmith/opt_einsum/pull/77 is released.
@patch_dependency('opt_einsum.helpers.compute_size_by_dict', opt_einsum)
def _compute_size_by_dict(indices, idx_dict):
if torch._C._get_tracing_state():
# If running under the jit, convert all sizes from tensors to ints, the
# first time each idx_dict is seen.
last_idx_dict = getattr(_compute_size_by_dict, '_last_idx_dict', None)
if idx_dict is not last_idx_dict:
_compute_size_by_dict._last_idx_dict = idx_dict
for key, value in idx_dict.items():
idx_dict[key] = int(value)

ret = 1
for i in indices:
ret *= idx_dict[i]
return ret


__all__ = []
3 changes: 2 additions & 1 deletion pyro/infer/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pyro.poutine.enumerate_messenger import EnumerateMessenger
from pyro.poutine.replay_messenger import ReplayMessenger
from pyro.poutine.util import prune_subsample_sites
from pyro.util import jit_iter

_RINGS = {0: MapRing, 1: SampleRing}

Expand Down Expand Up @@ -104,7 +105,7 @@ def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs):
sample = log_prob._pyro_backward_result
if sample is not None:
new_value = packed.pack(node["value"], node["infer"]["_dim_to_symbol"])
for index, dim in zip(sample, sample._pyro_sample_dims):
for index, dim in zip(jit_iter(sample), sample._pyro_sample_dims):
if dim in new_value._pyro_dims:
index._pyro_dims = sample._pyro_dims[1:]
new_value = packed.gather(new_value, index, dim)
Expand Down
12 changes: 8 additions & 4 deletions pyro/ops/einsum/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from six import add_metaclass

from pyro.ops import packed
from pyro.util import jit_iter

SAMPLE_SYMBOL = " " # must be unique and precede alphanumeric characters

Expand Down Expand Up @@ -89,7 +90,7 @@ def einsum_backward_sample(operands, sample1, sample2):
# Slice sample1 down based on choices in sample2.
assert set(sample1._pyro_sample_dims).isdisjoint(sample2._pyro_sample_dims)
sample_dims = sample1._pyro_sample_dims + sample2._pyro_sample_dims
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
if dim in sample1._pyro_dims:
index._pyro_dims = sample2._pyro_dims[1:]
sample1 = packed.gather(sample1, index, dim)
Expand All @@ -100,7 +101,8 @@ def einsum_backward_sample(operands, sample1, sample2):
sample._pyro_dims = parts[0]._pyro_dims
sample._pyro_sample_dims = sample_dims
assert sample.dim() == len(sample._pyro_dims)
assert sample.size(0) == len(sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert sample.size(0) == len(sample._pyro_sample_dims)

# Select sample dimensions to pass on to downstream sites.
for x in operands:
Expand All @@ -122,7 +124,8 @@ def einsum_backward_sample(operands, sample1, sample2):
x_sample._pyro_dims = sample._pyro_dims
x_sample._pyro_sample_dims = x_sample_dims
assert x_sample.dim() == len(x_sample._pyro_dims)
assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert x_sample.size(0) == len(x_sample._pyro_sample_dims)
yield x._pyro_backward, x_sample


Expand All @@ -142,5 +145,6 @@ def unflatten(flat_sample, output_dims, contract_dims, contract_shape):
sample._pyro_dims = SAMPLE_SYMBOL + output_dims
sample._pyro_sample_dims = contract_dims
assert sample.dim() == len(sample._pyro_dims)
assert sample.size(0) == len(sample._pyro_sample_dims)
if not torch._C._get_tracing_state():
assert sample.size(0) == len(sample._pyro_sample_dims)
return sample
3 changes: 2 additions & 1 deletion pyro/ops/einsum/torch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.ops import packed
from pyro.ops.einsum.adjoint import Backward, einsum_backward_sample, transpose, unflatten
from pyro.ops.einsum.util import Tensordot
from pyro.util import jit_iter


class _EinsumBackward(Backward):
Expand All @@ -25,7 +26,7 @@ def process(self, message):
# Slice down operands before combining terms.
sample2 = message
if sample2 is not None:
for dim, index in zip(sample2._pyro_sample_dims, sample2):
for dim, index in zip(sample2._pyro_sample_dims, jit_iter(sample2)):
batch_dims = batch_dims.replace(dim, '')
for i, x in enumerate(operands):
if dim in x._pyro_dims:
Expand Down
12 changes: 12 additions & 0 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,18 @@ def ignore_jit_warnings(filter=None):
yield


def jit_iter(tensor):
"""
Iterate over a tensor, ignoring jit warnings.
"""
# The "Iterating over a tensor" warning is erroneously a RuntimeWarning
# so we use a custom filter here.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Iterating over a tensor")
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return list(tensor)


@contextmanager
def optional(context_manager, condition):
"""
Expand Down
36 changes: 35 additions & 1 deletion tests/infer/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import pyro.ops.jit
import pyro.poutine as poutine
from pyro.infer import (SVI, JitTrace_ELBO, JitTraceEnum_ELBO, JitTraceGraph_ELBO, JitTraceMeanField_ELBO, Trace_ELBO,
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO)
TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, infer_discrete)
from pyro.optim import Adam
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.util import ignore_jit_warnings
from tests.common import assert_equal


Expand Down Expand Up @@ -399,6 +400,39 @@ def guide(data):
svi.step(data)


@pytest.mark.parametrize('length', [1, 2, 10])
@pytest.mark.parametrize('temperature', [0, 1], ids=['map', 'sample'])
def test_discrete(temperature, length):

@ignore_jit_warnings()
def hmm(transition, means, data):
states = [torch.tensor(0)]
for t in pyro.markov(range(len(data))):
states.append(pyro.sample("states_{}".format(t),
dist.Categorical(transition[states[-1]]),
infer={"enumerate": "parallel"}))
pyro.sample("obs_{}".format(t),
dist.Normal(means[states[-1]], 1.),
obs=data[t])
return tuple(states)

hidden_dim = 10
transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim)
means = torch.arange(float(hidden_dim))
data = 1 + 2 * torch.randn(length)

decoder = infer_discrete(hmm, first_available_dim=-1, temperature=temperature)
jit_decoder = pyro.ops.jit.trace(decoder)

states = decoder(transition, means, data)
jit_states = jit_decoder(transition, means, data)
assert len(states) == len(jit_states)
for state, jit_state in zip(states, jit_states):
assert state.shape == jit_state.shape
if temperature == 0:
assert_equal(state, jit_state)


@pytest.mark.parametrize("x,y", [
(CondIndepStackFrame("a", -1, torch.tensor(2000), 2), CondIndepStackFrame("a", -1, 2000, 2)),
(CondIndepStackFrame("a", -1, 1, 2), CondIndepStackFrame("a", -1, torch.tensor(1), 2)),
Expand Down