Skip to content

Commit

Permalink
Make step method state keep track of var_names
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Nov 8, 2024
1 parent c0c9e4e commit 196f668
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
2 changes: 2 additions & 0 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import field
from enum import IntEnum, unique
from typing import Any

Expand Down Expand Up @@ -91,6 +92,7 @@ def infer_warn_stats_info(

@dataclass_state
class StepMethodState(DataClassState):
var_names: list[str] = field(metadata={"tensor_name": True, "frozen": True})
rng: np.random.Generator


Expand Down
15 changes: 13 additions & 2 deletions pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def sampling_state(self) -> DataClassState:
state_class = self._state_class
kwargs = {}
for field in fields(state_class):
val = getattr(self, field.name, field.default)
is_tensor_name = field.metadata.get("tensor_name", False)
val: Any
if is_tensor_name:
val = [var.name for var in getattr(self, "vars")]
else:
val = getattr(self, field.name, field.default)
if val is MISSING:
raise AttributeError(

Check warning on line 75 in pymc/step_methods/state.py

View check run for this annotation

Codecov / codecov/patch

pymc/step_methods/state.py#L75

Added line #L75 was not covered by tests
f"{type(self).__name__!r} object has no attribute {field.name!r}"
Expand All @@ -84,9 +89,15 @@ def sampling_state(self, state: DataClassState):
state, state_class
), f"Encountered invalid state class '{state.__class__}'. State must be '{state_class}'"
for field in fields(state_class):
is_tensor_name = field.metadata.get("tensor_name", False)
state_val = deepcopy(getattr(state, field.name))
self_val = getattr(self, field.name)
is_frozen = field.metadata.get("frozen", False)
self_val: Any
if is_tensor_name:
self_val = [var.name for var in getattr(self, "vars")]
assert is_frozen
else:
self_val = getattr(self, field.name, field.default)
if is_frozen:
if not equal_dataclass_values(state_val, self_val):
raise ValueError(
Expand Down

0 comments on commit 196f668

Please sign in to comment.