From eadf00a427e9491a7a52493691284da3f23470a1 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Mon, 16 Sep 2024 18:44:03 -0400 Subject: [PATCH] Type annotate `pyro.util.py` (#3393) * type annotate util.py * exclude ipynb --- pyproject.toml | 1 + pyro/util.py | 146 +++++++++++++++++++++++++++++++++++++++---------- setup.cfg | 4 -- 3 files changed, 119 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4ae021c471..b8b2241c80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 [tool.ruff] +extend-exclude = ["*.ipynb"] line-length = 120 diff --git a/pyro/util.py b/pyro/util.py index 6c89e8fa26..de0482bf7e 100644 --- a/pyro/util.py +++ b/pyro/util.py @@ -11,14 +11,31 @@ from collections import defaultdict from contextlib import contextmanager from itertools import zip_longest +from typing import ( + TYPE_CHECKING, + Any, + Dict, + FrozenSet, + List, + Optional, + Set, + Union, + overload, +) import numpy as np import torch from pyro.poutine.util import site_is_subsample +if TYPE_CHECKING: + from pyro.distributions.torch_distribution import TorchDistributionMixin + from pyro.poutine.indep_messenger import CondIndepStackFrame + from pyro.poutine.runtime import Message + from pyro.poutine.trace import Trace -def set_rng_seed(rng_seed): + +def set_rng_seed(rng_seed: int) -> None: """ Sets seeds of `torch` and `torch.cuda` (if available). @@ -29,7 +46,7 @@ def set_rng_seed(rng_seed): np.random.seed(rng_seed) -def get_rng_state(): +def get_rng_state() -> Dict[str, Any]: return { "torch": torch.get_rng_state(), "random": random.getstate(), @@ -37,7 +54,7 @@ def get_rng_state(): } -def set_rng_state(state): +def set_rng_state(state: Dict[str, Any]) -> None: torch.set_rng_state(state["torch"]) random.setstate(state["random"]) if "numpy" in state: @@ -46,7 +63,11 @@ def set_rng_state(state): np.random.set_state(state["numpy"]) -def torch_isnan(x): +@overload +def torch_isnan(x: numbers.Number) -> bool: ... +@overload +def torch_isnan(x: torch.Tensor) -> torch.Tensor: ... +def torch_isnan(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]: """ A convenient function to check if a Tensor contains any nan; also works with numbers """ @@ -55,7 +76,11 @@ def torch_isnan(x): return torch.isnan(x).any() -def torch_isinf(x): +@overload +def torch_isinf(x: numbers.Number) -> bool: ... +@overload +def torch_isinf(x: torch.Tensor) -> torch.Tensor: ... +def torch_isinf(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]: """ A convenient function to check if a Tensor contains any +inf; also works with numbers """ @@ -64,7 +89,29 @@ def torch_isinf(x): return (x == math.inf).any() or (x == -math.inf).any() -def warn_if_nan(value, msg="", *, filename=None, lineno=None): +@overload +def warn_if_nan( + value: numbers.Number, + msg: str = "", + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> numbers.Number: ... +@overload +def warn_if_nan( + value: torch.Tensor, + msg: str = "", + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> torch.Tensor: ... +def warn_if_nan( + value: Union[torch.Tensor, numbers.Number], + msg: str = "", + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> Union[torch.Tensor, numbers.Number]: """ A convenient function to warn if a Tensor or its grad contains any nan, also works with numbers. @@ -79,7 +126,7 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): filename = frame.f_code.co_filename lineno = frame.f_lineno - if torch.is_tensor(value) and value.requires_grad: + if isinstance(value, torch.Tensor) and value.requires_grad: value.register_hook( lambda x: warn_if_nan( x, "backward " + msg, filename=filename, lineno=lineno @@ -87,6 +134,7 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): ) if torch_isnan(value): + assert isinstance(lineno, int) warnings.warn_explicit( "Encountered NaN{}".format(": " + msg if msg else "."), UserWarning, @@ -97,9 +145,35 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None): return value +@overload +def warn_if_inf( + value: numbers.Number, + msg: str = "", + allow_posinf: bool = False, + allow_neginf: bool = False, + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> numbers.Number: ... +@overload +def warn_if_inf( + value: torch.Tensor, + msg: str = "", + allow_posinf: bool = False, + allow_neginf: bool = False, + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> torch.Tensor: ... def warn_if_inf( - value, msg="", allow_posinf=False, allow_neginf=False, *, filename=None, lineno=None -): + value: Union[torch.Tensor, numbers.Number], + msg: str = "", + allow_posinf: bool = False, + allow_neginf: bool = False, + *, + filename: Optional[str] = None, + lineno: Optional[int] = None, +) -> Union[torch.Tensor, numbers.Number]: """ A convenient function to warn if a Tensor or its grad contains any inf, also works with numbers. @@ -114,7 +188,7 @@ def warn_if_inf( filename = frame.f_code.co_filename lineno = frame.f_lineno - if torch.is_tensor(value) and value.requires_grad: + if isinstance(value, torch.Tensor) and value.requires_grad: value.register_hook( lambda x: warn_if_inf( x, @@ -131,6 +205,7 @@ def warn_if_inf( if isinstance(value, numbers.Number) else (value == math.inf).any() ): + assert isinstance(lineno, int) warnings.warn_explicit( "Encountered +inf{}".format(": " + msg if msg else "."), UserWarning, @@ -142,6 +217,7 @@ def warn_if_inf( if isinstance(value, numbers.Number) else (value == -math.inf).any() ): + assert isinstance(lineno, int) warnings.warn_explicit( "Encountered -inf{}".format(": " + msg if msg else "."), UserWarning, @@ -152,7 +228,7 @@ def warn_if_inf( return value -def save_visualization(trace, graph_output): +def save_visualization(trace: "Trace", graph_output: str) -> None: """ DEPRECATED Use :func:`pyro.infer.inspect.render_model()` instead. @@ -206,7 +282,7 @@ def save_visualization(trace, graph_output): g.render(graph_output, view=False, cleanup=True) -def check_traces_match(trace1, trace2): +def check_traces_match(trace1: "Trace", trace2: "Trace") -> None: """ :param pyro.poutine.Trace trace1: Trace object of the model :param pyro.poutine.Trace trace2: Trace object of the guide @@ -236,7 +312,9 @@ def check_traces_match(trace1, trace2): ) -def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf): +def check_model_guide_match( + model_trace: "Trace", guide_trace: "Trace", max_plate_nesting: float = math.inf +) -> None: """ :param pyro.poutine.Trace model_trace: Trace object of the model :param pyro.poutine.Trace guide_trace: Trace object of the guide @@ -385,19 +463,20 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf ) -def check_site_shape(site, max_plate_nesting): +def check_site_shape(site: "Message", max_plate_nesting: int) -> None: actual_shape = list(site["log_prob"].shape) # Compute expected shape. - expected_shape = [] + expected_shape: List[Optional[int]] = [] for f in site["cond_indep_stack"]: if f.dim is not None: # Use the specified plate dimension, which counts from the right. assert f.dim < 0 if len(expected_shape) < -f.dim: - expected_shape = [None] * ( + extra_shape: List[Optional[int]] = [None] * ( -f.dim - len(expected_shape) - ) + expected_shape + ) + expected_shape = extra_shape + expected_shape if expected_shape[f.dim] is not None: raise ValueError( "\n ".join( @@ -448,6 +527,9 @@ def check_site_shape(site, max_plate_nesting): ) # Check parallel dimensions on the left of max_plate_nesting. + if TYPE_CHECKING: + assert site["infer"] is not None + assert isinstance(site["fn"], TorchDistributionMixin) enum_dim = site["infer"].get("_enumerate_dim") if enum_dim is not None: if ( @@ -464,7 +546,7 @@ def check_site_shape(site, max_plate_nesting): ) -def _are_independent(counters1, counters2): +def _are_independent(counters1: Dict[str, int], counters2: Dict[str, int]) -> bool: for name, counter1 in counters1.items(): if name in counters2: if counters2[name] != counter1: @@ -472,7 +554,7 @@ def _are_independent(counters1, counters2): return False -def check_traceenum_requirements(model_trace, guide_trace): +def check_traceenum_requirements(model_trace: "Trace", guide_trace: "Trace") -> None: """ Warn if user could easily rewrite the model or guide in a way that would clearly avoid invalid dependencies on enumerated variables. @@ -490,8 +572,10 @@ def check_traceenum_requirements(model_trace, guide_trace): if site["type"] == "sample" and site["infer"].get("enumerate") ) for role, trace in [("model", model_trace), ("guide", guide_trace)]: - plate_counters = {} # for sequential plates only - enumerated_contexts = defaultdict(set) + plate_counters: Dict[str, Dict[str, int]] = {} # for sequential plates only + enumerated_contexts: Dict[FrozenSet["CondIndepStackFrame"], Set[str]] = ( + defaultdict(set) + ) for name, site in trace.nodes.items(): if site["type"] != "sample": continue @@ -504,12 +588,12 @@ def check_traceenum_requirements(model_trace, guide_trace): for enumerated_context, names in enumerated_contexts.items(): if not (context < enumerated_context): continue - names = sorted( + names_list = sorted( n for n in names if not _are_independent(plate_counter, plate_counters[n]) ) - if not names: + if not names_list: continue diff = sorted(f.name for f in enumerated_context - context) warnings.warn( @@ -519,7 +603,7 @@ def check_traceenum_requirements(model_trace, guide_trace): role, name ), 'Expected site "{}" to precede sites "{}"'.format( - name, '", "'.join(sorted(names)) + name, '", "'.join(sorted(names_list)) ), 'to avoid breaking independence of plates "{}"'.format( '", "'.join(diff) @@ -534,7 +618,7 @@ def check_traceenum_requirements(model_trace, guide_trace): enumerated_contexts[context].add(name) -def check_if_enumerated(guide_trace): +def check_if_enumerated(guide_trace: "Trace") -> None: enumerated_sites = [ name for name, site in guide_trace.nodes.items() @@ -579,7 +663,7 @@ def ignore_jit_warnings(filter=None): yield -def jit_iter(tensor): +def jit_iter(tensor: torch.Tensor) -> List[torch.Tensor]: """ Iterate over a tensor, ignoring jit warnings. """ @@ -620,7 +704,7 @@ def ignore_experimental_warning(): yield -def deep_getattr(obj, name): +def deep_getattr(obj: object, name: str) -> Any: """ Python getattr() for arbitrarily deep attributes Throws an AttributeError if bad attribute @@ -639,5 +723,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.elapsed -def torch_float(x): +@overload +def torch_float(x: Union[float, int]) -> float: ... +@overload +def torch_float(x: torch.Tensor) -> torch.Tensor: ... +def torch_float( + x: Union[torch.Tensor, Union[float, int]] +) -> Union[torch.Tensor, float]: return x.float() if isinstance(x, torch.Tensor) else float(x) diff --git a/setup.cfg b/setup.cfg index a92c6f00d7..b21884cb87 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,10 +72,6 @@ warn_unused_ignores = True ignore_errors = True warn_unused_ignores = True -[mypy-pyro.util.*] -ignore_errors = True -warn_unused_ignores = True - [mypy-tests.test_primitives] ignore_errors = True warn_unused_ignores = True