forked from pion/dtls
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fragment_buffer.go
135 lines (109 loc) · 3.81 KB
/
fragment_buffer.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package dtls
import (
"github.com/pion/dtls/v3/pkg/protocol"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/dtls/v3/pkg/protocol/recordlayer"
)
// 2 megabytes
const fragmentBufferMaxSize = 2000000
type fragment struct {
recordLayerHeader recordlayer.Header
handshakeHeader handshake.Header
data []byte
}
type fragmentBuffer struct {
// map of MessageSequenceNumbers that hold slices of fragments
cache map[uint16][]*fragment
currentMessageSequenceNumber uint16
}
func newFragmentBuffer() *fragmentBuffer {
return &fragmentBuffer{cache: map[uint16][]*fragment{}}
}
// current total size of buffer
func (f *fragmentBuffer) size() int {
size := 0
for i := range f.cache {
for j := range f.cache[i] {
size += len(f.cache[i][j].data)
}
}
return size
}
// Attempts to push a DTLS packet to the fragmentBuffer
// when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
// when an error returns it is fatal, and the DTLS connection should be stopped
func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) {
if f.size()+len(buf) >= fragmentBufferMaxSize {
return false, false, errFragmentBufferOverflow
}
frag := new(fragment)
if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
return false, false, err
}
// fragment isn't a handshake, we don't need to handle it
if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
return false, false, nil
}
for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) {
if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
return false, false, err
}
// Fragment is a retransmission. We have already assembled it before successfully
isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber
if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
}
// end index should be the length of handshake header but if the handshake
// was fragmented, we should keep them all
end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
if size := len(buf); end > size {
end = size
}
// Discard all headers, when rebuilding the packet we will re-build
frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
buf = buf[end:]
}
return true, isRetransmit, nil
}
func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
frags, ok := f.cache[f.currentMessageSequenceNumber]
if !ok {
return nil, 0
}
// Go doesn't support recursive lambdas
var appendMessage func(targetOffset uint32) bool
rawMessage := []byte{}
appendMessage = func(targetOffset uint32) bool {
for _, f := range frags {
if f.handshakeHeader.FragmentOffset == targetOffset {
fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 {
if !appendMessage(fragmentEnd) {
return false
}
}
rawMessage = append(f.data, rawMessage...)
return true
}
}
return false
}
// Recursively collect up
if !appendMessage(0) {
return nil, 0
}
firstHeader := frags[0].handshakeHeader
firstHeader.FragmentOffset = 0
firstHeader.FragmentLength = firstHeader.Length
rawHeader, err := firstHeader.Marshal()
if err != nil {
return nil, 0
}
messageEpoch := frags[0].recordLayerHeader.Epoch
delete(f.cache, f.currentMessageSequenceNumber)
f.currentMessageSequenceNumber++
return append(rawHeader, rawMessage...), messageEpoch
}