-
Notifications
You must be signed in to change notification settings - Fork 62
/
rate-limiter.go
123 lines (106 loc) · 3.04 KB
/
rate-limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package rdns
import (
"expvar"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
// RateLimiter is a resolver that limits the number of queries by a client (network)
// that are passed to the upstream resolver per timeframe.
type RateLimiter struct {
id string
resolver Resolver
RateLimiterOptions
mu sync.RWMutex
currWinID int64
counters map[string]*uint
metrics *RateLimiterMetrics
}
var _ Resolver = &RateLimiter{}
type RateLimiterOptions struct {
Requests uint // Number of requests allwed per time period
Window uint // Time period in seconds
Prefix4 uint8 // Netmask to identify IP4 clients
Prefix6 uint8 // Netmask to identify IP6 clients
LimitResolver Resolver // Alternate resolver for rate-limited requests
}
type RateLimiterMetrics struct {
// Count of queries.
query *expvar.Int
// Count of queries that have exceeded the rate limit.
exceed *expvar.Int
// Count of dropped queries.
drop *expvar.Int
}
// NewRateLimiterIP returns a new instance of a query rate limiter.
func NewRateLimiter(id string, resolver Resolver, opt RateLimiterOptions) *RateLimiter {
if opt.Window == 0 {
opt.Window = 60
}
if opt.Prefix4 == 0 {
opt.Prefix4 = 24
}
if opt.Prefix6 == 0 {
opt.Prefix6 = 56
}
return &RateLimiter{
id: id,
resolver: resolver,
RateLimiterOptions: opt,
metrics: &RateLimiterMetrics{
query: getVarInt("router", id, "query"),
exceed: getVarInt("router", id, "exceed"),
drop: getVarInt("router", id, "drop"),
},
}
}
// Resolve a DNS query while limiting the query rate per time period.
func (r *RateLimiter) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
log := logger(r.id, q, ci)
r.metrics.query.Add(1)
// Apply the desired mask to the client IP to build a key it identify the client (network)
source := ci.SourceIP
if ip4 := source.To4(); len(ip4) == net.IPv4len {
source = source.Mask(net.CIDRMask(int(r.Prefix4), 32))
} else {
source = source.Mask(net.CIDRMask(int(r.Prefix6), 128))
}
key := source.String()
// Calculate the current (fixed) window
windowID := time.Now().Unix() / int64(r.Window)
var reject bool
r.mu.Lock()
// If we have moved on to the next window, re-initialize the counters
if windowID != r.currWinID {
r.currWinID = windowID
r.counters = make(map[string]*uint)
}
// Load the current counter for this client or make a new one
v, ok := r.counters[key]
if !ok {
v = new(uint)
r.counters[key] = v
}
// Check the number of requests made in this window
if *v >= r.Requests {
reject = true
}
*v++
r.mu.Unlock()
if reject {
r.metrics.exceed.Add(1)
if r.LimitResolver != nil {
log.WithField("resolver", r.LimitResolver).Debug("rate-limit exceeded, forwarding to limit-resolver")
return r.LimitResolver.Resolve(q, ci)
}
r.metrics.drop.Add(1)
log.Debug("rate-limit reached, dropping")
return nil, nil
}
log.WithField("resolver", r.resolver).Debug("forwarding query to resolver")
return r.resolver.Resolve(q, ci)
}
func (r *RateLimiter) String() string {
return r.id
}