Skip to content

Commit

Permalink
Switch default behavior to block=True
Browse files Browse the repository at this point in the history
Previous versions defaulted to block=False for the decorator, which was
surprising and lead to more than a few issues being filed. This change
reverses the default, setting block=True on the decorator. To opt into
the previous behavior, use block=False.
  • Loading branch information
jsocol committed Dec 4, 2022
1 parent 96034df commit 7d8e317
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 44 deletions.
2 changes: 1 addition & 1 deletion django_ratelimit/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
__all__ = ['ratelimit']


def ratelimit(group=None, key=None, rate=None, method=ALL, block=False):
def ratelimit(group=None, key=None, rate=None, method=ALL, block=True):
def decorator(fn):
@wraps(fn)
def _wrapped(request, *args, **kw):
Expand Down
79 changes: 43 additions & 36 deletions django_ratelimit/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def setUp(self):
cache.clear()

def test_no_key(self):
@ratelimit(rate='1/m', block=True)
@ratelimit(rate='1/m')
def view(request):
return True

Expand All @@ -66,15 +66,15 @@ def view(request):
view(req)

def test_ip(self):
@ratelimit(key='ip', rate='1/m')
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited

assert not view(rf.get('/')), 'First request works.'
assert view(rf.get('/')), 'Second request is limited'

def test_block(self):
@ratelimit(key='ip', rate='1/m', block=True)
@ratelimit(key='ip', rate='1/m')
def blocked(request):
return request.limited

Expand All @@ -83,7 +83,7 @@ def blocked(request):
blocked(rf.get('/')), 'Second request is blocked.'

def test_ratelimit_custom_string_exception_class(self):
@ratelimit(key='ip', rate='1/m', block=True)
@ratelimit(key='ip', rate='1/m')
def view(request):
return request.limited

Expand All @@ -98,7 +98,7 @@ def view(request):
view(req)

def test_ratelimit_custom_exception_class(self):
@ratelimit(key='ip', rate='1/m', block=True)
@ratelimit(key='ip', rate='1/m')
def view(request):
return request.limited

Expand All @@ -111,7 +111,7 @@ def view(request):
view(req)

def test_method(self):
@ratelimit(key='ip', method='POST', rate='1/m', group='a')
@ratelimit(key='ip', method='POST', rate='1/m', group='a', block=False)
def limit_post(request):
return request.limited

Expand All @@ -120,7 +120,7 @@ def limit_post(request):
assert not limit_post(rf.get('/')), 'Do not limit GET.'

def test_unsafe_methods(self):
@ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m')
@ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m', block=False)
def limit_unsafe(request):
return request.limited

Expand All @@ -133,7 +133,7 @@ def limit_unsafe(request):
assert limit_unsafe(rf.patch('/'))

def test_key_get(self):
@ratelimit(key='get:foo', rate='1/m', method='GET')
@ratelimit(key='get:foo', rate='1/m', method='GET', block=False)
def view(request):
return request.limited

Expand All @@ -143,7 +143,7 @@ def view(request):
assert view(rf.get('/', {'foo': 'b'}))

def test_key_post(self):
@ratelimit(key='post:foo', rate='1/m')
@ratelimit(key='post:foo', rate='1/m', block=False)
def view(request):
return request.limited

Expand All @@ -158,16 +158,16 @@ def _req():
req.META['HTTP_X_REAL_IP'] = '1.2.3.4'
return req

@ratelimit(key='header:x-real-ip', rate='1/m')
@ratelimit(key='header:x-missing-header', rate='1/m')
@ratelimit(key='header:x-real-ip', rate='1/m', block=False)
@ratelimit(key='header:x-missing-header', rate='1/m', block=False)
def view(request):
return request.limited

assert not view(_req())
assert view(_req())

def test_rate(self):
@ratelimit(key='ip', rate='2/m')
@ratelimit(key='ip', rate='2/m', block=False)
def twice(request):
return request.limited

Expand All @@ -176,14 +176,14 @@ def twice(request):
assert twice(rf.post('/')), 'Third request is limited.'

def test_zero_rate(self):
@ratelimit(key='ip', rate='0/m')
@ratelimit(key='ip', rate='0/m', block=False)
def never(request):
return request.limited

assert never(rf.post('/'))

def test_none_rate(self):
@ratelimit(key='ip', rate=None)
@ratelimit(key='ip', rate=None, block=False)
def always(request):
return request.limited

Expand All @@ -206,7 +206,7 @@ def get_rate(group, request):
return (2, 60)
return (1, 60)

@ratelimit(key='user_or_ip', rate=get_rate)
@ratelimit(key='user_or_ip', rate=get_rate, block=False)
def view(request):
return request.limited

Expand All @@ -224,7 +224,7 @@ def _req(never_limit=False):

get_rate = lambda g, r: None if r.never_limit else '1/m'

@ratelimit(key='ip', rate=get_rate)
@ratelimit(key='ip', rate=get_rate, block=False)
def view(request):
return request.limited

Expand All @@ -244,7 +244,7 @@ def get_rate(group, request):
return '1/m'
return '0/m'

@ratelimit(key='ip', rate=get_rate)
@ratelimit(key='ip', rate=get_rate, block=False)
def view(request):
return request.limited

Expand All @@ -259,7 +259,8 @@ def _req(auth):
return req

@ratelimit(key='user_or_ip',
rate='django_ratelimit.tests.callable_rate')
rate='django_ratelimit.tests.callable_rate',
block=False)
def view(request):
return request.limited

Expand Down Expand Up @@ -288,15 +289,15 @@ def view(request):
assert view(_req(auth=True))

def test_callable_key_path(self):
@ratelimit(key='django_ratelimit.tests.mykey', rate='1/m')
@ratelimit(key='django_ratelimit.tests.mykey', rate='1/m', block=False)
def view(request):
return request.limited

assert not view(rf.post('/'))
assert view(rf.post('/'))

def test_callable_key(self):
@ratelimit(key=mykey, rate='1/m')
@ratelimit(key=mykey, rate='1/m', block=False)
def view(request):
return request.limited

Expand All @@ -317,8 +318,8 @@ def view(request):

def test_stacked_methods(self):
"""Different methods should result in different counts."""
@ratelimit(rate='1/m', key='ip', method='GET')
@ratelimit(rate='1/m', key='ip', method='POST')
@ratelimit(rate='1/m', key='ip', method='GET', block=False)
@ratelimit(rate='1/m', key='ip', method='POST', block=False)
def view(request):
return request.limited

Expand All @@ -329,19 +330,21 @@ def view(request):

def test_sorted_methods(self):
"""Order of the methods shouldn't matter."""
@ratelimit(rate='1/m', key='ip', method=['GET', 'POST'], group='a')
@ratelimit(rate='1/m', key='ip', method=['GET', 'POST'],
group='a', block=False)
def get_post(request):
return request.limited

@ratelimit(rate='1/m', key='ip', method=['POST', 'GET'], group='a')
@ratelimit(rate='1/m', key='ip', method=['POST', 'GET'],
group='a', block=False)
def post_get(request):
return request.limited

assert not get_post(rf.get('/'))
assert post_get(rf.get('/'))

def test_ratelimit_full_mask_v4(self):
@ratelimit(rate='1/m', key='ip')
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited

Expand All @@ -356,7 +359,7 @@ def view(request):
assert not view(req)

def test_ratelimit_full_mask_v6(self):
@ratelimit(rate='1/m', key='ip')
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited

Expand All @@ -371,7 +374,7 @@ def view(request):
assert not view(req)

def test_ratelimit_mask_v4(self):
@ratelimit(rate='1/m', key='ip')
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited

Expand All @@ -390,7 +393,7 @@ def view(request):
assert not view(req)

def test_ratelimit_mask_v6(self):
@ratelimit(rate='1/m', key='ip')
@ratelimit(rate='1/m', key='ip', block=False)
def view(request):
return request.limited

Expand Down Expand Up @@ -513,11 +516,13 @@ def get(self, request):

def test_methods_counted_separately(self):
class TestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m', method='GET'))
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited

@method_decorator(ratelimit(key='ip', rate='1/m', method='POST'))
@method_decorator(ratelimit(key='ip', rate='1/m',
method='POST', block=False))
def post(self, request):
return request.limited

Expand All @@ -529,12 +534,14 @@ def post(self, request):

def test_views_counted_separately(self):
class TestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m', method='GET'))
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited

class AnotherTestView(View):
@method_decorator(ratelimit(key='ip', rate='1/m', method='GET'))
@method_decorator(ratelimit(key='ip', rate='1/m',
method='GET', block=False))
def get(self, request):
return request.limited

Expand All @@ -549,7 +556,7 @@ def get(self, request):
class CacheFailTests(TestCase):
@override_settings(RATELIMIT_USE_CACHE='fake-cache')
def test_bad_cache(self):
@ratelimit(key='ip', rate='1/m')
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited

Expand All @@ -558,7 +565,7 @@ def view(request):

@override_settings(RATELIMIT_USE_CACHE='connection-errors')
def test_limit_on_cache_connection_error(self):
@ratelimit(key='ip', rate='10/m')
@ratelimit(key='ip', rate='10/m', block=False)
def view(request):
return request.limited

Expand All @@ -567,7 +574,7 @@ def view(request):
@override_settings(RATELIMIT_USE_CACHE='connection-errors',
RATELIMIT_FAIL_OPEN=True)
def test_fail_open_setting(self):
@ratelimit(key='ip', rate='1/m')
@ratelimit(key='ip', rate='1/m', block=False)
def view(request):
return request.limited

Expand Down Expand Up @@ -606,7 +613,7 @@ def do_increment(request):

@override_settings(RATELIMIT_USE_CACHE='instant-expiration')
def test_cache_timeout(self):
@ratelimit(key='ip', rate='1/m', block=True)
@ratelimit(key='ip', rate='1/m')
def view(request):
return True

Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Use as a decorator in ``views.py``:
def secondview(request):
# ...
After activating django-ratelimit, you should ensure that your cache
Before activating django-ratelimit, you should ensure that your cache
backend is setup to be both persistent and work across multiple
deployment worker instances (for instance UWSGI workers). Read more in
the Django docs on `caching
Expand Down
13 changes: 7 additions & 6 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Import:
from django_ratelimit.decorators import ratelimit
.. py:decorator:: ratelimit(group=None, key=, rate=None, method=ALL, block=False)
.. py:decorator:: ratelimit(group=None, key=, rate=None, method=ALL, block=True)
:arg group:
*None* A group of rate limits to count together. Defaults to the
Expand Down Expand Up @@ -46,7 +46,7 @@ Import:
``PATCH``).

:arg block:
*False* Whether to block the request instead of annotating.
*True* Whether to block the request instead of annotating.


HTTP Methods
Expand Down Expand Up @@ -79,7 +79,7 @@ Examples

.. code-block:: python
@ratelimit(key='ip', rate='5/m')
@ratelimit(key='ip', rate='5/m', block=False)
def myview(request):
# Will be true if the same IP makes more than 5 POST
# requests/minute.
Expand All @@ -91,7 +91,8 @@ Examples
# If the same IP makes >5 reqs/min, will raise Ratelimited
return HttpResponse()
@ratelimit(key='post:username', rate='5/m', method=['GET', 'POST'])
@ratelimit(key='post:username', rate='5/m',
method=['GET', 'POST'], block=False)
def login(request):
# If the same username is used >5 times/min, this will be True.
# The `username` value will come from GET or POST, determined by the
Expand All @@ -118,8 +119,8 @@ Examples
# Allow 4 reqs/hour.
return HttpResponse()
rate = lambda g, r: None if r.user.is_authenticated else '100/h'
@ratelimit(key='ip', rate=rate)
get_rate = lambda g, r: None if r.user.is_authenticated else '100/h'
@ratelimit(key='ip', rate=get_rate)
def skipif1(request):
# Only rate limit anonymous requests
return HttpResponse()
Expand Down

0 comments on commit 7d8e317

Please sign in to comment.