-
Notifications
You must be signed in to change notification settings - Fork 62
/
request-dedup.go
144 lines (127 loc) · 3.08 KB
/
request-dedup.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package rdns
import (
"encoding/binary"
"sync"
"github.com/miekg/dns"
)
type dedupKey struct {
name string
qtype uint16
ecs_ipv4 uint32
ecs_ipv6_hi uint64
ecs_ipv6_lo uint64
ecs_mask uint8
}
type inflightRequest struct {
answer *dns.Msg
err error
done chan struct{}
}
// requestDedup passes individual requests normally. Subsequent
// queries for the same name are being held until the first query
// returns. In that case, all waiting requests are answered with
// the same response. This element is used to smooth out spikes
// of queries for the same name.
type requestDedup struct {
id string
resolver Resolver
mu sync.Mutex
inflight map[dedupKey]*inflightRequest
}
var _ Resolver = &requestDedup{}
func NewRequestDedup(id string, resolver Resolver) *requestDedup {
return &requestDedup{
id: id,
resolver: resolver,
inflight: make(map[dedupKey]*inflightRequest),
}
}
func (r *requestDedup) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
var (
ecsIPv4 uint32
ecsIPv6Lo, ecsIPv6Hi uint64
ecsMask uint8
)
edns0 := q.IsEdns0()
if edns0 != nil {
// Find the ECS option
for _, opt := range edns0.Option {
ecs, ok := opt.(*dns.EDNS0_SUBNET)
if !ok {
continue
}
switch ecs.Family {
case 1: // ip4
ecsIPv4 = byteToUint32(ecs.Address.To4())
ecsMask = ecs.SourceNetmask
case 2: // ip6
ecsIPv6Hi, ecsIPv6Lo = byteToUint128(ecs.Address.To16())
ecsMask = ecs.SourceNetmask
}
break
}
}
k := dedupKey{
name: q.Question[0].Name,
qtype: q.Question[0].Qtype,
ecs_ipv4: ecsIPv4,
ecs_ipv6_hi: ecsIPv6Hi,
ecs_ipv6_lo: ecsIPv6Lo,
ecs_mask: ecsMask,
}
r.mu.Lock()
req, ok := r.inflight[k]
if !ok {
req = &inflightRequest{
done: make(chan struct{}),
}
r.inflight[k] = req
}
r.mu.Unlock()
log := logger(r.id, q, ci)
// If the request is already in flight, wait for that to complete and
// return the same answer.
if ok {
log.Debug("duplicated request, waiting for first answer")
<-req.done
a, err := req.answer, req.err
// Return a copy of the answer as other elements might be modifying it
if a != nil {
a = a.Copy()
}
return a, err
}
log.WithField("resolver", r.resolver).Debug("forwarding query to resolver")
// Not already in flight, make the request
a, err := r.resolver.Resolve(q, ci)
req.answer = a
req.err = err
close(req.done) // release other goroutines waiting for the response
// No longer in flight
r.mu.Lock()
delete(r.inflight, k)
r.mu.Unlock()
// Return a copy since it could be modified in the chain (i.e. in the listener)
// but it's also stored for other goroutines which need to copy it.
if a != nil {
return a.Copy(), err
}
return a, err
}
func (r *requestDedup) String() string {
return r.id
}
func byteToUint128(b []byte) (uint64, uint64) {
if len(b) != 16 {
return 0, 0
}
hi := binary.BigEndian.Uint64(b[0:8])
lo := binary.BigEndian.Uint64(b[8:16])
return hi, lo
}
func byteToUint32(b []byte) uint32 {
if len(b) != 4 {
return 0
}
return binary.BigEndian.Uint32(b[0:4])
}