-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Simplify Model __new__ and metaclass #7473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ricardoV94
merged 14 commits into
pymc-devs:main
from
thomasaarholt:thomasaarholt/modelclass
Oct 10, 2024
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
ddf6f10
Type get_context correctly
thomasaarholt 74f6616
Import from future to use delayed evaluation of annotations
thomasaarholt aa253ac
New ModelManager class for managing model contexts
thomasaarholt d5beee6
Model class is now the context manager directly
thomasaarholt e2cefc3
Fix type of UNSET in type definition
thomasaarholt 5a1f0b8
Set model parent in init rather than in __new__
thomasaarholt 8846db9
Replace get_context in metaclass with classmethod
thomasaarholt 7318d75
Remove get_contexts from metaclass
thomasaarholt 560cd3b
Simplify ContextMeta
thomasaarholt 6ff3d5d
Type Model.register_rv for for downstream typing
thomasaarholt 2ebc48f
Include np.ndarray as possible type for coord values
thomasaarholt 165a35f
Use function-scoped new_dims to handle type hint varying throughout f…
thomasaarholt dbbb9a2
Fix case of dims = [None, None, ...]
thomasaarholt 8565965
Remove unused hack
thomasaarholt File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import annotations | ||
|
||
import functools | ||
import sys | ||
|
@@ -19,13 +20,8 @@ | |
import warnings | ||
|
||
from collections.abc import Iterable, Sequence | ||
from sys import modules | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Literal, | ||
Optional, | ||
TypeVar, | ||
Union, | ||
cast, | ||
overload, | ||
) | ||
|
@@ -42,7 +38,6 @@ | |
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.random.type import RandomType | ||
from pytensor.tensor.variable import TensorConstant, TensorVariable | ||
from typing_extensions import Self | ||
|
||
from pymc.blocking import DictToArrayBijection, RaveledVars | ||
from pymc.data import is_valid_observed | ||
|
@@ -73,6 +68,7 @@ | |
VarName, | ||
WithMemoization, | ||
_add_future_warning_tag, | ||
_UnsetType, | ||
get_transformed_name, | ||
get_value_vars_from_user_vars, | ||
get_var_name, | ||
|
@@ -92,118 +88,36 @@ | |
] | ||
|
||
|
||
T = TypeVar("T", bound="ContextMeta") | ||
class ModelManager(threading.local): | ||
"""Keeps track of currently active model contexts. | ||
|
||
A global instance of this is created in this module on import. | ||
Use that instance, `MODEL_MANAGER` to inspect current contexts. | ||
|
||
class ContextMeta(type): | ||
"""Functionality for objects that put themselves in a context manager.""" | ||
|
||
def __new__(cls, name, bases, dct, **kwargs): | ||
"""Add __enter__ and __exit__ methods to the class.""" | ||
|
||
def __enter__(self): | ||
self.__class__.context_class.get_contexts().append(self) | ||
return self | ||
|
||
def __exit__(self, typ, value, traceback): | ||
self.__class__.context_class.get_contexts().pop() | ||
It inherits from threading.local so is thread-safe, if models | ||
can be entered/exited within individual threads. | ||
""" | ||
|
||
dct[__enter__.__name__] = __enter__ | ||
dct[__exit__.__name__] = __exit__ | ||
def __init__(self): | ||
self.active_contexts: list[Model] = [] | ||
|
||
# We strip off keyword args, per the warning from | ||
# StackExchange: | ||
# DO NOT send "**kwargs" to "type.__new__". It won't catch them and | ||
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception. | ||
return super().__new__(cls, name, bases, dct) | ||
@property | ||
def current_context(self) -> Model | None: | ||
"""Return the innermost context of any current contexts.""" | ||
return self.active_contexts[-1] if self.active_contexts else None | ||
|
||
# FIXME: is there a more elegant way to automatically add methods to the class that | ||
# are instance methods instead of class methods? | ||
def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs): | ||
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically.""" | ||
if context_class is not None: | ||
cls._context_class = context_class | ||
super().__init__(name, bases, nmspc) | ||
@property | ||
def parent_context(self) -> Model | None: | ||
"""Return the parent context to the active context, if any.""" | ||
return self.active_contexts[-2] if len(self.active_contexts) > 1 else None | ||
|
||
def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None: | ||
"""Return the most recently pushed context object of type ``cls`` on the stack, or ``None``. | ||
|
||
If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``. | ||
""" | ||
try: | ||
candidate: T | None = cls.get_contexts()[-1] | ||
except IndexError: | ||
# Calling code expects to get a TypeError if the entity | ||
# is unfound, and there's too much to fix. | ||
if error_if_none: | ||
raise TypeError(f"No {cls} on context stack") | ||
return None | ||
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access: | ||
raise BlockModelAccessError(candidate.error_msg_on_access) | ||
return candidate | ||
|
||
def get_contexts(cls) -> list[T]: | ||
"""Return a stack of context instances for the ``context_class`` of ``cls``.""" | ||
# This lazily creates the context class's contexts | ||
# thread-local object, as needed. This seems inelegant to me, | ||
# but since the context class is not guaranteed to exist when | ||
# the metaclass is being instantiated, I couldn't figure out a | ||
# better way. [2019/10/11:rpg] | ||
|
||
# no race-condition here, contexts is a thread-local object | ||
# be sure not to override contexts in a subclass however! | ||
context_class = cls.context_class | ||
assert isinstance( | ||
context_class, type | ||
), f"Name of context class, {context_class} was not resolvable to a class" | ||
if not hasattr(context_class, "contexts"): | ||
context_class.contexts = threading.local() | ||
|
||
contexts = context_class.contexts | ||
|
||
if not hasattr(contexts, "stack"): | ||
contexts.stack = [] | ||
return contexts.stack | ||
|
||
# the following complex property accessor is necessary because the | ||
# context_class may not have been created at the point it is | ||
# specified, so the context_class may be a class *name* rather | ||
# than a class. | ||
@property | ||
def context_class(cls) -> type: | ||
def resolve_type(c: type | str) -> type: | ||
if isinstance(c, str): | ||
c = getattr(modules[cls.__module__], c) | ||
if isinstance(c, type): | ||
return c | ||
raise ValueError(f"Cannot resolve context class {c}") | ||
|
||
assert cls is not None | ||
if isinstance(cls._context_class, str): | ||
cls._context_class = resolve_type(cls._context_class) | ||
if not isinstance(cls._context_class, str | type): | ||
raise ValueError( | ||
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type" | ||
) | ||
return cls._context_class | ||
|
||
# Inherit context class from parent | ||
def __init_subclass__(cls, **kwargs): | ||
super().__init_subclass__(**kwargs) | ||
cls.context_class = super().context_class | ||
|
||
# Initialize object in its own context... | ||
# Merged from InitContextMeta in the original. | ||
def __call__(cls, *args, **kwargs): | ||
# We type hint Model here so type checkers understand that Model is a context manager. | ||
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info. | ||
instance: Model = cls.__new__(cls, *args, **kwargs) | ||
with instance: # appends context | ||
instance.__init__(*args, **kwargs) | ||
return instance | ||
# MODEL_MANAGER is instantiated at import, and serves as a truth for | ||
# what any currently active model contexts are. | ||
MODEL_MANAGER = ModelManager() | ||
|
||
|
||
def modelcontext(model: Optional["Model"]) -> "Model": | ||
def modelcontext(model: Model | None) -> Model: | ||
"""Return the given model or, if None was supplied, try to find one in the context stack.""" | ||
if model is None: | ||
model = Model.get_context(error_if_none=False) | ||
thomasaarholt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -372,6 +286,18 @@ def profile(self): | |
return self._pytensor_function.profile | ||
|
||
|
||
class ContextMeta(type): | ||
"""A metaclass in order to apply a model's context during `Model.__init__``.""" | ||
|
||
# We want the Model's context to be active during __init__. In order for this | ||
# to apply to subclasses of Model as well, we need to use a metaclass. | ||
def __call__(cls: type[Model], *args, **kwargs): | ||
instance = cls.__new__(cls, *args, **kwargs) | ||
with instance: # applies context | ||
instance.__init__(*args, **kwargs) | ||
return instance | ||
|
||
|
||
class Model(WithMemoization, metaclass=ContextMeta): | ||
"""Encapsulates the variables and likelihood factors of a model. | ||
|
||
|
@@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta): | |
|
||
""" | ||
|
||
if TYPE_CHECKING: | ||
def __enter__(self): | ||
"""Enter the context manager.""" | ||
MODEL_MANAGER.active_contexts.append(self) | ||
return self | ||
|
||
def __enter__(self: Self) -> Self: | ||
"""Enter the context manager.""" | ||
|
||
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: | ||
"""Exit the context manager.""" | ||
|
||
def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs): | ||
# resolves the parent instance | ||
instance = super().__new__(cls) | ||
if model is UNSET: | ||
instance._parent = cls.get_context(error_if_none=False) | ||
else: | ||
instance._parent = model | ||
return instance | ||
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: | ||
"""Exit the context manager.""" | ||
_ = MODEL_MANAGER.active_contexts.pop() | ||
|
||
@staticmethod | ||
def _validate_name(name): | ||
|
@@ -525,11 +443,11 @@ def __init__( | |
check_bounds=True, | ||
*, | ||
coords_mutable=None, | ||
model: Union[Literal[UNSET], None, "Model"] = UNSET, | ||
model: _UnsetType | None | Model = UNSET, | ||
): | ||
del model # used in __new__ to define the parent of this model | ||
self.name = self._validate_name(name) | ||
self.check_bounds = check_bounds | ||
self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context | ||
|
||
if coords_mutable is not None: | ||
warnings.warn( | ||
|
@@ -577,6 +495,17 @@ def __init__( | |
functools.partial(str_for_model, formatting="latex"), self | ||
) | ||
|
||
@classmethod | ||
def get_context( | ||
cls, error_if_none: bool = True, allow_block_model_access: bool = False | ||
) -> Model | None: | ||
model = MODEL_MANAGER.current_context | ||
if isinstance(model, BlockModelAccess) and not allow_block_model_access: | ||
raise BlockModelAccessError(model.error_msg_on_access) | ||
if model is None and error_if_none: | ||
raise TypeError("No model on context stack") | ||
return model | ||
|
||
lucianopaz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@property | ||
def parent(self): | ||
return self._parent | ||
|
@@ -967,7 +896,7 @@ def shape_from_dims(self, dims): | |
def add_coord( | ||
self, | ||
name: str, | ||
values: Sequence | None = None, | ||
values: Sequence | np.ndarray | None = None, | ||
mutable: bool | None = None, | ||
*, | ||
length: int | Variable | None = None, | ||
|
@@ -1233,16 +1162,16 @@ def set_data( | |
|
||
def register_rv( | ||
self, | ||
rv_var, | ||
name, | ||
rv_var: RandomVariable, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is wrong, RandomVariable is not a variable type, it's an Op type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello there 😊 |
||
name: str, | ||
*, | ||
observed=None, | ||
total_size=None, | ||
dims=None, | ||
default_transform=UNSET, | ||
transform=UNSET, | ||
initval=None, | ||
): | ||
) -> TensorVariable: | ||
"""Register an (un)observed random variable with the model. | ||
|
||
Parameters | ||
|
@@ -2074,11 +2003,6 @@ def to_graphviz( | |
) | ||
|
||
|
||
# this is really disgusting, but it breaks a self-loop: I can't pass Model | ||
# itself as context class init arg. | ||
Model._context_class = Model | ||
|
||
|
||
class BlockModelAccess(Model): | ||
"""Can be used to prevent user access to Model contexts.""" | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.