Skip to content

Commit

Permalink
Type annotate pyro.util.py (#3393)
Browse files Browse the repository at this point in the history
* type annotate util.py

* exclude ipynb
  • Loading branch information
ordabayevy committed Sep 16, 2024
1 parent 414a4d5 commit eadf00a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 32 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

[tool.ruff]
extend-exclude = ["*.ipynb"]
line-length = 120


Expand Down
146 changes: 118 additions & 28 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
List,
Optional,
Set,
Union,
overload,
)

import numpy as np
import torch

from pyro.poutine.util import site_is_subsample

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.runtime import Message
from pyro.poutine.trace import Trace

def set_rng_seed(rng_seed):

def set_rng_seed(rng_seed: int) -> None:
"""
Sets seeds of `torch` and `torch.cuda` (if available).
Expand All @@ -29,15 +46,15 @@ def set_rng_seed(rng_seed):
np.random.seed(rng_seed)


def get_rng_state():
def get_rng_state() -> Dict[str, Any]:
return {
"torch": torch.get_rng_state(),
"random": random.getstate(),
"numpy": np.random.get_state(),
}


def set_rng_state(state):
def set_rng_state(state: Dict[str, Any]) -> None:
torch.set_rng_state(state["torch"])
random.setstate(state["random"])
if "numpy" in state:
Expand All @@ -46,7 +63,11 @@ def set_rng_state(state):
np.random.set_state(state["numpy"])


def torch_isnan(x):
@overload
def torch_isnan(x: numbers.Number) -> bool: ...
@overload
def torch_isnan(x: torch.Tensor) -> torch.Tensor: ...
def torch_isnan(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]:
"""
A convenient function to check if a Tensor contains any nan; also works with numbers
"""
Expand All @@ -55,7 +76,11 @@ def torch_isnan(x):
return torch.isnan(x).any()


def torch_isinf(x):
@overload
def torch_isinf(x: numbers.Number) -> bool: ...
@overload
def torch_isinf(x: torch.Tensor) -> torch.Tensor: ...
def torch_isinf(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]:
"""
A convenient function to check if a Tensor contains any +inf; also works with numbers
"""
Expand All @@ -64,7 +89,29 @@ def torch_isinf(x):
return (x == math.inf).any() or (x == -math.inf).any()


def warn_if_nan(value, msg="", *, filename=None, lineno=None):
@overload
def warn_if_nan(
value: numbers.Number,
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> numbers.Number: ...
@overload
def warn_if_nan(
value: torch.Tensor,
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> torch.Tensor: ...
def warn_if_nan(
value: Union[torch.Tensor, numbers.Number],
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> Union[torch.Tensor, numbers.Number]:
"""
A convenient function to warn if a Tensor or its grad contains any nan,
also works with numbers.
Expand All @@ -79,14 +126,15 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None):
filename = frame.f_code.co_filename
lineno = frame.f_lineno

if torch.is_tensor(value) and value.requires_grad:
if isinstance(value, torch.Tensor) and value.requires_grad:
value.register_hook(
lambda x: warn_if_nan(
x, "backward " + msg, filename=filename, lineno=lineno
)
)

if torch_isnan(value):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered NaN{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -97,9 +145,35 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None):
return value


@overload
def warn_if_inf(
value: numbers.Number,
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> numbers.Number: ...
@overload
def warn_if_inf(
value: torch.Tensor,
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> torch.Tensor: ...
def warn_if_inf(
value, msg="", allow_posinf=False, allow_neginf=False, *, filename=None, lineno=None
):
value: Union[torch.Tensor, numbers.Number],
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> Union[torch.Tensor, numbers.Number]:
"""
A convenient function to warn if a Tensor or its grad contains any inf,
also works with numbers.
Expand All @@ -114,7 +188,7 @@ def warn_if_inf(
filename = frame.f_code.co_filename
lineno = frame.f_lineno

if torch.is_tensor(value) and value.requires_grad:
if isinstance(value, torch.Tensor) and value.requires_grad:
value.register_hook(
lambda x: warn_if_inf(
x,
Expand All @@ -131,6 +205,7 @@ def warn_if_inf(
if isinstance(value, numbers.Number)
else (value == math.inf).any()
):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered +inf{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -142,6 +217,7 @@ def warn_if_inf(
if isinstance(value, numbers.Number)
else (value == -math.inf).any()
):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered -inf{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -152,7 +228,7 @@ def warn_if_inf(
return value


def save_visualization(trace, graph_output):
def save_visualization(trace: "Trace", graph_output: str) -> None:
"""
DEPRECATED Use :func:`pyro.infer.inspect.render_model()` instead.
Expand Down Expand Up @@ -206,7 +282,7 @@ def save_visualization(trace, graph_output):
g.render(graph_output, view=False, cleanup=True)


def check_traces_match(trace1, trace2):
def check_traces_match(trace1: "Trace", trace2: "Trace") -> None:
"""
:param pyro.poutine.Trace trace1: Trace object of the model
:param pyro.poutine.Trace trace2: Trace object of the guide
Expand Down Expand Up @@ -236,7 +312,9 @@ def check_traces_match(trace1, trace2):
)


def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf):
def check_model_guide_match(
model_trace: "Trace", guide_trace: "Trace", max_plate_nesting: float = math.inf
) -> None:
"""
:param pyro.poutine.Trace model_trace: Trace object of the model
:param pyro.poutine.Trace guide_trace: Trace object of the guide
Expand Down Expand Up @@ -385,19 +463,20 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
)


def check_site_shape(site, max_plate_nesting):
def check_site_shape(site: "Message", max_plate_nesting: int) -> None:
actual_shape = list(site["log_prob"].shape)

# Compute expected shape.
expected_shape = []
expected_shape: List[Optional[int]] = []
for f in site["cond_indep_stack"]:
if f.dim is not None:
# Use the specified plate dimension, which counts from the right.
assert f.dim < 0
if len(expected_shape) < -f.dim:
expected_shape = [None] * (
extra_shape: List[Optional[int]] = [None] * (
-f.dim - len(expected_shape)
) + expected_shape
)
expected_shape = extra_shape + expected_shape
if expected_shape[f.dim] is not None:
raise ValueError(
"\n ".join(
Expand Down Expand Up @@ -448,6 +527,9 @@ def check_site_shape(site, max_plate_nesting):
)

# Check parallel dimensions on the left of max_plate_nesting.
if TYPE_CHECKING:
assert site["infer"] is not None
assert isinstance(site["fn"], TorchDistributionMixin)
enum_dim = site["infer"].get("_enumerate_dim")
if enum_dim is not None:
if (
Expand All @@ -464,15 +546,15 @@ def check_site_shape(site, max_plate_nesting):
)


def _are_independent(counters1, counters2):
def _are_independent(counters1: Dict[str, int], counters2: Dict[str, int]) -> bool:
for name, counter1 in counters1.items():
if name in counters2:
if counters2[name] != counter1:
return True
return False


def check_traceenum_requirements(model_trace, guide_trace):
def check_traceenum_requirements(model_trace: "Trace", guide_trace: "Trace") -> None:
"""
Warn if user could easily rewrite the model or guide in a way that would
clearly avoid invalid dependencies on enumerated variables.
Expand All @@ -490,8 +572,10 @@ def check_traceenum_requirements(model_trace, guide_trace):
if site["type"] == "sample" and site["infer"].get("enumerate")
)
for role, trace in [("model", model_trace), ("guide", guide_trace)]:
plate_counters = {} # for sequential plates only
enumerated_contexts = defaultdict(set)
plate_counters: Dict[str, Dict[str, int]] = {} # for sequential plates only
enumerated_contexts: Dict[FrozenSet["CondIndepStackFrame"], Set[str]] = (
defaultdict(set)
)
for name, site in trace.nodes.items():
if site["type"] != "sample":
continue
Expand All @@ -504,12 +588,12 @@ def check_traceenum_requirements(model_trace, guide_trace):
for enumerated_context, names in enumerated_contexts.items():
if not (context < enumerated_context):
continue
names = sorted(
names_list = sorted(
n
for n in names
if not _are_independent(plate_counter, plate_counters[n])
)
if not names:
if not names_list:
continue
diff = sorted(f.name for f in enumerated_context - context)
warnings.warn(
Expand All @@ -519,7 +603,7 @@ def check_traceenum_requirements(model_trace, guide_trace):
role, name
),
'Expected site "{}" to precede sites "{}"'.format(
name, '", "'.join(sorted(names))
name, '", "'.join(sorted(names_list))
),
'to avoid breaking independence of plates "{}"'.format(
'", "'.join(diff)
Expand All @@ -534,7 +618,7 @@ def check_traceenum_requirements(model_trace, guide_trace):
enumerated_contexts[context].add(name)


def check_if_enumerated(guide_trace):
def check_if_enumerated(guide_trace: "Trace") -> None:
enumerated_sites = [
name
for name, site in guide_trace.nodes.items()
Expand Down Expand Up @@ -579,7 +663,7 @@ def ignore_jit_warnings(filter=None):
yield


def jit_iter(tensor):
def jit_iter(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Iterate over a tensor, ignoring jit warnings.
"""
Expand Down Expand Up @@ -620,7 +704,7 @@ def ignore_experimental_warning():
yield


def deep_getattr(obj, name):
def deep_getattr(obj: object, name: str) -> Any:
"""
Python getattr() for arbitrarily deep attributes
Throws an AttributeError if bad attribute
Expand All @@ -639,5 +723,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.elapsed


def torch_float(x):
@overload
def torch_float(x: Union[float, int]) -> float: ...
@overload
def torch_float(x: torch.Tensor) -> torch.Tensor: ...
def torch_float(
x: Union[torch.Tensor, Union[float, int]]
) -> Union[torch.Tensor, float]:
return x.float() if isinstance(x, torch.Tensor) else float(x)
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ warn_unused_ignores = True
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.util.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-tests.test_primitives]
ignore_errors = True
warn_unused_ignores = True
Expand Down

0 comments on commit eadf00a

Please sign in to comment.