-
Notifications
You must be signed in to change notification settings - Fork 1
Forwarding context management
For context, the problem, and desired solution was originally discussed in this issue.
Know before going further, that the core of the solution is just:
class ContextualizeCall:
def __init__(self, func, enter_func, exit_func):
self.func, self.enter_func, self.exit_func = func, enter_func, exit_func
# wrapping self with attributes of func (signature, name, etc.)
wraps(func)(self)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __enter__(self):
return self.enter_func(self.func)
def __exit__(self, *exception_info):
return self.exit_func(self.func, *exception_info)
The rest of the code is just to make it more production-ready (validation, etc.)
Let's make a toy "streamer" that has all it needs:
- A
read
method that will yield items if and only if, the streamer was "turned on" (and raise error if not) - A mechanism to turn things on and off (a context manager)
class StreamHasNotBeenStarted(RuntimeError):
"""Raised when an action requires the stream to be 'on'"""
class Streamer:
def __init__(self, iterable):
self.iterable = iterable
self.is_running = False
self._read = None
def enter(self):
# print(f'{type(self).__name__}.enter')
self._read = iter(self.iterable).__next__
self.is_running = True
def exit(self, *exc):
# print(f'{type(self).__name__}.exit')
self._read = None
self.is_running = False
__enter__, __exit__ = enter, exit
def read(self):
if not self.is_running:
raise StreamHasNotBeenStarted(
"The stream needs to be on/started (in a context) for that!"
)
return self._read()
Let's try it out:
s = Streamer('stream')
See that it works:
with s:
assert s.read() == 's'
assert s.read() == 't'
Let's make a reader test function to carry out our validation more easily.
def test_reader(reader):
"""An (initialized) reader should be able to be called six times and produce s t r e a m"""
assert ''.join(reader() for _ in range(6)) == 'stream'
If we enter s
's context, our reader will work!
reader = s.read
with s:
test_reader(reader)
We can also perform this entering/exiting manually
s.enter()
reader = s.read
test_reader(reader)
s.exit()
But if we don't turn things on, it won't work
reader = s.read
try:
# oops, forgot the enter s context
test_reader(reader)
it_worked = True
except StreamHasNotBeenStarted as e:
it_worked = False
assert isinstance(e, StreamHasNotBeenStarted)
assert e.args[0] == 'The stream needs to be on/started (in a context) for that!'
assert not it_worked
But we can't turn the reader
(read method) on -- it's not a context (it's instance is!)
reader = s.read
try:
with reader: # can we actually do this (answer: no! We can enter s, not s.read)
test_reader(reader)
it_worked = True
except Exception as e:
it_worked = False
assert isinstance(e, AttributeError)
assert e.args[0] == '__enter__' # well yeah, reader doesn't have an __enter__!
assert not it_worked
What we'd like is a contextualize_with_instance
that satisfies this:
reader = contextualize_with_instance(s.read)
try:
with reader: # now we can enter the reader!
test_reader(reader)
it_worked = True
except Exception as e:
it_worked = False
assert it_worked # Hurray!
See code for contextualize_with_instance
below.
from inspect import signature, Parameter
from abc import abstractstaticmethod
from functools import wraps
class TypeValidationError(TypeError):
"""Raised if an object is not valid"""
@abstractstaticmethod
def is_valid(obj) -> bool:
"""Returns True if and only if obj is considered valid"""
@classmethod
def validate(cls, obj, *args, **kwargs):
if not cls.is_valid(obj):
raise cls(*args, **kwargs)
def null_enter_func(obj):
return obj
def null_exit_func(obj, *exception_info):
return None
def _has_min_num_of_arguments(func, mininum_n_args):
return len(signature(func).parameters) >= mininum_n_args
def _callable_has_a_variadic_pos(func):
"""True if and only if the function has a variadic positional argument (i.e. *args)"""
return any(x.kind == Parameter.VAR_POSITIONAL for x in signature(func).parameters.values())
class EnterFunctionValidation(TypeValidationError):
"""Raised if an function isn't a valid (context manager) __enter__ function"""
@staticmethod
def is_valid(obj):
return callable(obj) and _has_min_num_of_arguments(obj, 1)
class ExitFunctionValidation(TypeValidationError):
"""Raised if an function isn't a valid (context manager) __exit__ function"""
@staticmethod
def is_valid(obj):
return callable(obj) and (_callable_has_a_variadic_pos or _has_min_num_of_arguments(obj, 4))
# TODO: If needed:
# A more general Contextualizer would proxy/delegate all methods (including special ones) to the wrapped
class ContextualizeCall:
def __init__(self, func, enter_func=null_enter_func, exit_func=null_exit_func):
# validation
if not callable(func):
raise TypeError(f"First argument should be a callable, was: {func}")
EnterFunctionValidation.validate(
enter_func,
f"Not a valid enter function (should be a callable with at least one arg): {enter_func}"
)
ExitFunctionValidation.validate(
exit_func,
f"Not a valid exit function (should be a callable with at least one arg or varadic args): {exit_func}",
)
# assignment
self.func = func
self.enter_func = enter_func
self.exit_func = exit_func
# wrapping self with attributes of func (signature, name, etc.)
wraps(func)(self)
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
def __enter__(self):
# print(f'{type(self).__name__}.__enter__')
return self.enter_func(self.func)
def __exit__(self, *exception_info):
# print(f'{type(self).__name__}.__exit__')
return self.exit_func(self.func, *exception_info)
Using these general objects for the particular case of having bound methods forward their context management to the instances they're bound to
from functools import partial
def forward_to_instance_enter(obj):
return obj.__self__.__enter__()
def forward_to_instance_exit(obj, *exception_info):
return obj.__self__.__exit__(obj, *exception_info)
contextualize_with_instance = partial(
ContextualizeCall,
enter_func=forward_to_instance_enter,
exit_func=forward_to_instance_exit
)
contextualize_with_instance.__doc__ = (
"To be applied to a bound method. "
"Returns a callable that forwards context enters/exits to the bound instance"
)