@@ -72,6 +72,11 @@ func (id StreamID) String() string {
72
72
return net .JoinHostPort (id .remoteIP .String (), id .remotePort .String ())
73
73
}
74
74
75
+ type IPLayer interface {
76
+ gopacket.NetworkLayer
77
+ gopacket.SerializableLayer
78
+ }
79
+
75
80
type Stream struct {
76
81
mx sync.Mutex
77
82
seq tcpassembly.Sequence
@@ -157,7 +162,7 @@ func (s *Stream) tcpReceived(src net.IP, tcp *layers.TCP) error {
157
162
return nil
158
163
}
159
164
160
- func (s * Stream ) tcpSent (pkt nfq.Packet , ip * layers. IPv4 , tcp * layers.TCP ) error {
165
+ func (s * Stream ) tcpSent (pkt nfq.Packet , ip IPLayer , tcp * layers.TCP ) error {
161
166
s .mx .Lock ()
162
167
defer s .mx .Unlock ()
163
168
@@ -191,7 +196,15 @@ func (s *Stream) tcpSent(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.TCP) error
191
196
192
197
buf := gopacket .NewSerializeBuffer ()
193
198
opts := gopacket.SerializeOptions {ComputeChecksums : true }
194
- ip .TTL = claim .ttl
199
+
200
+ switch ip := ip .(type ) {
201
+ case * layers.IPv4 :
202
+ ip .TTL = claim .ttl
203
+ case * layers.IPv6 :
204
+ ip .HopLimit = claim .ttl
205
+ default :
206
+ panic ("not a IPv4/6 layer" )
207
+ }
195
208
tcp .SetNetworkLayerForChecksum (ip )
196
209
if err := gopacket .SerializeLayers (buf , opts , ip , tcp , gopacket .Payload (tcp .Payload )); err != nil {
197
210
return err
@@ -309,9 +322,23 @@ func NewStreamTracker() *StreamTracker {
309
322
}
310
323
311
324
func (st * StreamTracker ) HandlePacket (pkt nfq.Packet ) error {
312
- p := gopacket .NewPacket (pkt .Data (), layers .LayerTypeIPv4 , gopacket .Lazy )
313
- if ip , ok := p .NetworkLayer ().(* layers.IPv4 ); ok && ip .Version == 4 {
314
- return st .handleIPv4 (pkt , p , ip )
325
+ data := pkt .Data ()
326
+ if len (data ) == 0 {
327
+ return pkt .Accept ()
328
+ }
329
+
330
+ version := data [0 ] >> 4
331
+ if version == 4 {
332
+ p := gopacket .NewPacket (data , layers .LayerTypeIPv4 , gopacket .Lazy )
333
+ if ip , ok := p .NetworkLayer ().(* layers.IPv4 ); ok {
334
+ return st .handleIPv4 (pkt , p , ip )
335
+ }
336
+ }
337
+ if version == 6 {
338
+ p := gopacket .NewPacket (data , layers .LayerTypeIPv6 , gopacket .Lazy )
339
+ if ip , ok := p .NetworkLayer ().(* layers.IPv6 ); ok {
340
+ return st .handleIPv6 (pkt , p , ip )
341
+ }
315
342
}
316
343
return pkt .Accept ()
317
344
}
@@ -321,7 +348,17 @@ func (st *StreamTracker) handleIPv4(pkt nfq.Packet, p gopacket.Packet, ip *layer
321
348
return st .handleICMPv4 (pkt , ip .SrcIP , icmp )
322
349
}
323
350
if tcp , ok := p .TransportLayer ().(* layers.TCP ); ok {
324
- return st .handleTCP (pkt , ip , tcp )
351
+ return st .handleTCP (pkt , ip , ip .SrcIP , tcp )
352
+ }
353
+ return pkt .Accept ()
354
+ }
355
+
356
+ func (st * StreamTracker ) handleIPv6 (pkt nfq.Packet , p gopacket.Packet , ip * layers.IPv6 ) error {
357
+ if icmp , ok := p .Layer (layers .LayerTypeICMPv6 ).(* layers.ICMPv6 ); ok {
358
+ return st .handleICMPv6 (pkt , ip .SrcIP , icmp )
359
+ }
360
+ if tcp , ok := p .TransportLayer ().(* layers.TCP ); ok {
361
+ return st .handleTCP (pkt , ip , ip .SrcIP , tcp )
325
362
}
326
363
return pkt .Accept ()
327
364
}
@@ -350,7 +387,31 @@ func (st *StreamTracker) handleICMPv4(pkt nfq.Packet, srcIP net.IP, icmp *layers
350
387
return nil
351
388
}
352
389
353
- func (st * StreamTracker ) handleTCP (pkt nfq.Packet , ip * layers.IPv4 , tcp * layers.TCP ) error {
390
+ func (st * StreamTracker ) handleICMPv6 (pkt nfq.Packet , srcIP net.IP , icmp * layers.ICMPv6 ) error {
391
+ defer pkt .Accept ()
392
+
393
+ if icmp .TypeCode .Type () != layers .ICMPv6TypeTimeExceeded {
394
+ return nil
395
+ }
396
+
397
+ p := gopacket .NewPacket (icmp .Payload , layers .LayerTypeIPv6 , gopacket .Lazy )
398
+ ip , ok := p .NetworkLayer ().(* layers.IPv6 )
399
+ if ! ok || ip .NextHeader != layers .IPProtocolTCP || len (ip .Payload ) < 8 {
400
+ return nil
401
+ }
402
+
403
+ dstPort := layers .TCPPort (binary .BigEndian .Uint16 (ip .Payload [2 :4 ]))
404
+ seq := binary .BigEndian .Uint32 (ip .Payload [4 :8 ])
405
+ id := StreamID {ip .NetworkFlow ().Dst (), layers .NewTCPPortEndpoint (dstPort )}
406
+
407
+ stream := st .Get (id )
408
+ if stream != nil {
409
+ return stream .icmpReceived (srcIP , seq )
410
+ }
411
+ return nil
412
+ }
413
+
414
+ func (st * StreamTracker ) handleTCP (pkt nfq.Packet , ip IPLayer , srcIP net.IP , tcp * layers.TCP ) error {
354
415
srcID := StreamID {ip .NetworkFlow ().Src (), tcp .TransportFlow ().Src ()}
355
416
dstID := StreamID {ip .NetworkFlow ().Dst (), tcp .TransportFlow ().Dst ()}
356
417
@@ -382,7 +443,7 @@ func (st *StreamTracker) handleTCP(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.
382
443
}
383
444
384
445
if src := st .Get (srcID ); src != nil {
385
- if err := src .tcpReceived (ip . SrcIP , tcp ); err != nil {
446
+ if err := src .tcpReceived (srcIP , tcp ); err != nil {
386
447
pkt .Accept ()
387
448
return err
388
449
}
0 commit comments