From aa21814a89853c17c139054a5c51f0bb1ea68a0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Thu, 5 Sep 2024 13:24:36 +0200 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20deciding=20if?= =?UTF-8?q?=20`embed`=20body=20fields,=20do=20not=20overwrite=20fields,=20?= =?UTF-8?q?compute=20once=20per=20router,=20refactor=20internals=20in=20pr?= =?UTF-8?q?eparation=20for=20Pydantic=20models=20in=20`Form`,=20`Query`=20?= =?UTF-8?q?and=20others=20(#12117)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi/_compat.py | 15 ++ fastapi/dependencies/utils.py | 300 ++++++++++++++++++------------- fastapi/param_functions.py | 4 +- fastapi/params.py | 3 +- fastapi/routing.py | 26 ++- tests/test_compat.py | 13 +- tests/test_forms_single_param.py | 99 ++++++++++ 7 files changed, 324 insertions(+), 136 deletions(-) create mode 100644 tests/test_forms_single_param.py diff --git a/fastapi/_compat.py b/fastapi/_compat.py index 06b847b4f3818..f940d65973223 100644 --- a/fastapi/_compat.py +++ b/fastapi/_compat.py @@ -279,6 +279,12 @@ def create_body_model( BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload] return BodyModel + def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: + return [ + ModelField(field_info=field_info, name=name) + for name, field_info in model.model_fields.items() + ] + else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -513,6 +519,9 @@ def create_body_model( BodyModel.__fields__[f.name] = f # type: ignore[index] return BodyModel + def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: + return list(model.__fields__.values()) # type: ignore[attr-defined] + def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -532,6 +541,12 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + for arg in get_args(annotation): + if field_annotation_is_sequence(arg): + return True + return False return _annotation_is_sequence(annotation) or _annotation_is_sequence( get_origin(annotation) ) diff --git a/fastapi/dependencies/utils.py b/fastapi/dependencies/utils.py index 0dcba62f130a9..7ac18d941cd85 100644 --- a/fastapi/dependencies/utils.py +++ b/fastapi/dependencies/utils.py @@ -59,7 +59,13 @@ from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool -from starlette.datastructures import FormData, Headers, QueryParams, UploadFile +from starlette.datastructures import ( + FormData, + Headers, + ImmutableMultiDict, + QueryParams, + UploadFile, +) from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket @@ -282,7 +288,7 @@ def get_dependant( ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_details.field is not None - if is_body_param(param_field=param_details.field, is_path_param=is_path_param): + if isinstance(param_details.field.field_info, params.Body): dependant.body_params.append(param_details.field) else: add_param_to_fields(field=param_details.field, dependant=dependant) @@ -466,29 +472,16 @@ def analyze_param( required=field_info.default in (Required, Undefined), field_info=field_info, ) + if is_path_param: + assert is_scalar_field( + field=field + ), "Path params must be of one of the supported types" + elif isinstance(field_info, params.Query): + assert is_scalar_field(field) or is_scalar_sequence_field(field) return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) -def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: - if is_path_param: - assert is_scalar_field( - field=param_field - ), "Path params must be of one of the supported types" - return False - elif is_scalar_field(field=param_field): - return False - elif isinstance( - param_field.field_info, (params.Query, params.Header) - ) and is_scalar_sequence_field(param_field): - return False - else: - assert isinstance( - param_field.field_info, params.Body - ), f"Param: {param_field.name} can only be a request body, using Body()" - return True - - def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) @@ -557,6 +550,7 @@ async def solve_dependencies( dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, async_exit_stack: AsyncExitStack, + embed_body_fields: bool, ) -> SolvedDependency: values: Dict[str, Any] = {} errors: List[Any] = [] @@ -598,6 +592,7 @@ async def solve_dependencies( dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, ) background_tasks = solved_result.background_tasks dependency_cache.update(solved_result.dependency_cache) @@ -640,7 +635,9 @@ async def solve_dependencies( body_values, body_errors, ) = await request_body_to_args( # body_params checked above - required_params=dependant.body_params, received_body=body + body_fields=dependant.body_params, + received_body=body, + embed_body_fields=embed_body_fields, ) values.update(body_values) errors.extend(body_errors) @@ -669,138 +666,185 @@ async def solve_dependencies( ) +def _validate_value_with_model_field( + *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] +) -> Tuple[Any, List[Any]]: + if value is None: + if field.required: + return None, [get_missing_field_error(loc=loc)] + else: + return deepcopy(field.default), [] + v_, errors_ = field.validate(value, values, loc=loc) + if isinstance(errors_, ErrorWrapper): + return None, [errors_] + elif isinstance(errors_, list): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + return None, new_errors + else: + return v_, [] + + +def _get_multidict_value(field: ModelField, values: Mapping[str, Any]) -> Any: + if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): + value = values.getlist(field.alias) + else: + value = values.get(field.alias, None) + if ( + value is None + or ( + isinstance(field.field_info, params.Form) + and isinstance(value, str) # For type checks + and value == "" + ) + or (is_sequence_field(field) and len(value) == 0) + ): + if field.required: + return + else: + return deepcopy(field.default) + return value + + def request_params_to_args( - required_params: Sequence[ModelField], + fields: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], ) -> Tuple[Dict[str, Any], List[Any]]: - values = {} + values: Dict[str, Any] = {} errors = [] - for field in required_params: - if is_scalar_sequence_field(field) and isinstance( - received_params, (QueryParams, Headers) - ): - value = received_params.getlist(field.alias) or field.default - else: - value = received_params.get(field.alias) + for field in fields: + value = _get_multidict_value(field, received_params) field_info = field.field_info assert isinstance( field_info, params.Param ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) - if value is None: - if field.required: - errors.append(get_missing_field_error(loc=loc)) - else: - values[field.name] = deepcopy(field.default) - continue - v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) - errors.extend(new_errors) + v_, errors_ = _validate_value_with_model_field( + field=field, value=value, values=values, loc=loc + ) + if errors_: + errors.extend(errors_) else: values[field.name] = v_ return values, errors +def _should_embed_body_fields(fields: List[ModelField]) -> bool: + if not fields: + return False + # More than one dependency could have the same field, it would show up as multiple + # fields but it's the same one, so count them by name + body_param_names_set = {field.name for field in fields} + # A top level field has to be a single field, not multiple + if len(body_param_names_set) > 1: + return True + first_field = fields[0] + # If it explicitly specifies it is embedded, it has to be embedded + if getattr(first_field.field_info, "embed", None): + return True + # If it's a Form (or File) field, it has to be a BaseModel to be top level + # otherwise it has to be embedded, so that the key value pair can be extracted + if isinstance(first_field.field_info, params.Form): + return True + return False + + +async def _extract_form_body( + body_fields: List[ModelField], + received_body: FormData, +) -> Dict[str, Any]: + values = {} + first_field = body_fields[0] + first_field_info = first_field.field_info + + for field in body_fields: + value = _get_multidict_value(field, received_body) + if ( + isinstance(first_field_info, params.File) + and is_bytes_field(field) + and isinstance(value, UploadFile) + ): + value = await value.read() + elif ( + is_bytes_sequence_field(field) + and isinstance(first_field_info, params.File) + and value_is_sequence(value) + ): + # For types + assert isinstance(value, sequence_types) # type: ignore[arg-type] + results: List[Union[bytes, str]] = [] + + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]], + ) -> None: + result = await fn() + results.append(result) # noqa: B023 + + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = serialize_sequence_value(field=field, value=results) + values[field.name] = value + return values + + async def request_body_to_args( - required_params: List[ModelField], + body_fields: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], + embed_body_fields: bool, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: - values = {} + values: Dict[str, Any] = {} errors: List[Dict[str, Any]] = [] - if required_params: - field = required_params[0] - field_info = field.field_info - embed = getattr(field_info, "embed", None) - field_alias_omitted = len(required_params) == 1 and not embed - if field_alias_omitted: - received_body = {field.alias: received_body} - - for field in required_params: - loc: Tuple[str, ...] - if field_alias_omitted: - loc = ("body",) - else: - loc = ("body", field.alias) - - value: Optional[Any] = None - if received_body is not None: - if (is_sequence_field(field)) and isinstance(received_body, FormData): - value = received_body.getlist(field.alias) - else: - try: - value = received_body.get(field.alias) - except AttributeError: - errors.append(get_missing_field_error(loc)) - continue - if ( - value is None - or (isinstance(field_info, params.Form) and value == "") - or ( - isinstance(field_info, params.Form) - and is_sequence_field(field) - and len(value) == 0 - ) - ): - if field.required: - errors.append(get_missing_field_error(loc)) - else: - values[field.name] = deepcopy(field.default) + assert body_fields, "request_body_to_args() should be called with fields" + single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields + first_field = body_fields[0] + body_to_process = received_body + if isinstance(received_body, FormData): + body_to_process = await _extract_form_body(body_fields, received_body) + + if single_not_embedded_field: + loc: Tuple[str, ...] = ("body",) + v_, errors_ = _validate_value_with_model_field( + field=first_field, value=body_to_process, values=values, loc=loc + ) + return {first_field.name: v_}, errors_ + for field in body_fields: + loc = ("body", field.alias) + value: Optional[Any] = None + if body_to_process is not None: + try: + value = body_to_process.get(field.alias) + # If the received body is a list, not a dict + except AttributeError: + errors.append(get_missing_field_error(loc)) continue - if ( - isinstance(field_info, params.File) - and is_bytes_field(field) - and isinstance(value, UploadFile) - ): - value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, params.File) - and value_is_sequence(value) - ): - # For types - assert isinstance(value, sequence_types) # type: ignore[arg-type] - results: List[Union[bytes, str]] = [] - - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]], - ) -> None: - result = await fn() - results.append(result) # noqa: B023 - - async with anyio.create_task_group() as tg: - for sub_value in value: - tg.start_soon(process_fn, sub_value.read) - value = serialize_sequence_value(field=field, value=results) - - v_, errors_ = field.validate(value, values, loc=loc) - - if isinstance(errors_, list): - errors.extend(errors_) - elif errors_: - errors.append(errors_) - else: - values[field.name] = v_ + v_, errors_ = _validate_value_with_model_field( + field=field, value=value, values=values, loc=loc + ) + if errors_: + errors.extend(errors_) + else: + values[field.name] = v_ return values, errors -def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: - flat_dependant = get_flat_dependant(dependant) +def get_body_field( + *, flat_dependant: Dependant, name: str, embed_body_fields: bool +) -> Optional[ModelField]: + """ + Get a ModelField representing the request body for a path operation, combining + all body parameters into a single field if necessary. + + Used to check if it's form data (with `isinstance(body_field, params.Form)`) + or JSON and to generate the JSON Schema for a request body. + + This is **not** used to validate/parse the request body, that's done with each + individual body parameter. + """ if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] - field_info = first_param.field_info - embed = getattr(field_info, "embed", None) - body_param_names_set = {param.name for param in flat_dependant.body_params} - if len(body_param_names_set) == 1 and not embed: + if not embed_body_fields: return first_param - # If one field requires to embed, all have to be embedded - # in case a sub-dependency is evaluated with a single unique body field - # That is combined (embedded) with other body fields - for param in flat_dependant.body_params: - setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name BodyModel = create_body_model( fields=flat_dependant.body_params, model_name=model_name diff --git a/fastapi/param_functions.py b/fastapi/param_functions.py index 0d5f27af48a74..7ddaace25a936 100644 --- a/fastapi/param_functions.py +++ b/fastapi/param_functions.py @@ -1282,7 +1282,7 @@ def Body( # noqa: N802 ), ] = _Unset, embed: Annotated[ - bool, + Union[bool, None], Doc( """ When `embed` is `True`, the parameter will be expected in a JSON body as a @@ -1294,7 +1294,7 @@ def Body( # noqa: N802 [FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter). """ ), - ] = False, + ] = None, media_type: Annotated[ str, Doc( diff --git a/fastapi/params.py b/fastapi/params.py index cc2a5c13c0261..3dfa5a1a381e5 100644 --- a/fastapi/params.py +++ b/fastapi/params.py @@ -479,7 +479,7 @@ def __init__( *, default_factory: Union[Callable[[], Any], None] = _Unset, annotation: Optional[Any] = None, - embed: bool = False, + embed: Union[bool, None] = None, media_type: str = "application/json", alias: Optional[str] = None, alias_priority: Union[int, None] = _Unset, @@ -642,7 +642,6 @@ def __init__( default=default, default_factory=default_factory, annotation=annotation, - embed=True, media_type=media_type, alias=alias, alias_priority=alias_priority, diff --git a/fastapi/routing.py b/fastapi/routing.py index 61a112fc47ea6..86e30360216cb 100644 --- a/fastapi/routing.py +++ b/fastapi/routing.py @@ -33,8 +33,10 @@ from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( + _should_embed_body_fields, get_body_field, get_dependant, + get_flat_dependant, get_parameterless_sub_dependant, get_typed_return_annotation, solve_dependencies, @@ -225,6 +227,7 @@ def get_request_handler( response_model_exclude_defaults: bool = False, response_model_exclude_none: bool = False, dependency_overrides_provider: Optional[Any] = None, + embed_body_fields: bool = False, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" is_coroutine = asyncio.iscoroutinefunction(dependant.call) @@ -291,6 +294,7 @@ async def app(request: Request) -> Response: body=body, dependency_overrides_provider=dependency_overrides_provider, async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, ) errors = solved_result.errors if not errors: @@ -354,7 +358,9 @@ async def app(request: Request) -> Response: def get_websocket_app( - dependant: Dependant, dependency_overrides_provider: Optional[Any] = None + dependant: Dependant, + dependency_overrides_provider: Optional[Any] = None, + embed_body_fields: bool = False, ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: async with AsyncExitStack() as async_exit_stack: @@ -367,6 +373,7 @@ async def app(websocket: WebSocket) -> None: dependant=dependant, dependency_overrides_provider=dependency_overrides_provider, async_exit_stack=async_exit_stack, + embed_body_fields=embed_body_fields, ) if solved_result.errors: raise WebSocketRequestValidationError( @@ -399,11 +406,15 @@ def __init__( 0, get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) - + self._flat_dependant = get_flat_dependant(self.dependant) + self._embed_body_fields = _should_embed_body_fields( + self._flat_dependant.body_params + ) self.app = websocket_session( get_websocket_app( dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, + embed_body_fields=self._embed_body_fields, ) ) @@ -544,7 +555,15 @@ def __init__( 0, get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) - self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) + self._flat_dependant = get_flat_dependant(self.dependant) + self._embed_body_fields = _should_embed_body_fields( + self._flat_dependant.body_params + ) + self.body_field = get_body_field( + flat_dependant=self._flat_dependant, + name=self.unique_id, + embed_body_fields=self._embed_body_fields, + ) self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: @@ -561,6 +580,7 @@ def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response] response_model_exclude_defaults=self.response_model_exclude_defaults, response_model_exclude_none=self.response_model_exclude_none, dependency_overrides_provider=self.dependency_overrides_provider, + embed_body_fields=self._embed_body_fields, ) def matches(self, scope: Scope) -> Tuple[Match, Scope]: diff --git a/tests/test_compat.py b/tests/test_compat.py index bf268b860b1dc..270475bf3a42a 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -1,11 +1,13 @@ -from typing import List, Union +from typing import Any, Dict, List, Union from fastapi import FastAPI, UploadFile from fastapi._compat import ( ModelField, Undefined, _get_model_config, + get_model_fields, is_bytes_sequence_annotation, + is_scalar_field, is_uploadfile_sequence_annotation, ) from fastapi.testclient import TestClient @@ -91,3 +93,12 @@ def test_is_uploadfile_sequence_annotation(): # and other types, but I'm not even sure it's a good idea to support it as a first # class "feature" assert is_uploadfile_sequence_annotation(Union[List[str], List[UploadFile]]) + + +def test_is_pv1_scalar_field(): + # For coverage + class Model(BaseModel): + foo: Union[str, Dict[str, Any]] + + fields = get_model_fields(Model) + assert not is_scalar_field(fields[0]) diff --git a/tests/test_forms_single_param.py b/tests/test_forms_single_param.py new file mode 100644 index 0000000000000..3bb951441f52c --- /dev/null +++ b/tests/test_forms_single_param.py @@ -0,0 +1,99 @@ +from fastapi import FastAPI, Form +from fastapi.testclient import TestClient +from typing_extensions import Annotated + +app = FastAPI() + + +@app.post("/form/") +def post_form(username: Annotated[str, Form()]): + return username + + +client = TestClient(app) + + +def test_single_form_field(): + response = client.post("/form/", data={"username": "Rick"}) + assert response.status_code == 200, response.text + assert response.json() == "Rick" + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": {"title": "FastAPI", "version": "0.1.0"}, + "paths": { + "/form/": { + "post": { + "summary": "Post Form", + "operationId": "post_form_form__post", + "requestBody": { + "content": { + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Body_post_form_form__post" + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {}}}, + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + } + } + }, + "components": { + "schemas": { + "Body_post_form_form__post": { + "properties": {"username": {"type": "string", "title": "Username"}}, + "type": "object", + "required": ["username"], + "title": "Body_post_form_form__post", + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": { + "anyOf": [{"type": "string"}, {"type": "integer"}] + }, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + }