-
Notifications
You must be signed in to change notification settings - Fork 1
/
ratelimitcache.py
108 lines (85 loc) · 3.45 KB
/
ratelimitcache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from django.http import HttpResponseForbidden
from django.core.cache import cache
from datetime import datetime, timedelta
import functools
import hashlib
class ratelimit(object):
"Instances of this class can be used as decorators"
# This class is designed to be sub-classed
minutes = 2 # The time period
requests = 20 # Number of allowed requests in that time period
prefix = 'rl-' # Prefix for memcache key
expire_after = (minutes + 1) * 60
def __init__(self, **options):
for key, value in options.items():
setattr(self, key, value)
def __call__(self, fn):
def wrapper(request, *args, **kwargs):
return self.view_wrapper(request, fn, *args, **kwargs)
functools.update_wrapper(wrapper, fn)
return wrapper
def rate_limit_manual(self, request):
if not self.should_ratelimit(request):
return None
counts = [int(x) for x in self.get_counters(request).values()]
if sum(counts) >= self.requests:
return self.disallowed(request)
return None
def view_wrapper(self, request, fn, *args, **kwargs):
if not self.should_ratelimit(request):
return fn(request, *args, **kwargs)
counts = self.get_counters(request).values()
# Increment rate limiting counter
self.cache_incr(self.current_key(request))
# Have they failed?
if sum(counts) >= self.requests:
return self.disallowed(request)
return fn(request, *args, **kwargs)
def cache_get_many(self, keys):
return cache.get_many(keys)
def cache_incr(self, key):
# memcache is only backend that can increment atomically
try:
# add first, to ensure the key exists
cache._cache.add(key, '0', time=self.expire_after)
cache._cache.incr(key)
except AttributeError:
cache.set(key, cache.get(key, 0) + 1, self.expire_after)
def should_ratelimit(self, request):
return True
def get_counters(self, request):
return self.cache_get_many(self.keys_to_check(request))
def keys_to_check(self, request):
extra = self.key_extra(request)
now = datetime.now()
return [
'%s%s-%s' % (
self.prefix,
extra,
(now - timedelta(minutes=minute)).strftime('%Y%m%d%H%M')
) for minute in range(self.minutes + 1)
]
def current_key(self, request):
return '%s%s-%s' % (
self.prefix,
self.key_extra(request),
datetime.now().strftime('%Y%m%d%H%M')
)
def key_extra(self, request):
# By default, their IP address is used
return request.META.get('REMOTE_ADDR', '')
def disallowed(self, request):
"Over-ride this method if you want to log incidents"
return HttpResponseForbidden('Rate limit exceeded')
class ratelimit_post(ratelimit):
"Rate limit POSTs - can be used to protect a login form"
key_field = None # If provided, this POST var will affect the rate limit
def should_ratelimit(self, request):
return request.method == 'POST'
def key_extra(self, request):
# IP address and key_field (if it is set)
extra = super(ratelimit_post, self).key_extra(request)
if self.key_field:
value = hashlib.sha1(request.POST.get(self.key_field, '')).hexdigest()
extra += '-' + value
return extra