diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index d997f806f5..1aa1bb17a3 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -68,7 +68,7 @@ def backward(ctx, grad_output): diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2) Sigma_inv = torch.mm(R_inv, R_inv.t()) - V, D, _ = torch.svd(Sigma_inv + jitter) + V, D, _ = torch.linalg.svd(Sigma_inv + jitter) D_outer = D.unsqueeze(-1) + D.unsqueeze(0) expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim]) diff --git a/pyro/distributions/transforms/householder.py b/pyro/distributions/transforms/householder.py index 33f9f0c4fc..47183de925 100644 --- a/pyro/distributions/transforms/householder.py +++ b/pyro/distributions/transforms/householder.py @@ -29,7 +29,7 @@ def __init__(self, u_unnormed=None): # Construct normalized vectors for Householder transform def u(self): u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed - norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True) + norm = torch.linalg.norm(u_unnormed, ord=2, dim=-1, keepdim=True) return torch.div(u_unnormed, norm) def _call(self, x): diff --git a/pyro/distributions/transforms/sylvester.py b/pyro/distributions/transforms/sylvester.py index 0e63a3470a..a115873b90 100644 --- a/pyro/distributions/transforms/sylvester.py +++ b/pyro/distributions/transforms/sylvester.py @@ -92,11 +92,11 @@ def Q(self, x): u = self.u() partial_Q = torch.eye( self.input_dim, dtype=x.dtype, layout=x.layout, device=x.device - ) - 2.0 * torch.ger(u[0], u[0]) + ) - 2.0 * torch.outer(u[0], u[0]) for idx in range(1, self.u_unnormed.size(-2)): partial_Q = torch.matmul( - partial_Q, torch.eye(self.input_dim) - 2.0 * torch.ger(u[idx], u[idx]) + partial_Q, torch.eye(self.input_dim) - 2.0 * torch.outer(u[idx], u[idx]) ) return partial_Q diff --git a/pyro/infer/autoguide/effect.py b/pyro/infer/autoguide/effect.py index 37ed99351d..b29bfdff06 100644 --- a/pyro/infer/autoguide/effect.py +++ b/pyro/infer/autoguide/effect.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from operator import attrgetter from typing import Callable, Optional, Tuple, Union import torch @@ -14,7 +15,7 @@ from pyro.poutine.runtime import get_plates from .initialization import init_to_feasible, init_to_mean -from .utils import deep_getattr, deep_setattr, helpful_support_errors +from .utils import deep_setattr, helpful_support_errors class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)): @@ -175,8 +176,8 @@ def get_posterior( def _get_params(self, name: str, prior: Distribution): try: - loc = deep_getattr(self.locs, name) - scale = deep_getattr(self.scales, name) + loc = attrgetter(name)(self.locs) + scale = attrgetter(name)(self.scales) return loc, scale except AttributeError: pass @@ -287,10 +288,10 @@ def get_posterior( def _get_params(self, name: str, prior: Distribution): try: - loc = deep_getattr(self.locs, name) - scale = deep_getattr(self.scales, name) + loc = attrgetter(name)(self.locs) + scale = attrgetter(name)(self.scales) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): - weight = deep_getattr(self.weights, name) + weight = attrgetter(name)(self.weights) return loc, scale, weight else: return loc, scale @@ -427,8 +428,8 @@ def get_posterior( def _get_params(self, name: str, prior: Distribution): try: - loc = deep_getattr(self.locs, name) - scale = deep_getattr(self.scales, name) + loc = attrgetter(name)(self.locs) + scale = attrgetter(name)(self.scales) return loc, scale except AttributeError: pass diff --git a/pyro/infer/autoguide/gaussian.py b/pyro/infer/autoguide/gaussian.py index de873471ea..a7ca708a59 100644 --- a/pyro/infer/autoguide/gaussian.py +++ b/pyro/infer/autoguide/gaussian.py @@ -5,6 +5,7 @@ from abc import ABCMeta, abstractmethod from collections import OrderedDict, defaultdict from contextlib import ExitStack +from operator import attrgetter from types import SimpleNamespace from typing import Callable, Dict, Optional, Set, Tuple, Union @@ -23,7 +24,7 @@ from .guides import AutoGuide from .initialization import InitMessenger, init_to_feasible -from .utils import deep_getattr, deep_setattr, helpful_support_errors +from .utils import deep_setattr, helpful_support_errors # Helper to dispatch to concrete subclasses of AutoGaussian, e.g. @@ -287,8 +288,8 @@ def _transform_values( for name, site in self._factors.items(): if site["is_observed"]: continue - loc = deep_getattr(self.locs, name) - scale = deep_getattr(self.scales, name) + loc = attrgetter(name)(self.locs) + scale = attrgetter(name)(self.scales) unconstrained = aux_values[name] * scale + loc # Transform to constrained space. @@ -335,7 +336,7 @@ def _setup_prototype(self, *args, **kwargs): # Create sparse -> dense precision scatter indices. self._dense_scatter = {} for d, site in self._factors.items(): - prec_sqrt_shape = deep_getattr(self.prec_sqrts, d).shape + prec_sqrt_shape = attrgetter(d)(self.prec_sqrts).shape info_vec_shape = prec_sqrt_shape[:-1] precision_shape = prec_sqrt_shape[:-1] + prec_sqrt_shape[-2:-1] index1 = torch.zeros(info_vec_shape, dtype=torch.long) @@ -425,8 +426,8 @@ def _dense_get_mvn(self): flat_info_vec = torch.zeros(self._dense_size) flat_precision = torch.zeros(self._dense_size**2) for d, (index1, index2) in self._dense_scatter.items(): - white_vec = deep_getattr(self.white_vecs, d) - prec_sqrt = deep_getattr(self.prec_sqrts, d) + white_vec = attrgetter(d)(self.white_vecs) + prec_sqrt = attrgetter(d)(self.prec_sqrts) info_vec = (prec_sqrt @ white_vec[..., None])[..., 0] precision = prec_sqrt @ prec_sqrt.transpose(-1, -2) flat_info_vec.scatter_add_(0, index1, info_vec.reshape(-1)) @@ -505,8 +506,8 @@ def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]: batch_shape = torch.Size( p.size for p in sorted(self._plates[d], key=lambda p: p.dim) ) - white_vec = deep_getattr(self.white_vecs, d) - prec_sqrt = deep_getattr(self.prec_sqrts, d) + white_vec = attrgetter(d)(self.white_vecs) + prec_sqrt = attrgetter(d)(self.prec_sqrts) factors[d] = funsor.gaussian.Gaussian( white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]), prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]), diff --git a/pyro/infer/autoguide/guides.py b/pyro/infer/autoguide/guides.py index 4fb5ab221a..40e401b561 100644 --- a/pyro/infer/autoguide/guides.py +++ b/pyro/infer/autoguide/guides.py @@ -20,6 +20,7 @@ def model(): import warnings import weakref from contextlib import ExitStack +from operator import attrgetter import torch from torch import nn @@ -38,7 +39,7 @@ def model(): from pyro.poutine.util import site_is_subsample from .initialization import InitMessenger, init_to_feasible, init_to_median -from .utils import _product, deep_getattr, deep_setattr, helpful_support_errors +from .utils import _product, deep_setattr, helpful_support_errors def prototype_hide_fn(msg): @@ -491,8 +492,8 @@ def _setup_prototype(self, *args, **kwargs): ) def _get_loc_and_scale(self, name): - site_loc = deep_getattr(self.locs, name) - site_scale = deep_getattr(self.scales, name) + site_loc = attrgetter(name)(self.locs) + site_scale = attrgetter(name)(self.scales) return site_loc, site_scale def forward(self, *args, **kwargs): diff --git a/pyro/infer/autoguide/structured.py b/pyro/infer/autoguide/structured.py index 4825030a2e..5c1108b1ac 100644 --- a/pyro/infer/autoguide/structured.py +++ b/pyro/infer/autoguide/structured.py @@ -3,6 +3,7 @@ from collections import OrderedDict, defaultdict from contextlib import ExitStack +from operator import attrgetter from types import SimpleNamespace from typing import Callable, Dict, Optional, Union @@ -19,7 +20,7 @@ from .guides import AutoGuide from .initialization import InitMessenger, init_to_feasible -from .utils import deep_getattr, deep_setattr, helpful_support_errors +from .utils import deep_setattr, helpful_support_errors def _config_auxiliary(msg): @@ -274,11 +275,11 @@ def get_deltas(self, save_params=None): # Sample zero-mean blockwise independent Delta/Normal/MVN. log_density = 0.0 - loc = deep_getattr(self.locs, name) + loc = attrgetter(name)(self.locs) zero = torch.zeros_like(loc) conditional = self.conditionals[name] if callable(conditional): - aux_value = deep_getattr(self.conds, name)() + aux_value = attrgetter(name)(self.conds)() elif conditional == "delta": aux_value = zero elif conditional == "normal": @@ -287,7 +288,7 @@ def get_deltas(self, save_params=None): dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) - scale = deep_getattr(self.scales, name) + scale = attrgetter(name)(self.scales) aux_value = aux_value * scale if compute_density: log_density = (-scale.log()).expand_as(aux_value) @@ -299,8 +300,8 @@ def get_deltas(self, save_params=None): dist.Normal(zero, 1).to_event(1), infer={"is_auxiliary": True}, ) - scale = deep_getattr(self.scales, name) - scale_tril = deep_getattr(self.scale_trils, name) + scale = attrgetter(name)(self.scales) + scale_tril = attrgetter(name)(self.scale_trils) aux_value = aux_value @ scale_tril.T * scale if compute_density: log_density = ( @@ -318,9 +319,9 @@ def get_deltas(self, save_params=None): # Note: these shear transforms have no effect on the Jacobian # determinant, and can therefore be excluded from the log_density # computation below, even for nonlinear dep(). - deps = deep_getattr(self.deps, name) + deps = attrgetter(name)(self.deps) for upstream in self.dependencies.get(name, {}): - dep = deep_getattr(deps, upstream) + dep = attrgetter(upstream)(deps) aux_value = aux_value + dep(aux_values[upstream]) aux_values[name] = aux_value @@ -368,7 +369,7 @@ def forward(self, *args, **kwargs): def median(self, *args, **kwargs): result = {} for name, site in self._sorted_sites: - loc = deep_getattr(self.locs, name).detach() + loc = attrgetter(name)(self.locs).detach() shape = self._batch_shapes[name] + self._unconstrained_event_shapes[name] loc = loc.reshape(shape) result[name] = biject_to(site["fn"].support)(loc) diff --git a/pyro/infer/autoguide/utils.py b/pyro/infer/autoguide/utils.py index 6e4177a5df..b790f9ab17 100644 --- a/pyro/infer/autoguide/utils.py +++ b/pyro/infer/autoguide/utils.py @@ -18,12 +18,6 @@ def _product(shape): return result -def deep_getattr(obj, key): - for part in key.split("."): - obj = getattr(obj, part) - return obj - - def deep_setattr(obj, key, val): """ Set an attribute `key` on the object. If any of the prefix attributes do diff --git a/pyro/ops/welford.py b/pyro/ops/welford.py index 0affd404c8..b7f1c13c08 100644 --- a/pyro/ops/welford.py +++ b/pyro/ops/welford.py @@ -33,7 +33,7 @@ def update(self, sample): if self.diagonal: self._m2 += delta_pre * delta_post else: - self._m2 += torch.ger(delta_post, delta_pre) + self._m2 += torch.outer(delta_post, delta_pre) def get_covariance(self, regularize=True): if self.n_samples < 2: @@ -72,7 +72,7 @@ def update(self, sample): self._mean = self._mean + delta_pre / self.n_samples delta_post = sample - self._mean if self.head_size > 0: - self._m2_top = self._m2_top + torch.ger( + self._m2_top = self._m2_top + torch.outer( delta_post[: self.head_size], delta_pre ) else: diff --git a/pyro/primitives.py b/pyro/primitives.py index 9b36c0f46d..6ed6862c3c 100644 --- a/pyro/primitives.py +++ b/pyro/primitives.py @@ -6,6 +6,7 @@ from collections import OrderedDict from contextlib import ExitStack, contextmanager from inspect import isclass +from operator import attrgetter from typing import Callable, Iterator, Optional, Sequence, Union import torch @@ -28,7 +29,7 @@ effectful, ) from pyro.poutine.subsample_messenger import SubsampleMessenger -from pyro.util import deep_getattr, set_rng_seed # noqa: F401 +from pyro.util import set_rng_seed # noqa: F401 def get_param_store() -> ParamStoreDict: @@ -493,7 +494,7 @@ def module( mod_name = _name if _name in target_state_dict.keys(): if not is_param: - deep_getattr(nn_module, mod_name)._parameters[param_name] = ( + attrgetter(mod_name)(nn_module)._parameters[param_name] = ( target_state_dict[_name] ) else: diff --git a/pyro/util.py b/pyro/util.py index de0482bf7e..1fa6043520 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -1,7 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import functools import math import numbers import random @@ -704,14 +703,6 @@ def ignore_experimental_warning(): yield -def deep_getattr(obj: object, name: str) -> Any: - """ - Python getattr() for arbitrarily deep attributes - Throws an AttributeError if bad attribute - """ - return functools.reduce(getattr, name.split("."), obj) - - class timed: def __enter__(self, timer=timeit.default_timer): self.start = timer() diff --git a/tests/infer/test_sampling.py b/tests/infer/test_sampling.py index bd1ef247fe..f3cf071ea9 100644 --- a/tests/infer/test_sampling.py +++ b/tests/infer/test_sampling.py @@ -78,9 +78,13 @@ def test_importance_guide(self): self.model, guide=self.guide, num_samples=5000 ).run() marginal = EmpiricalMarginal(posterior) - assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01) assert_equal( - 0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1 + 0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01 + ) + assert_equal( + 0, + torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(), + prec=0.1, ) @pytest.mark.init(rng_seed=0) @@ -89,7 +93,11 @@ def test_importance_prior(self): self.model, guide=None, num_samples=10000 ).run() marginal = EmpiricalMarginal(posterior) - assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01) assert_equal( - 0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1 + 0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01 + ) + assert_equal( + 0, + torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(), + prec=0.1, ) diff --git a/tests/ops/test_linalg.py b/tests/ops/test_linalg.py index f476a9a1f2..5b4497567d 100644 --- a/tests/ops/test_linalg.py +++ b/tests/ops/test_linalg.py @@ -5,7 +5,7 @@ import torch from pyro.ops.linalg import rinverse -from tests.common import assert_close, assert_equal +from tests.common import assert_equal @pytest.mark.parametrize( @@ -35,29 +35,3 @@ def test_sym_rinverse(A, use_sym): batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d) assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8) - - -# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular -@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated") -@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"]) -def test_triangular_solve(upper): - b = torch.randn(5, 6) - A = torch.randn(5, 5) - expected = torch.triangular_solve(b, A, upper=upper).solution - actual = torch.linalg.solve_triangular(A, b, upper=upper) - assert_close(actual, expected) - A = A.triu() if upper else A.tril() - assert_close(A @ actual, b) - - -# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular -@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated") -@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"]) -def test_triangular_solve_transpose(upper): - b = torch.randn(5, 6) - A = torch.randn(5, 5) - expected = torch.triangular_solve(b, A, upper=upper, transpose=True).solution - actual = torch.linalg.solve_triangular(A.T, b, upper=not upper) - assert_close(actual, expected) - A = A.triu() if upper else A.tril() - assert_close(A.T @ actual, b) diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 63d4a2b73c..751cfecf1e 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -25,7 +25,7 @@ def eq(x, y, prec=1e-10): - return torch.norm(x - y).item() < prec + return torch.linalg.norm(x - y).item() < prec # XXX name is a bit silly