Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing_extensions import Literal

from .errors import RateLimitExceeded
from .util import add_request_signature
from .wrappers import Limit, LimitGroup

# used to annotate get_app_config method
Expand Down Expand Up @@ -656,6 +657,7 @@ def __limit_decorator(
_scope = scope if shared else None

def decorator(func: Callable[..., Response]):
func = add_request_signature(func)
keyfunc = key_func or self._key_func
name = f"{func.__module__}.{func.__name__}"
dynamic_limit = None
Expand Down
70 changes: 70 additions & 0 deletions slowapi/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from asyncio import iscoroutinefunction
from functools import wraps
from inspect import signature, Parameter

from starlette.requests import Request
from typing import Callable, List


def get_ipaddr(request: Request) -> str:
Expand All @@ -25,3 +30,68 @@ def get_remote_address(request: Request) -> str:
return "127.0.0.1"

return request.client.host


def get_request_param(func: Callable) -> List[Parameter]:
"""Retrieve list of parameters that are a Request"""
sig = signature(func)
params = list(sig.parameters.values())
return [param for param in params if param.annotation == Request]


def add_request_signature(func: Callable):
"""Adds starlette.Request argument to function's signature so that it'll be accessible to custom decorators"""

def scrap_req(func: Callable, args, kwargs):
if getattr(func, "scrap_req", False):
req_param = get_request_param(func)[0]
try:
del kwargs[req_param.name]
except KeyError:
# Request is not in kwargs for some reason delete from args
# Deletion index: 0
del args[0]
return args, kwargs

if iscoroutinefunction(func):

@wraps(func)
async def wrapper(*args, **kwargs):
args, kwargs = scrap_req(func, args, kwargs)
return await func(*args, **kwargs)

else:

@wraps(func)
def wrapper(*args, **kwargs):
args, kwargs = scrap_req(func, args, kwargs)
return func(*args, **kwargs)

sig = signature(func)
params = list(sig.parameters.values())

rq = get_request_param(func)
if len(rq) >= 1:
if not hasattr(func, "scrap_req"): # Ignore if already set
func.scrap_req = False
else:
func.scrap_req = True
name = "request" # Slowapi should allow for request to be anything <- param name generator
param_names = [pname.name for pname in params]
if name not in param_names:
func.req = name

req = Parameter(
name=name, kind=Parameter.POSITIONAL_OR_KEYWORD, annotation=Request
)
params.insert(0, req)
sig = sig.replace(parameters=params)
func.__signature__ = sig
else:
fname = f"{func.__module__}.{func.__name__}"
raise Exception(
f"Remove 'request' argument from function {fname}"
f" or add [request : starlette.Request] manually."
)

return wrapper
52 changes: 11 additions & 41 deletions tests/test_fastapi_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,47 +144,18 @@ async def t1(request: Request, response: Response):
== 429
)

def test_endpoint_missing_request_param(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr)

with pytest.raises(Exception) as exc_info:

@app.get("/t3")
@limiter.limit("5/minute")
async def t3():
return PlainTextResponse("test")

assert exc_info.match(
r"""^No "request" or "websocket" argument on function .*"""
)

def test_endpoint_missing_request_param_sync(self, build_fastapi_app):
def test_endpoint_request_param_invalid(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr)

with pytest.raises(Exception) as exc_info:

@app.get("/t3_sync")
@app.get("/t4")
@limiter.limit("5/minute")
def t3():
async def t4(request: str = None):
return PlainTextResponse("test")

assert exc_info.match(
r"""^No "request" or "websocket" argument on function .*"""
)

def test_endpoint_request_param_invalid(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr)

@app.get("/t4")
@limiter.limit("5/minute")
async def t4(request: str = None):
return PlainTextResponse("test")

with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t4")
assert exc_info.match(
r"""parameter `request` must be an instance of starlette.requests.Request"""
r"Remove 'request' argument from function tests.test_fastapi_extension.t4 or add \[request : starlette.Request\] manually"
)

def test_endpoint_response_param_invalid(self, build_fastapi_app):
Expand All @@ -205,16 +176,15 @@ async def t4(request: Request, response: str = None):
def test_endpoint_request_param_invalid_sync(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr)

@app.get("/t5")
@limiter.limit("5/minute")
def t5(request: str = None):
return PlainTextResponse("test")

with pytest.raises(Exception) as exc_info:
client = TestClient(app)
client.get("/t5")

@app.get("/t5")
@limiter.limit("5/minute")
def t5(request: str = None):
return PlainTextResponse("test")

assert exc_info.match(
r"""parameter `request` must be an instance of starlette.requests.Request"""
r"Remove 'request' argument from function tests.test_fastapi_extension.t5 or add \[request : starlette.Request\] manually"
)

def test_endpoint_response_param_invalid_sync(self, build_fastapi_app):
Expand Down