Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate pyro.util.py #3393

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

[tool.ruff]
extend-exclude = ["*.ipynb"]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jupyter Notebooks are linted by default starting from ruff version 0.6.0.
https://docs.astral.sh/ruff/configuration/#jupyter-notebook-discovery

line-length = 120


Expand Down
146 changes: 118 additions & 28 deletions pyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
List,
Optional,
Set,
Union,
overload,
)

import numpy as np
import torch

from pyro.poutine.util import site_is_subsample

if TYPE_CHECKING:
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.indep_messenger import CondIndepStackFrame
from pyro.poutine.runtime import Message
from pyro.poutine.trace import Trace

def set_rng_seed(rng_seed):

def set_rng_seed(rng_seed: int) -> None:
"""
Sets seeds of `torch` and `torch.cuda` (if available).

Expand All @@ -29,15 +46,15 @@ def set_rng_seed(rng_seed):
np.random.seed(rng_seed)


def get_rng_state():
def get_rng_state() -> Dict[str, Any]:
return {
"torch": torch.get_rng_state(),
"random": random.getstate(),
"numpy": np.random.get_state(),
}


def set_rng_state(state):
def set_rng_state(state: Dict[str, Any]) -> None:
torch.set_rng_state(state["torch"])
random.setstate(state["random"])
if "numpy" in state:
Expand All @@ -46,7 +63,11 @@ def set_rng_state(state):
np.random.set_state(state["numpy"])


def torch_isnan(x):
@overload
def torch_isnan(x: numbers.Number) -> bool: ...
@overload
def torch_isnan(x: torch.Tensor) -> torch.Tensor: ...
def torch_isnan(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]:
"""
A convenient function to check if a Tensor contains any nan; also works with numbers
"""
Expand All @@ -55,7 +76,11 @@ def torch_isnan(x):
return torch.isnan(x).any()


def torch_isinf(x):
@overload
def torch_isinf(x: numbers.Number) -> bool: ...
@overload
def torch_isinf(x: torch.Tensor) -> torch.Tensor: ...
def torch_isinf(x: Union[torch.Tensor, numbers.Number]) -> Union[bool, torch.Tensor]:
"""
A convenient function to check if a Tensor contains any +inf; also works with numbers
"""
Expand All @@ -64,7 +89,29 @@ def torch_isinf(x):
return (x == math.inf).any() or (x == -math.inf).any()


def warn_if_nan(value, msg="", *, filename=None, lineno=None):
@overload
def warn_if_nan(
value: numbers.Number,
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> numbers.Number: ...
@overload
def warn_if_nan(
value: torch.Tensor,
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> torch.Tensor: ...
def warn_if_nan(
value: Union[torch.Tensor, numbers.Number],
msg: str = "",
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> Union[torch.Tensor, numbers.Number]:
"""
A convenient function to warn if a Tensor or its grad contains any nan,
also works with numbers.
Expand All @@ -79,14 +126,15 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None):
filename = frame.f_code.co_filename
lineno = frame.f_lineno

if torch.is_tensor(value) and value.requires_grad:
if isinstance(value, torch.Tensor) and value.requires_grad:
value.register_hook(
lambda x: warn_if_nan(
x, "backward " + msg, filename=filename, lineno=lineno
)
)

if torch_isnan(value):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered NaN{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -97,9 +145,35 @@ def warn_if_nan(value, msg="", *, filename=None, lineno=None):
return value


@overload
def warn_if_inf(
value: numbers.Number,
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> numbers.Number: ...
@overload
def warn_if_inf(
value: torch.Tensor,
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> torch.Tensor: ...
def warn_if_inf(
value, msg="", allow_posinf=False, allow_neginf=False, *, filename=None, lineno=None
):
value: Union[torch.Tensor, numbers.Number],
msg: str = "",
allow_posinf: bool = False,
allow_neginf: bool = False,
*,
filename: Optional[str] = None,
lineno: Optional[int] = None,
) -> Union[torch.Tensor, numbers.Number]:
"""
A convenient function to warn if a Tensor or its grad contains any inf,
also works with numbers.
Expand All @@ -114,7 +188,7 @@ def warn_if_inf(
filename = frame.f_code.co_filename
lineno = frame.f_lineno

if torch.is_tensor(value) and value.requires_grad:
if isinstance(value, torch.Tensor) and value.requires_grad:
value.register_hook(
lambda x: warn_if_inf(
x,
Expand All @@ -131,6 +205,7 @@ def warn_if_inf(
if isinstance(value, numbers.Number)
else (value == math.inf).any()
):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered +inf{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -142,6 +217,7 @@ def warn_if_inf(
if isinstance(value, numbers.Number)
else (value == -math.inf).any()
):
assert isinstance(lineno, int)
warnings.warn_explicit(
"Encountered -inf{}".format(": " + msg if msg else "."),
UserWarning,
Expand All @@ -152,7 +228,7 @@ def warn_if_inf(
return value


def save_visualization(trace, graph_output):
def save_visualization(trace: "Trace", graph_output: str) -> None:
"""
DEPRECATED Use :func:`pyro.infer.inspect.render_model()` instead.

Expand Down Expand Up @@ -206,7 +282,7 @@ def save_visualization(trace, graph_output):
g.render(graph_output, view=False, cleanup=True)


def check_traces_match(trace1, trace2):
def check_traces_match(trace1: "Trace", trace2: "Trace") -> None:
"""
:param pyro.poutine.Trace trace1: Trace object of the model
:param pyro.poutine.Trace trace2: Trace object of the guide
Expand Down Expand Up @@ -236,7 +312,9 @@ def check_traces_match(trace1, trace2):
)


def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf):
def check_model_guide_match(
model_trace: "Trace", guide_trace: "Trace", max_plate_nesting: float = math.inf
) -> None:
"""
:param pyro.poutine.Trace model_trace: Trace object of the model
:param pyro.poutine.Trace guide_trace: Trace object of the guide
Expand Down Expand Up @@ -385,19 +463,20 @@ def check_model_guide_match(model_trace, guide_trace, max_plate_nesting=math.inf
)


def check_site_shape(site, max_plate_nesting):
def check_site_shape(site: "Message", max_plate_nesting: int) -> None:
actual_shape = list(site["log_prob"].shape)

# Compute expected shape.
expected_shape = []
expected_shape: List[Optional[int]] = []
for f in site["cond_indep_stack"]:
if f.dim is not None:
# Use the specified plate dimension, which counts from the right.
assert f.dim < 0
if len(expected_shape) < -f.dim:
expected_shape = [None] * (
extra_shape: List[Optional[int]] = [None] * (
-f.dim - len(expected_shape)
) + expected_shape
)
expected_shape = extra_shape + expected_shape
if expected_shape[f.dim] is not None:
raise ValueError(
"\n ".join(
Expand Down Expand Up @@ -448,6 +527,9 @@ def check_site_shape(site, max_plate_nesting):
)

# Check parallel dimensions on the left of max_plate_nesting.
if TYPE_CHECKING:
assert site["infer"] is not None
assert isinstance(site["fn"], TorchDistributionMixin)
enum_dim = site["infer"].get("_enumerate_dim")
if enum_dim is not None:
if (
Expand All @@ -464,15 +546,15 @@ def check_site_shape(site, max_plate_nesting):
)


def _are_independent(counters1, counters2):
def _are_independent(counters1: Dict[str, int], counters2: Dict[str, int]) -> bool:
for name, counter1 in counters1.items():
if name in counters2:
if counters2[name] != counter1:
return True
return False


def check_traceenum_requirements(model_trace, guide_trace):
def check_traceenum_requirements(model_trace: "Trace", guide_trace: "Trace") -> None:
"""
Warn if user could easily rewrite the model or guide in a way that would
clearly avoid invalid dependencies on enumerated variables.
Expand All @@ -490,8 +572,10 @@ def check_traceenum_requirements(model_trace, guide_trace):
if site["type"] == "sample" and site["infer"].get("enumerate")
)
for role, trace in [("model", model_trace), ("guide", guide_trace)]:
plate_counters = {} # for sequential plates only
enumerated_contexts = defaultdict(set)
plate_counters: Dict[str, Dict[str, int]] = {} # for sequential plates only
enumerated_contexts: Dict[FrozenSet["CondIndepStackFrame"], Set[str]] = (
defaultdict(set)
)
for name, site in trace.nodes.items():
if site["type"] != "sample":
continue
Expand All @@ -504,12 +588,12 @@ def check_traceenum_requirements(model_trace, guide_trace):
for enumerated_context, names in enumerated_contexts.items():
if not (context < enumerated_context):
continue
names = sorted(
names_list = sorted(
n
for n in names
if not _are_independent(plate_counter, plate_counters[n])
)
if not names:
if not names_list:
continue
diff = sorted(f.name for f in enumerated_context - context)
warnings.warn(
Expand All @@ -519,7 +603,7 @@ def check_traceenum_requirements(model_trace, guide_trace):
role, name
),
'Expected site "{}" to precede sites "{}"'.format(
name, '", "'.join(sorted(names))
name, '", "'.join(sorted(names_list))
),
'to avoid breaking independence of plates "{}"'.format(
'", "'.join(diff)
Expand All @@ -534,7 +618,7 @@ def check_traceenum_requirements(model_trace, guide_trace):
enumerated_contexts[context].add(name)


def check_if_enumerated(guide_trace):
def check_if_enumerated(guide_trace: "Trace") -> None:
enumerated_sites = [
name
for name, site in guide_trace.nodes.items()
Expand Down Expand Up @@ -579,7 +663,7 @@ def ignore_jit_warnings(filter=None):
yield


def jit_iter(tensor):
def jit_iter(tensor: torch.Tensor) -> List[torch.Tensor]:
"""
Iterate over a tensor, ignoring jit warnings.
"""
Expand Down Expand Up @@ -620,7 +704,7 @@ def ignore_experimental_warning():
yield


def deep_getattr(obj, name):
def deep_getattr(obj: object, name: str) -> Any:
"""
Python getattr() for arbitrarily deep attributes
Throws an AttributeError if bad attribute
Expand All @@ -639,5 +723,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return self.elapsed


def torch_float(x):
@overload
def torch_float(x: Union[float, int]) -> float: ...
@overload
def torch_float(x: torch.Tensor) -> torch.Tensor: ...
def torch_float(
x: Union[torch.Tensor, Union[float, int]]
) -> Union[torch.Tensor, float]:
return x.float() if isinstance(x, torch.Tensor) else float(x)
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ warn_unused_ignores = True
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.util.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-tests.test_primitives]
ignore_errors = True
warn_unused_ignores = True
Expand Down
Loading