Skip to content

Commit

Permalink
♻️ Refactor and simplify internal data from solve_dependencies() us…
Browse files Browse the repository at this point in the history
…ing dataclasses (fastapi#12100)
  • Loading branch information
tiangolo authored Aug 31, 2024
1 parent 8d7d89e commit 5b7fa39
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
45 changes: 24 additions & 21 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,15 @@ async def solve_generator(
return await stack.enter_async_context(cm)


@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]


async def solve_dependencies(
*,
request: Union[Request, WebSocket],
Expand All @@ -539,13 +548,7 @@ async def solve_dependencies(
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
async_exit_stack: AsyncExitStack,
) -> Tuple[
Dict[str, Any],
List[Any],
Optional[StarletteBackgroundTasks],
Response,
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
]:
) -> SolvedDependency:
values: Dict[str, Any] = {}
errors: List[Any] = []
if response is None:
Expand Down Expand Up @@ -587,27 +590,21 @@ async def solve_dependencies(
dependency_cache=dependency_cache,
async_exit_stack=async_exit_stack,
)
(
sub_values,
sub_errors,
background_tasks,
_, # the subdependency returns the same response we have
sub_dependency_cache,
) = solved_result
dependency_cache.update(sub_dependency_cache)
if sub_errors:
errors.extend(sub_errors)
background_tasks = solved_result.background_tasks
dependency_cache.update(solved_result.dependency_cache)
if solved_result.errors:
errors.extend(solved_result.errors)
continue
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
solved = dependency_cache[sub_dependant.cache_key]
elif is_gen_callable(call) or is_async_gen_callable(call):
solved = await solve_generator(
call=call, stack=async_exit_stack, sub_values=sub_values
call=call, stack=async_exit_stack, sub_values=solved_result.values
)
elif is_coroutine_callable(call):
solved = await call(**sub_values)
solved = await call(**solved_result.values)
else:
solved = await run_in_threadpool(call, **sub_values)
solved = await run_in_threadpool(call, **solved_result.values)
if sub_dependant.name is not None:
values[sub_dependant.name] = solved
if sub_dependant.cache_key not in dependency_cache:
Expand Down Expand Up @@ -654,7 +651,13 @@ async def solve_dependencies(
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
)
return values, errors, background_tasks, response, dependency_cache
return SolvedDependency(
values=values,
errors=errors,
background_tasks=background_tasks,
response=response,
dependency_cache=dependency_cache,
)


def request_params_to_args(
Expand Down
33 changes: 21 additions & 12 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,26 +292,34 @@ async def app(request: Request) -> Response:
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
values, errors, background_tasks, sub_response, _ = solved_result
errors = solved_result.errors
if not errors:
raw_response = await run_endpoint_function(
dependant=dependant, values=values, is_coroutine=is_coroutine
dependant=dependant,
values=solved_result.values,
is_coroutine=is_coroutine,
)
if isinstance(raw_response, Response):
if raw_response.background is None:
raw_response.background = background_tasks
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {"background": background_tasks}
response_args: Dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
# response class, in the case of redirect it's 307
current_status_code = (
status_code if status_code else sub_response.status_code
status_code
if status_code
else solved_result.response.status_code
)
if current_status_code is not None:
response_args["status_code"] = current_status_code
if sub_response.status_code:
response_args["status_code"] = sub_response.status_code
if solved_result.response.status_code:
response_args["status_code"] = (
solved_result.response.status_code
)
content = await serialize_response(
field=response_field,
response_content=raw_response,
Expand All @@ -326,7 +334,7 @@ async def app(request: Request) -> Response:
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
response.body = b""
response.headers.raw.extend(sub_response.headers.raw)
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
Expand Down Expand Up @@ -360,11 +368,12 @@ async def app(websocket: WebSocket) -> None:
dependency_overrides_provider=dependency_overrides_provider,
async_exit_stack=async_exit_stack,
)
values, errors, _, _2, _3 = solved_result
if errors:
raise WebSocketRequestValidationError(_normalize_errors(errors))
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**values)
await dependant.call(**solved_result.values)

return app

Expand Down

0 comments on commit 5b7fa39

Please sign in to comment.