Skip to content
This repository was archived by the owner on Nov 10, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/class_resource_view1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from fastapi_utils import Resource


class MyApi(Resource):

def get(self):
return 'done'
16 changes: 16 additions & 0 deletions docs/src/class_resource_view2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fastapi import FastAPI

from docs.src.class_resource_view1 import MyApi
from fastapi_utils import Api


def main():
app = FastAPI()
api = Api(app)

myapi = MyApi()
api.add_resource(myapi, '/uri')


if __name__ == '__main__':
main()
18 changes: 18 additions & 0 deletions docs/src/class_resource_view3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from fastapi import FastAPI

from docs.src.class_resource_view1 import MyApi
from pymongo import MongoClient
from fastapi_utils import Api


def main():
app = FastAPI()
api = Api(app)

mongo_client = MongoClient('mongodb://localhost:27017')
myapi = MyApi(mongo_client)
api.add_resource(myapi, '/uri')


if __name__ == '__main__':
main()
27 changes: 27 additions & 0 deletions docs/src/class_resource_view4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pydantic import BaseModel

from fastapi_utils import Resource, set_responses


# Setup
class ResponseModel(BaseModel):
answer: str


class NotFoundModel(BaseModel):
IsFound: bool

# Setup end


class MyApi(Resource):
def __init__(self, mongo_client):
self.mongo = mongo_client

@set_responses(ResponseModel)
def get(self):
return 'done'

@set_responses(ResponseModel, 201, {404: NotFoundModel})
def post(self):
return 'Done again'
40 changes: 40 additions & 0 deletions docs/user-guide/class-resource.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#### Source module: [`fastapi_utils.cbv_base`](https://github.com/yuval9313/fastapi-utils/blob/master/fastapi_utils/cbv_base.py){.internal-link target=_blank}

---

If you familiar with Flask-RESTful and you want to quickly create CRUD application,
full of features and resources, and you also support OOP you might want to use this Resource based class

---

Similar to Flask-RESTful all we have to do is create a class at inherit from `Resource`
```python hl_lines="61 62 74 75 85 86 100 101"
{!./src/class_resource_view1.py!}
```

And then in `app.py`
```python hl_lines="61 62 74 75 85 86 100 101"
{!./src/class_resource_view2.py!}
```

And that's it, You now got an app.

---

Now how to handle things when it starting to get complicated:

##### Resource with dependencies
Since initialization is taking place **before** adding the resource to the api,
we can just insert our dependencies in the instance init: (`app.py`)
```python hl_lines="61 62 74 75 85 86 100 101"
{!./src/class_resource_view3.py!}
```

#### Responses
FastApi swagger is all beautiful with the responses and fit status codes,
it is no sweat to declare those.

Inside the resource class have `@set_responses` before the function
```python hl_lines="61 62 74 75 85 86 100 101"
{!./src/class_resource_view4.py!}
```
9 changes: 9 additions & 0 deletions fastapi_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
__version__ = "0.2.1"

from .cbv_base import Api, Resource, set_responses, take_init_parameters

__all__ = [
"Api",
"Resource",
"set_responses",
"take_init_parameters",
]
94 changes: 73 additions & 21 deletions fastapi_utils/cbv.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import inspect
from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints
from typing import Any, Callable, List, Tuple, Type, TypeVar, Union, cast, get_type_hints

from fastapi import APIRouter, Depends
from fastapi.routing import APIRoute
from pydantic.typing import is_classvar
from starlette.routing import Route, WebSocketRoute

T = TypeVar("T")

CBV_CLASS_KEY = "__cbv_class__"
INCLUDE_INIT_PARAMS_KEY = "__include_init_params__"
RETURN_TYPES_FUNC_KEY = "__return_types_func__"


def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]:
def cbv(router: APIRouter, *urls: str) -> Callable[[Type[T]], Type[T]]:
"""
This function returns a decorator that converts the decorated into a class-based view for the provided router.

Expand All @@ -23,34 +26,23 @@ def cbv(router: APIRouter) -> Callable[[Type[T]], Type[T]]:
"""

def decorator(cls: Type[T]) -> Type[T]:
return _cbv(router, cls)
# Define cls as cbv class exclusively when using the decorator
return _cbv(router, cls, *urls)

return decorator


def _cbv(router: APIRouter, cls: Type[T]) -> Type[T]:
def _cbv(router: APIRouter, cls: Type[T], *urls: str, instance: Any = None) -> Type[T]:
"""
Replaces any methods of the provided class `cls` that are endpoints of routes in `router` with updated
function calls that will properly inject an instance of `cls`.
"""
_init_cbv(cls)
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
functions_set = set(func for _, func in function_members)
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)
_init_cbv(cls, instance)
_register_endpoints(router, cls, *urls)
return cls


def _init_cbv(cls: Type[Any]) -> None:
def _init_cbv(cls: Type[Any], instance: Any = None) -> None:
"""
Idempotently modifies the provided `cls`, performing the following modifications:
* The `__init__` function is updated to set any class-annotated dependencies as instance attributes
Expand All @@ -64,6 +56,7 @@ def _init_cbv(cls: Type[Any]) -> None:
new_parameters = [
x for x in old_parameters if x.kind not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
]

dependency_names: List[str] = []
for name, hint in get_type_hints(cls).items():
if is_classvar(hint):
Expand All @@ -73,19 +66,77 @@ def _init_cbv(cls: Type[Any]) -> None:
new_parameters.append(
inspect.Parameter(name=name, kind=inspect.Parameter.KEYWORD_ONLY, annotation=hint, **parameter_kwargs)
)
new_signature = old_signature.replace(parameters=new_parameters)
new_signature = inspect.Signature(())
if not instance or hasattr(cls, INCLUDE_INIT_PARAMS_KEY):
new_signature = old_signature.replace(parameters=new_parameters)

def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
for dep_name in dependency_names:
dep_value = kwargs.pop(dep_name)
setattr(self, dep_name, dep_value)
old_init(self, *args, **kwargs)
if instance and not hasattr(cls, INCLUDE_INIT_PARAMS_KEY):
self.__class__ = instance.__class__
self.__dict__ = instance.__dict__
else:
old_init(self, *args, **kwargs)

setattr(cls, "__signature__", new_signature)
setattr(cls, "__init__", new_init)
setattr(cls, CBV_CLASS_KEY, True)


def _register_endpoints(router: APIRouter, cls: Type[Any], *urls: str) -> None:
cbv_router = APIRouter()
function_members = inspect.getmembers(cls, inspect.isfunction)
for url in urls:
_allocate_routes_by_method_name(router, url, function_members)
router_roles = []
for route in router.routes:
assert isinstance(route, APIRoute)
route_methods: Any = route.methods
cast(Tuple[Any], route_methods)
router_roles.append((route.path, tuple(route_methods)))

if len(set(router_roles)) != len(router_roles):
raise Exception("An identical route role has been implemented more then once")

functions_set = set(func for _, func in function_members)
cbv_routes = [
route
for route in router.routes
if isinstance(route, (Route, WebSocketRoute)) and route.endpoint in functions_set
]
for route in cbv_routes:
router.routes.remove(route)
_update_cbv_route_endpoint_signature(cls, route)
cbv_router.routes.append(route)
router.include_router(cbv_router)


def _allocate_routes_by_method_name(router: APIRouter, url: str, function_members: List[Tuple[str, Any]]) -> None:
existing_routes_endpoints: List[Tuple[Any, str]] = [
(route.endpoint, route.path) for route in router.routes if isinstance(route, APIRoute)
]
for name, func in function_members:
if hasattr(router, name) and not name.startswith("__") and not name.endswith("__"):
if (func, url) not in existing_routes_endpoints:
response_model = None
responses = None
status_code = 200
return_types_func = getattr(func, RETURN_TYPES_FUNC_KEY, None)
if return_types_func:
response_model, status_code, responses = return_types_func()

api_resource = router.api_route(
url,
methods=[name.capitalize()],
response_model=response_model,
status_code=status_code,
responses=responses,
)
api_resource(func)


def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, WebSocketRoute]) -> None:
"""
Fixes the endpoint signature for a cbv route to ensure FastAPI performs dependency injection properly.
Expand All @@ -98,5 +149,6 @@ def _update_cbv_route_endpoint_signature(cls: Type[Any], route: Union[Route, Web
new_parameters = [new_first_parameter] + [
parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY) for parameter in old_parameters[1:]
]

new_signature = old_signature.replace(parameters=new_parameters)
setattr(route.endpoint, "__signature__", new_signature)
36 changes: 36 additions & 0 deletions fastapi_utils/cbv_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Dict, Optional, Tuple

from fastapi import APIRouter, FastAPI

from fastapi_utils.cbv import INCLUDE_INIT_PARAMS_KEY, RETURN_TYPES_FUNC_KEY, _cbv


class Resource:
# raise NotImplementedError
pass


class Api:
def __init__(self, app: FastAPI):
self.app = app

def add_resource(self, resource: Resource, *urls: str, **kwargs: Any) -> None:
router = APIRouter()
_cbv(router, type(resource), *urls, instance=resource)
self.app.include_router(router)


def take_init_parameters(cls: Any) -> Any:
setattr(cls, INCLUDE_INIT_PARAMS_KEY, True)
return cls


def set_responses(response: Any, status_code: int = 200, responses: Dict[str, Any] = None) -> Any:
def decorator(func: Any) -> Any:
def get_responses() -> Tuple[Any, int, Optional[Dict[str, Any]]]:
return response, status_code, responses

setattr(func, RETURN_TYPES_FUNC_KEY, get_responses)
return func

return decorator
71 changes: 71 additions & 0 deletions tests/test_cbv_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Any, Dict, List, Union

from fastapi import FastAPI
from starlette.testclient import TestClient

from fastapi_utils.cbv_base import Api, Resource, set_responses


def test_cbv() -> None:
class CBV(Resource):
def __init__(self, z: int = 1):
super().__init__()
self.y = 1
self.z = z

@set_responses(int)
def post(self, x: int) -> int:
print(x)
return x + self.y + self.z

@set_responses(bool)
def get(self) -> bool:
return hasattr(self, "cy")

app = FastAPI()
api = Api(app)
cbv = CBV(2)
api.add_resource(cbv, "/", "/classvar")

client = TestClient(app)
response_1 = client.post("/", params={"x": 1}, json={})
assert response_1.status_code == 200
assert response_1.content == b"4"

response_2 = client.get("/classvar")
assert response_2.status_code == 200
assert response_2.content == b"false"


def test_arg_in_path() -> None:
class TestCBV(Resource):
@set_responses(str)
def get(self, item_id: str) -> str:
return item_id

app = FastAPI()
api = Api(app)

test_cbv_resource = TestCBV()
api.add_resource(test_cbv_resource, "/{item_id}")

assert TestClient(app).get("/test").json() == "test"


def test_multiple_routes() -> None:
class RootHandler(Resource):
def get(self, item_path: str = None) -> Union[List[Any], Dict[str, str]]:
if item_path:
return {"item_path": item_path}
return []

app = FastAPI()
api = Api(app)

root_handler_resource = RootHandler()
api.add_resource(root_handler_resource, "/items/?", "/items/{item_path:path}")

client = TestClient(app)

assert client.get("/items/1").json() == {"item_path": "1"}
assert client.get("/items").json() == []