|
| 1 | +from asgiref.sync import iscoroutinefunction, markcoroutinefunction |
| 2 | + |
| 3 | +from django.core.exceptions import ImproperlyConfigured |
| 4 | + |
| 5 | + |
| 6 | +class AsyncMiddlewareMixin: |
| 7 | + sync_capable = False |
| 8 | + async_capable = True |
| 9 | + |
| 10 | + def __init__(self, get_response): |
| 11 | + if get_response is None: |
| 12 | + raise ValueError("get_response must be provided.") |
| 13 | + self.get_response = get_response |
| 14 | + # If get_response is not an async function, raise an error. |
| 15 | + self.async_mode = iscoroutinefunction(self.get_response) or iscoroutinefunction( |
| 16 | + getattr(self.get_response, "__call__", None) |
| 17 | + ) |
| 18 | + if self.async_mode: |
| 19 | + # Mark the class as async-capable. |
| 20 | + markcoroutinefunction(self) |
| 21 | + else: |
| 22 | + raise ImproperlyConfigured("get_response must be async") |
| 23 | + |
| 24 | + super().__init__() |
| 25 | + |
| 26 | + def __repr__(self): |
| 27 | + return "<%s get_response=%s>" % ( |
| 28 | + self.__class__.__qualname__, |
| 29 | + getattr( |
| 30 | + self.get_response, |
| 31 | + "__qualname__", |
| 32 | + self.get_response.__class__.__name__, |
| 33 | + ), |
| 34 | + ) |
| 35 | + |
| 36 | + async def __call__(self, request): |
| 37 | + response = None |
| 38 | + if hasattr(self, "process_request"): |
| 39 | + response = await self.process_request(request) |
| 40 | + response = response or await self.get_response(request) |
| 41 | + if hasattr(self, "process_response"): |
| 42 | + response = await self.process_response(request, response) |
| 43 | + return response |
0 commit comments