Skip to content

Forwarding context management

Thor Whalen edited this page Dec 18, 2021 · 2 revisions

Context

For context, the problem, and desired solution was originally discussed in this issue.

Know before going further, that the core of the solution is just:

class ContextualizeCall:
    def __init__(self, func, enter_func, exit_func):
        self.func, self.enter_func, self.exit_func = func, enter_func, exit_func
        # wrapping self with attributes of func (signature, name, etc.)
        wraps(func)(self)
        
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
        
    def __enter__(self):
        return self.enter_func(self.func)
    
    def __exit__(self, *exception_info):
        return self.exit_func(self.func, *exception_info)

The rest of the code is just to make it more production-ready (validation, etc.)

Situation

Let's make a toy "streamer" that has all it needs:

  • A read method that will yield items if and only if, the streamer was "turned on" (and raise error if not)
  • A mechanism to turn things on and off (a context manager)
class StreamHasNotBeenStarted(RuntimeError):
    """Raised when an action requires the stream to be 'on'"""
    
class Streamer:
    def __init__(self, iterable):
        self.iterable = iterable
        self.is_running = False
        self._read = None
        
    def enter(self):
#         print(f'{type(self).__name__}.enter')
        self._read = iter(self.iterable).__next__
        self.is_running = True
        
    def exit(self, *exc):
#         print(f'{type(self).__name__}.exit')
        self._read = None
        self.is_running = False
        
    __enter__, __exit__ = enter, exit
        
    def read(self):
        if not self.is_running:
            raise StreamHasNotBeenStarted(
                "The stream needs to be on/started (in a context) for that!"
            )
        return self._read()

Let's try it out:

s = Streamer('stream')

See that it works:

with s:
    assert s.read() == 's'
    assert s.read() == 't'

Let's make a reader test function to carry out our validation more easily.

def test_reader(reader):
    """An (initialized) reader should be able to be called six times and produce s t r e a m"""
    assert ''.join(reader() for _ in range(6)) == 'stream'

If we enter s's context, our reader will work!

reader = s.read

with s:
    test_reader(reader)

We can also perform this entering/exiting manually

s.enter()
reader = s.read
test_reader(reader)
s.exit()

But if we don't turn things on, it won't work

reader = s.read
try:
    # oops, forgot the enter s context
    test_reader(reader)
    it_worked = True
except StreamHasNotBeenStarted as e:
    it_worked = False
    assert isinstance(e, StreamHasNotBeenStarted)
    assert e.args[0] == 'The stream needs to be on/started (in a context) for that!'
assert not it_worked

But we can't turn the reader (read method) on -- it's not a context (it's instance is!)

reader = s.read

try:
    with reader:  # can we actually do this (answer: no! We can enter s, not s.read)
        test_reader(reader)
    it_worked = True
except Exception as e:
    it_worked = False
    assert isinstance(e, AttributeError)
    assert e.args[0] == '__enter__'  # well yeah, reader doesn't have an __enter__!
    
assert not it_worked

What we'd like is a contextualize_with_instance that satisfies this:

reader = contextualize_with_instance(s.read)

try:
    with reader:  # now we can enter the reader!
        test_reader(reader)
    it_worked = True
except Exception as e:
    it_worked = False

    
assert it_worked  # Hurray!

See code for contextualize_with_instance below.

The general solution objects

from inspect import signature, Parameter
from abc import abstractstaticmethod
from functools import wraps


class TypeValidationError(TypeError):
    """Raised if an object is not valid"""
    
    @abstractstaticmethod
    def is_valid(obj) -> bool:
        """Returns True if and only if obj is considered valid"""
        
    @classmethod
    def validate(cls, obj, *args, **kwargs):
        if not cls.is_valid(obj):
            raise cls(*args, **kwargs)
    
    
def null_enter_func(obj):
    return obj

def null_exit_func(obj, *exception_info):
    return None

def _has_min_num_of_arguments(func, mininum_n_args):
    return len(signature(func).parameters) >= mininum_n_args

def _callable_has_a_variadic_pos(func):
    """True if and only if the function has a variadic positional argument (i.e. *args)"""
    return any(x.kind == Parameter.VAR_POSITIONAL for x in signature(func).parameters.values())

class EnterFunctionValidation(TypeValidationError):
    """Raised if an function isn't a valid (context manager) __enter__ function"""
    @staticmethod
    def is_valid(obj):
        return callable(obj) and _has_min_num_of_arguments(obj, 1)
    
class ExitFunctionValidation(TypeValidationError):
    """Raised if an function isn't a valid (context manager) __exit__ function"""
    @staticmethod
    def is_valid(obj):
        return callable(obj) and (_callable_has_a_variadic_pos or _has_min_num_of_arguments(obj, 4))

# TODO: If needed: 
#  A more general Contextualizer would proxy/delegate all methods (including special ones) to the wrapped
class ContextualizeCall:
    def __init__(self, func, enter_func=null_enter_func, exit_func=null_exit_func):
        # validation
        if not callable(func):
            raise TypeError(f"First argument should be a callable, was: {func}")
        EnterFunctionValidation.validate(
            enter_func, 
            f"Not a valid enter function (should be a callable with at least one arg): {enter_func}"
        )
        ExitFunctionValidation.validate(
            exit_func, 
            f"Not a valid exit function (should be a callable with at least one arg or varadic args): {exit_func}",
        )
        # assignment
        self.func = func
        self.enter_func = enter_func
        self.exit_func = exit_func
        # wrapping self with attributes of func (signature, name, etc.)
        wraps(func)(self)
        
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
        
    def __enter__(self):
#         print(f'{type(self).__name__}.__enter__')
        return self.enter_func(self.func)
    
    def __exit__(self, *exception_info):
#         print(f'{type(self).__name__}.__exit__')
        return self.exit_func(self.func, *exception_info)

contextualize_with_instance

Using these general objects for the particular case of having bound methods forward their context management to the instances they're bound to

from functools import partial

def forward_to_instance_enter(obj):
    return obj.__self__.__enter__()

def forward_to_instance_exit(obj, *exception_info):
    return obj.__self__.__exit__(obj, *exception_info)

contextualize_with_instance = partial(
    ContextualizeCall, 
    enter_func=forward_to_instance_enter, 
    exit_func=forward_to_instance_exit
)
contextualize_with_instance.__doc__ = (
    "To be applied to a bound method. "
    "Returns a callable that forwards context enters/exits to the bound instance"
)
Clone this wiki locally