Skip to content

Commit

Permalink
use similar code for all callback-applying methods
Browse files Browse the repository at this point in the history
avoid building nested chain iterables
avoid triggering defaultdict when looking up registries
apply functions as they are looked up
  • Loading branch information
davidism committed Oct 4, 2021
1 parent 166a2a6 commit 3f6cdbd
Showing 1 changed file with 52 additions and 51 deletions.
103 changes: 52 additions & 51 deletions src/flask/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,12 @@
from .signals import request_tearing_down
from .templating import DispatchingJinjaLoader
from .templating import Environment
from .typing import AfterRequestCallable
from .typing import BeforeFirstRequestCallable
from .typing import BeforeRequestCallable
from .typing import ResponseReturnValue
from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import TemplateFilterCallable
from .typing import TemplateGlobalCallable
from .typing import TemplateTestCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
from .wrappers import Request
from .wrappers import Response

Expand Down Expand Up @@ -745,20 +740,21 @@ def update_template_context(self, context: dict) -> None:
:param context: the context as a dictionary that is updated in place
to add extra variables.
"""
funcs: t.Iterable[TemplateContextProcessorCallable] = []
if None in self.template_context_processors:
funcs = chain(funcs, self.template_context_processors[None])
reqctx = _request_ctx_stack.top
if reqctx is not None:
for bp in reversed(request.blueprints):
if bp in self.template_context_processors:
funcs = chain(funcs, self.template_context_processors[bp])
names: t.Iterable[t.Optional[str]] = (None,)

# A template may be rendered outside a request context.
if request:
names = chain(names, reversed(request.blueprints))

# The values passed to render_template take precedence. Keep a
# copy to re-apply after all context functions.
orig_ctx = context.copy()
for func in funcs:
context.update(func())
# make sure the original values win. This makes it possible to
# easier add new variables in context processors without breaking
# existing views.

for name in names:
if name in self.template_context_processors:
for func in self.template_context_processors[name]:
context.update(func())

context.update(orig_ctx)

def make_shell_context(self) -> dict:
Expand Down Expand Up @@ -1278,9 +1274,10 @@ def _find_error_handler(
class, or ``None`` if a suitable handler is not found.
"""
exc_class, code = self._get_exc_class_and_code(type(e))
names = (*request.blueprints, None)

for c in [code, None] if code is not None else [None]:
for name in chain(request.blueprints, [None]):
for c in (code, None) if code is not None else (None,):
for name in names:
handler_map = self.error_handler_spec[name][c]

if not handler_map:
Expand Down Expand Up @@ -1800,19 +1797,19 @@ def inject_url_defaults(self, endpoint: str, values: dict) -> None:
.. versionadded:: 0.7
"""
funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None]
names: t.Iterable[t.Optional[str]] = (None,)

# url_for may be called outside a request context, parse the
# passed endpoint instead of using request.blueprints.
if "." in endpoint:
# This is called by url_for, which can be called outside a
# request, can't use request.blueprints.
bps = _split_blueprint_path(endpoint.rpartition(".")[0])
bp_funcs = chain.from_iterable(
self.url_default_functions[bp] for bp in reversed(bps)
names = chain(
names, reversed(_split_blueprint_path(endpoint.rpartition(".")[0]))
)
funcs = chain(funcs, bp_funcs)

for func in funcs:
func(endpoint, values)
for name in names:
if name in self.url_default_functions:
for func in self.url_default_functions[name]:
func(endpoint, values)

def handle_url_build_error(
self, error: Exception, endpoint: str, values: dict
Expand Down Expand Up @@ -1847,22 +1844,20 @@ def preprocess_request(self) -> t.Optional[ResponseReturnValue]:
value is handled as if it was the return value from the view, and
further request handling is stopped.
"""
names = (None, *reversed(request.blueprints))

funcs: t.Iterable[URLValuePreprocessorCallable] = []
for name in chain([None], reversed(request.blueprints)):
for name in names:
if name in self.url_value_preprocessors:
funcs = chain(funcs, self.url_value_preprocessors[name])
for func in funcs:
func(request.endpoint, request.view_args)
for url_func in self.url_value_preprocessors[name]:
url_func(request.endpoint, request.view_args)

funcs: t.Iterable[BeforeRequestCallable] = []
for name in chain([None], reversed(request.blueprints)):
for name in names:
if name in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[name])
for func in funcs:
rv = self.ensure_sync(func)()
if rv is not None:
return rv
for before_func in self.before_request_funcs[name]:
rv = self.ensure_sync(before_func)()

if rv is not None:
return rv

return None

Expand All @@ -1880,14 +1875,18 @@ def process_response(self, response: Response) -> Response:
instance of :attr:`response_class`.
"""
ctx = _request_ctx_stack.top
funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions
for name in chain(request.blueprints, [None]):

for func in ctx._after_request_functions:
response = self.ensure_sync(func)(response)

for name in chain(request.blueprints, (None,)):
if name in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[name]))
for handler in funcs:
response = self.ensure_sync(handler)(response)
for func in reversed(self.after_request_funcs[name]):
response = self.ensure_sync(func)(response)

if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)

return response

def do_teardown_request(
Expand Down Expand Up @@ -1915,12 +1914,12 @@ def do_teardown_request(
"""
if exc is _sentinel:
exc = sys.exc_info()[1]
funcs: t.Iterable[TeardownCallable] = []
for name in chain(request.blueprints, [None]):

for name in chain(request.blueprints, (None,)):
if name in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[name]))
for func in funcs:
self.ensure_sync(func)(exc)
for func in reversed(self.teardown_request_funcs[name]):
self.ensure_sync(func)(exc)

request_tearing_down.send(self, exc=exc)

def do_teardown_appcontext(
Expand All @@ -1942,8 +1941,10 @@ def do_teardown_appcontext(
"""
if exc is _sentinel:
exc = sys.exc_info()[1]

for func in reversed(self.teardown_appcontext_funcs):
self.ensure_sync(func)(exc)

appcontext_tearing_down.send(self, exc=exc)

def app_context(self) -> AppContext:
Expand Down

0 comments on commit 3f6cdbd

Please sign in to comment.