Skip to content
This repository was archived by the owner on Nov 10, 2023. It is now read-only.

Commit c638b7d

Browse files
committed
Add IPv6 support
1 parent bab636a commit c638b7d

File tree

1 file changed

+69
-8
lines changed

1 file changed

+69
-8
lines changed

main.go

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ func (id StreamID) String() string {
7272
return net.JoinHostPort(id.remoteIP.String(), id.remotePort.String())
7373
}
7474

75+
type IPLayer interface {
76+
gopacket.NetworkLayer
77+
gopacket.SerializableLayer
78+
}
79+
7580
type Stream struct {
7681
mx sync.Mutex
7782
seq tcpassembly.Sequence
@@ -157,7 +162,7 @@ func (s *Stream) tcpReceived(src net.IP, tcp *layers.TCP) error {
157162
return nil
158163
}
159164

160-
func (s *Stream) tcpSent(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.TCP) error {
165+
func (s *Stream) tcpSent(pkt nfq.Packet, ip IPLayer, tcp *layers.TCP) error {
161166
s.mx.Lock()
162167
defer s.mx.Unlock()
163168

@@ -191,7 +196,15 @@ func (s *Stream) tcpSent(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.TCP) error
191196

192197
buf := gopacket.NewSerializeBuffer()
193198
opts := gopacket.SerializeOptions{ComputeChecksums: true}
194-
ip.TTL = claim.ttl
199+
200+
switch ip := ip.(type) {
201+
case *layers.IPv4:
202+
ip.TTL = claim.ttl
203+
case *layers.IPv6:
204+
ip.HopLimit = claim.ttl
205+
default:
206+
panic("not a IPv4/6 layer")
207+
}
195208
tcp.SetNetworkLayerForChecksum(ip)
196209
if err := gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload(tcp.Payload)); err != nil {
197210
return err
@@ -309,9 +322,23 @@ func NewStreamTracker() *StreamTracker {
309322
}
310323

311324
func (st *StreamTracker) HandlePacket(pkt nfq.Packet) error {
312-
p := gopacket.NewPacket(pkt.Data(), layers.LayerTypeIPv4, gopacket.Lazy)
313-
if ip, ok := p.NetworkLayer().(*layers.IPv4); ok && ip.Version == 4 {
314-
return st.handleIPv4(pkt, p, ip)
325+
data := pkt.Data()
326+
if len(data) == 0 {
327+
return pkt.Accept()
328+
}
329+
330+
version := data[0] >> 4
331+
if version == 4 {
332+
p := gopacket.NewPacket(data, layers.LayerTypeIPv4, gopacket.Lazy)
333+
if ip, ok := p.NetworkLayer().(*layers.IPv4); ok {
334+
return st.handleIPv4(pkt, p, ip)
335+
}
336+
}
337+
if version == 6 {
338+
p := gopacket.NewPacket(data, layers.LayerTypeIPv6, gopacket.Lazy)
339+
if ip, ok := p.NetworkLayer().(*layers.IPv6); ok {
340+
return st.handleIPv6(pkt, p, ip)
341+
}
315342
}
316343
return pkt.Accept()
317344
}
@@ -321,7 +348,17 @@ func (st *StreamTracker) handleIPv4(pkt nfq.Packet, p gopacket.Packet, ip *layer
321348
return st.handleICMPv4(pkt, ip.SrcIP, icmp)
322349
}
323350
if tcp, ok := p.TransportLayer().(*layers.TCP); ok {
324-
return st.handleTCP(pkt, ip, tcp)
351+
return st.handleTCP(pkt, ip, ip.SrcIP, tcp)
352+
}
353+
return pkt.Accept()
354+
}
355+
356+
func (st *StreamTracker) handleIPv6(pkt nfq.Packet, p gopacket.Packet, ip *layers.IPv6) error {
357+
if icmp, ok := p.Layer(layers.LayerTypeICMPv6).(*layers.ICMPv6); ok {
358+
return st.handleICMPv6(pkt, ip.SrcIP, icmp)
359+
}
360+
if tcp, ok := p.TransportLayer().(*layers.TCP); ok {
361+
return st.handleTCP(pkt, ip, ip.SrcIP, tcp)
325362
}
326363
return pkt.Accept()
327364
}
@@ -350,7 +387,31 @@ func (st *StreamTracker) handleICMPv4(pkt nfq.Packet, srcIP net.IP, icmp *layers
350387
return nil
351388
}
352389

353-
func (st *StreamTracker) handleTCP(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.TCP) error {
390+
func (st *StreamTracker) handleICMPv6(pkt nfq.Packet, srcIP net.IP, icmp *layers.ICMPv6) error {
391+
defer pkt.Accept()
392+
393+
if icmp.TypeCode.Type() != layers.ICMPv6TypeTimeExceeded {
394+
return nil
395+
}
396+
397+
p := gopacket.NewPacket(icmp.Payload, layers.LayerTypeIPv6, gopacket.Lazy)
398+
ip, ok := p.NetworkLayer().(*layers.IPv6)
399+
if !ok || ip.NextHeader != layers.IPProtocolTCP || len(ip.Payload) < 8 {
400+
return nil
401+
}
402+
403+
dstPort := layers.TCPPort(binary.BigEndian.Uint16(ip.Payload[2:4]))
404+
seq := binary.BigEndian.Uint32(ip.Payload[4:8])
405+
id := StreamID{ip.NetworkFlow().Dst(), layers.NewTCPPortEndpoint(dstPort)}
406+
407+
stream := st.Get(id)
408+
if stream != nil {
409+
return stream.icmpReceived(srcIP, seq)
410+
}
411+
return nil
412+
}
413+
414+
func (st *StreamTracker) handleTCP(pkt nfq.Packet, ip IPLayer, srcIP net.IP, tcp *layers.TCP) error {
354415
srcID := StreamID{ip.NetworkFlow().Src(), tcp.TransportFlow().Src()}
355416
dstID := StreamID{ip.NetworkFlow().Dst(), tcp.TransportFlow().Dst()}
356417

@@ -382,7 +443,7 @@ func (st *StreamTracker) handleTCP(pkt nfq.Packet, ip *layers.IPv4, tcp *layers.
382443
}
383444

384445
if src := st.Get(srcID); src != nil {
385-
if err := src.tcpReceived(ip.SrcIP, tcp); err != nil {
446+
if err := src.tcpReceived(srcIP, tcp); err != nil {
386447
pkt.Accept()
387448
return err
388449
}

0 commit comments

Comments
 (0)