Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Model __new__ and metaclass #7473

Merged
merged 14 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
def determine_coords(
model,
value: pd.DataFrame | pd.Series | xr.DataArray,
dims: Sequence[str | None] | None = None,
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
dims: Sequence[str] | None = None,
coords: dict[str, Sequence | np.ndarray] | None = None,
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
"""Determine coordinate values from data or the model (via ``dims``)."""
if coords is None:
coords = {}
Expand Down Expand Up @@ -268,9 +268,10 @@ def determine_coords(

if dims is None:
# TODO: Also determine dim names from the index
dims = [None] * np.ndim(value)

return coords, dims
new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value)
else:
new_dims = dims
return coords, new_dims


def ConstantData(
Expand Down Expand Up @@ -366,7 +367,7 @@ def Data(
The name for this variable.
value : array_like or pandas.Series, pandas.Dataframe
A value to associate with this variable.
dims : str or tuple of str, optional
dims : str, tuple of str or tuple of None, optional
Dimension names of the random variables (as opposed to the shapes of these
random variables). Use this when ``value`` is a pandas Series or DataFrame. The
``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ
Expand Down Expand Up @@ -451,14 +452,17 @@ def Data(
expected=x.ndim,
)

new_dims: Sequence[str] | Sequence[None] | None
if infer_dims_and_coords:
coords, dims = determine_coords(model, value, dims)
coords, new_dims = determine_coords(model, value, dims)
else:
new_dims = dims

if dims:
if new_dims:
xshape = x.shape
# Register new dimension lengths
for d, dname in enumerate(dims):
if dname not in model.dim_lengths:
for d, dname in enumerate(new_dims):
if dname not in model.dim_lengths and dname is not None:
model.add_coord(
name=dname,
# Note: Coordinate values can't be taken from
Expand All @@ -467,6 +471,6 @@ def Data(
length=xshape[d],
)

model.register_data_var(x, dims=dims)
model.register_data_var(x, dims=new_dims)

return x
189 changes: 59 additions & 130 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import sys
Expand All @@ -19,13 +20,8 @@
import warnings

from collections.abc import Iterable, Sequence
from sys import modules
from typing import (
TYPE_CHECKING,
Literal,
Optional,
TypeVar,
Union,
cast,
overload,
)
Expand All @@ -42,7 +38,6 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from typing_extensions import Self

from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.data import is_valid_observed
Expand Down Expand Up @@ -73,6 +68,7 @@
VarName,
WithMemoization,
_add_future_warning_tag,
_UnsetType,
get_transformed_name,
get_value_vars_from_user_vars,
get_var_name,
Expand All @@ -92,118 +88,36 @@
]


T = TypeVar("T", bound="ContextMeta")
class ModelManager(threading.local):
"""Keeps track of currently active model contexts.

A global instance of this is created in this module on import.
Use that instance, `MODEL_MANAGER` to inspect current contexts.

class ContextMeta(type):
"""Functionality for objects that put themselves in a context manager."""

def __new__(cls, name, bases, dct, **kwargs):
"""Add __enter__ and __exit__ methods to the class."""
It inherits from threading.local so is thread-safe, if models
can be entered/exited within individual threads.
"""

def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
return self
def __init__(self):
self.active_contexts: list[Model] = []

def __exit__(self, typ, value, traceback):
self.__class__.context_class.get_contexts().pop()
@property
def current_context(self) -> Model | None:
"""Return the innermost context of any current contexts."""
return self.active_contexts[-1] if self.active_contexts else None

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__
@property
def parent_context(self) -> Model | None:
"""Return the parent context to the active context, if any."""
return self.active_contexts[-2] if len(self.active_contexts) > 1 else None

# We strip off keyword args, per the warning from
# StackExchange:
# DO NOT send "**kwargs" to "type.__new__". It won't catch them and
# you'll get a "TypeError: type() takes 1 or 3 arguments" exception.
return super().__new__(cls, name, bases, dct)

# FIXME: is there a more elegant way to automatically add methods to the class that
# are instance methods instead of class methods?
def __init__(cls, name, bases, nmspc, context_class: type | None = None, **kwargs):
"""Add ``__enter__`` and ``__exit__`` methods to the new class automatically."""
if context_class is not None:
cls._context_class = context_class
super().__init__(name, bases, nmspc)
# MODEL_MANAGER is instantiated at import, and serves as a truth for
# what any currently active model contexts are.
MODEL_MANAGER = ModelManager()

def get_context(cls, error_if_none=True, allow_block_model_access=False) -> T | None:
"""Return the most recently pushed context object of type ``cls`` on the stack, or ``None``.

If ``error_if_none`` is True (default), raise a ``TypeError`` instead of returning ``None``.
"""
try:
candidate: T | None = cls.get_contexts()[-1]
except IndexError:
# Calling code expects to get a TypeError if the entity
# is unfound, and there's too much to fix.
if error_if_none:
raise TypeError(f"No {cls} on context stack")
return None
if isinstance(candidate, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(candidate.error_msg_on_access)
return candidate

def get_contexts(cls) -> list[T]:
"""Return a stack of context instances for the ``context_class`` of ``cls``."""
# This lazily creates the context class's contexts
# thread-local object, as needed. This seems inelegant to me,
# but since the context class is not guaranteed to exist when
# the metaclass is being instantiated, I couldn't figure out a
# better way. [2019/10/11:rpg]

# no race-condition here, contexts is a thread-local object
# be sure not to override contexts in a subclass however!
context_class = cls.context_class
assert isinstance(
context_class, type
), f"Name of context class, {context_class} was not resolvable to a class"
if not hasattr(context_class, "contexts"):
context_class.contexts = threading.local()

contexts = context_class.contexts

if not hasattr(contexts, "stack"):
contexts.stack = []
return contexts.stack

# the following complex property accessor is necessary because the
# context_class may not have been created at the point it is
# specified, so the context_class may be a class *name* rather
# than a class.
@property
def context_class(cls) -> type:
def resolve_type(c: type | str) -> type:
if isinstance(c, str):
c = getattr(modules[cls.__module__], c)
if isinstance(c, type):
return c
raise ValueError(f"Cannot resolve context class {c}")

assert cls is not None
if isinstance(cls._context_class, str):
cls._context_class = resolve_type(cls._context_class)
if not isinstance(cls._context_class, str | type):
raise ValueError(
f"Context class for {cls.__name__}, {cls._context_class}, is not of the right type"
)
return cls._context_class

# Inherit context class from parent
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.context_class = super().context_class

# Initialize object in its own context...
# Merged from InitContextMeta in the original.
def __call__(cls, *args, **kwargs):
# We type hint Model here so type checkers understand that Model is a context manager.
# This metaclass is only used for Model, so this is safe to do. See #6809 for more info.
instance: Model = cls.__new__(cls, *args, **kwargs)
with instance: # appends context
instance.__init__(*args, **kwargs)
return instance


def modelcontext(model: Optional["Model"]) -> "Model":
def modelcontext(model: Model | None) -> Model:
"""Return the given model or, if None was supplied, try to find one in the context stack."""
if model is None:
model = Model.get_context(error_if_none=False)
thomasaarholt marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -372,6 +286,18 @@ def profile(self):
return self._pytensor_function.profile


class ContextMeta(type):
"""A metaclass in order to apply a model's context during `Model.__init__``."""

# We want the Model's context to be active during __init__. In order for this
# to apply to subclasses of Model as well, we need to use a metaclass.
def __call__(cls: type[Model], *args, **kwargs):
instance = cls.__new__(cls, *args, **kwargs)
with instance: # applies context
instance.__init__(*args, **kwargs)
return instance


class Model(WithMemoization, metaclass=ContextMeta):
"""Encapsulates the variables and likelihood factors of a model.

Expand Down Expand Up @@ -495,22 +421,14 @@ class Model(WithMemoization, metaclass=ContextMeta):

"""

if TYPE_CHECKING:

def __enter__(self: Self) -> Self:
"""Enter the context manager."""

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
"""Exit the context manager."""
def __enter__(self):
"""Enter the context manager."""
MODEL_MANAGER.active_contexts.append(self)
return self

def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if model is UNSET:
instance._parent = cls.get_context(error_if_none=False)
else:
instance._parent = model
return instance
def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None:
"""Exit the context manager."""
_ = MODEL_MANAGER.active_contexts.pop()

@staticmethod
def _validate_name(name):
Expand All @@ -525,11 +443,11 @@ def __init__(
check_bounds=True,
*,
coords_mutable=None,
model: Union[Literal[UNSET], None, "Model"] = UNSET,
model: _UnsetType | None | Model = UNSET,
):
del model # used in __new__ to define the parent of this model
self.name = self._validate_name(name)
self.check_bounds = check_bounds
self._parent = model if not isinstance(model, _UnsetType) else MODEL_MANAGER.parent_context

if coords_mutable is not None:
warnings.warn(
Expand Down Expand Up @@ -577,6 +495,17 @@ def __init__(
functools.partial(str_for_model, formatting="latex"), self
)

@classmethod
def get_context(
cls, error_if_none: bool = True, allow_block_model_access: bool = False
) -> Model | None:
model = MODEL_MANAGER.current_context
if isinstance(model, BlockModelAccess) and not allow_block_model_access:
raise BlockModelAccessError(model.error_msg_on_access)
if model is None and error_if_none:
raise TypeError("No model on context stack")
return model

lucianopaz marked this conversation as resolved.
Show resolved Hide resolved
@property
def parent(self):
return self._parent
Expand Down Expand Up @@ -967,7 +896,7 @@ def shape_from_dims(self, dims):
def add_coord(
self,
name: str,
values: Sequence | None = None,
values: Sequence | np.ndarray | None = None,
mutable: bool | None = None,
*,
length: int | Variable | None = None,
Expand Down Expand Up @@ -1233,16 +1162,16 @@ def set_data(

def register_rv(
self,
rv_var,
name,
rv_var: RandomVariable,
name: str,
*,
observed=None,
total_size=None,
dims=None,
default_transform=UNSET,
transform=UNSET,
initval=None,
):
) -> TensorVariable:
"""Register an (un)observed random variable with the model.

Parameters
Expand Down
Loading