Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "winter"
version = "30.0.1"
version = "31.0.0"
homepage = "https://github.com/WinterFramework/winter"
description = "Web Framework with focus on python typing, dataclasses and modular design"
authors = ["Alexander Egorov <mofr@zond.org>"]
Expand Down Expand Up @@ -41,6 +41,7 @@ pydantic = ">=1.10, <2"
openapi-spec-validator = ">=0.5.7, <1"
uritemplate = "==4.2.0" # Lib doesn't follow semantic versioning
httpx = ">=0.24.1, <0.28"
redis = "^6.2.0"

[tool.poetry.dev-dependencies]
flake8 = ">=3.7.7, <4"
Expand All @@ -61,6 +62,7 @@ pytz = ">=2020.5"

[tool.poetry.group.dev.dependencies]
setuptools = "^71.1.0"
testcontainers = "^4.10.0"

[build-system]
requires = ["poetry-core>=1.3.1"]
Expand Down
27 changes: 27 additions & 0 deletions tests/apps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import atexit

from django.apps import AppConfig
from testcontainers.redis import RedisContainer

from tests.web.interceptors import HelloWorldInterceptor
from winter.web import RedisThrottlingConfiguration
from winter.web import exception_handlers_registry
from winter.web import interceptor_registry
from winter.web.exceptions.handlers import DefaultExceptionHandler
Expand All @@ -9,6 +13,10 @@
class TestAppConfig(AppConfig):
name = 'tests'

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._redis_container: RedisContainer | None = None

def ready(self):
# define this import for force initialization all modules and to register Exceptions
from .urls import urlpatterns # noqa: F401
Expand All @@ -19,7 +27,26 @@ def ready(self):
interceptor_registry.add_interceptor(HelloWorldInterceptor())

winter_openapi.setup()

winter.web.setup()

self._redis_container = RedisContainer()
self._redis_container.start()
self._redis_container.get_client().flushdb()
atexit.register(self.cleanup_redis)

redis_throttling_configuration = RedisThrottlingConfiguration(
host=self._redis_container.get_container_host_ip(),
port=self._redis_container.get_exposed_port(self._redis_container.port),
db=0,
password=self._redis_container.password
)
winter.web.set_redis_throttling_configuration(redis_throttling_configuration)

winter_django.setup()

exception_handlers_registry.set_default_handler(DefaultExceptionHandler) # for 100% test coverage

def cleanup_redis(self): # pragma: no cover
if self._redis_container:
self._redis_container.stop()
39 changes: 39 additions & 0 deletions tests/test_throttling.py → tests/web/test_throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

import freezegun
import pytest
from mock import patch

from winter.web import RedisThrottlingConfiguration
from winter.web import ThrottlingMisconfigurationException
from winter.web import set_redis_throttling_configuration
from winter.web.throttling.redis_throttling_client import get_redis_throttling_client
from winter.web.throttling import redis_throttling_client
from winter.web.throttling import redis_throttling_configuration

expected_error_response = {
'status': 429,
Expand Down Expand Up @@ -65,3 +72,35 @@ def test_get_throttling_with_conditional_reset(api_client):
is_reset = True if i == 5 else False
response = api_client.get(f'/with-throttling/with-reset/?is_reset={is_reset}')
assert response.status_code == HTTPStatus.OK, i


@patch.object(redis_throttling_client, 'get_redis_throttling_configuration', return_value=None)
@patch.object(redis_throttling_client, '_redis_throttling_client', None)
def test_get_redis_throttling_client_without_configuration(_):
with pytest.raises(ThrottlingMisconfigurationException) as exc_info:
get_redis_throttling_client()

assert 'Configuration for Redis must be set' in str(exc_info.value)


@patch.object(
redis_throttling_configuration,
'_redis_throttling_configuration',
RedisThrottlingConfiguration(
host='localhost',
port=1234,
db=0,
password=None
)
)
def test_try_to_set_redis_configuration_twice():
configuration = RedisThrottlingConfiguration(
host='localhost',
port=5678,
db=0,
password=None
)
with pytest.raises(ThrottlingMisconfigurationException) as exc_info:
set_redis_throttling_configuration(configuration)

assert 'RedisThrottlingConfiguration is already initialized' in str(exc_info.value)
3 changes: 3 additions & 0 deletions winter/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .response_header_resolver import ResponseHeaderArgumentResolver
from .response_header_serializer import response_headers_serializer
from .response_status_annotation import response_status
from .throttling import ThrottlingMisconfigurationException
from .throttling import RedisThrottlingConfiguration
from .throttling import set_redis_throttling_configuration
from .throttling import throttling
from .urls import register_url_regexp

Expand Down
6 changes: 6 additions & 0 deletions winter/web/throttling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .exceptions import ThrottlingMisconfigurationException
from .throttling import throttling
from .throttling import reset
from .throttling import create_throttle_class
from .redis_throttling_configuration import set_redis_throttling_configuration
from .redis_throttling_configuration import RedisThrottlingConfiguration
2 changes: 2 additions & 0 deletions winter/web/throttling/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ThrottlingMisconfigurationException(Exception):
pass
67 changes: 67 additions & 0 deletions winter/web/throttling/redis_throttling_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import time

from redis import Redis

from .exceptions import ThrottlingMisconfigurationException
from .redis_throttling_configuration import get_redis_throttling_configuration
from .redis_throttling_configuration import RedisThrottlingConfiguration


class RedisThrottlingClient:
# Redis Lua scripts are atomic
# Sliding window throttling.
# Rejected requests aren't counted.
THROTTLING_LUA = '''
local key = KEYS[1]
local now = tonumber(ARGV[1])
local duration = tonumber(ARGV[2])
local max_requests = tonumber(ARGV[3])

redis.call("ZREMRANGEBYSCORE", key, 0, now - duration)
local count = redis.call("ZCARD", key)

if count >= max_requests then
return 0
end

redis.call("ZADD", key, now, now)
redis.call("EXPIRE", key, duration)
return 1
'''

def __init__(self, configuration: RedisThrottlingConfiguration):
self._redis_client = Redis(
host=configuration.host,
port=configuration.port,
db=configuration.db,
password=configuration.password,
decode_responses=True,
)
self._throttling_script = self._redis_client.register_script(self.THROTTLING_LUA)

def is_request_allowed(self, key: str, duration: int, num_requests: int) -> bool:
now = time.time()
is_allowed = self._throttling_script(
keys=[key],
args=[now, duration, num_requests]
)
return is_allowed == 1

def delete(self, key: str):
self._redis_client.delete(key)


_redis_throttling_client: RedisThrottlingClient | None = None

def get_redis_throttling_client() -> RedisThrottlingClient:
global _redis_throttling_client

if _redis_throttling_client is None:
configuration = get_redis_throttling_configuration()

if configuration is None:
raise ThrottlingMisconfigurationException('Configuration for Redis must be set before using the throttling')

_redis_throttling_client = RedisThrottlingClient(configuration)

return _redis_throttling_client
25 changes: 25 additions & 0 deletions winter/web/throttling/redis_throttling_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass

from .exceptions import ThrottlingMisconfigurationException


@dataclass
class RedisThrottlingConfiguration:
host: str
port: int
db: int
password: str | None = None


_redis_throttling_configuration: RedisThrottlingConfiguration | None = None


def set_redis_throttling_configuration(configuration: RedisThrottlingConfiguration):
global _redis_throttling_configuration
if _redis_throttling_configuration is not None:
raise ThrottlingMisconfigurationException(f'{RedisThrottlingConfiguration.__name__} is already initialized')
_redis_throttling_configuration = configuration


def get_redis_throttling_configuration() -> RedisThrottlingConfiguration | None:
return _redis_throttling_configuration
21 changes: 6 additions & 15 deletions winter/web/throttling.py → winter/web/throttling/throttling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from typing import Tuple

import django.http
from django.core.cache import cache as default_cache

from winter.core import annotate_method
from .redis_throttling_client import get_redis_throttling_client

if TYPE_CHECKING:
from .routing import Route # noqa: F401
from winter.web.routing import Route # noqa: F401


@dataclasses.dataclass
Expand All @@ -33,23 +33,13 @@ def throttling(rate: Optional[str], scope: Optional[str] = None):
class BaseRateThrottle:
def __init__(self, throttling_: Throttling):
self._throttling = throttling_
self._redis_client = get_redis_throttling_client()

def allow_request(self, request: django.http.HttpRequest) -> bool:
ident = _get_ident(request)
key = _get_cache_key(self._throttling.scope, ident)

history = default_cache.get(key, [])
now = time.time()

while history and history[-1] <= now - self._throttling.duration:
history.pop()

if len(history) >= self._throttling.num_requests:
return False

history.insert(0, now)
default_cache.set(key, history, self._throttling.duration)
return True
return self._redis_client.is_request_allowed(key, self._throttling.duration, self._throttling.num_requests)


def reset(request: django.http.HttpRequest, scope: str):
Expand All @@ -59,7 +49,8 @@ def reset(request: django.http.HttpRequest, scope: str):
"""
ident = _get_ident(request)
key = _get_cache_key(scope, ident)
default_cache.delete(key)
redis_client = get_redis_throttling_client()
redis_client.delete(key)


CACHE_KEY_FORMAT = 'throttle_{scope}_{ident}'
Expand Down