Skip to content

Commit

Permalink
[tailscale] net: add TCP socket creation/close hooks to SockTrace API
Browse files Browse the repository at this point in the history
Extends the hooks added by #45 to also expose when TCP sockets are
created or closed (meant to allow TCP stats to be read from them). We
don't do this for all socket types since stats are not available for
UDP sockets, and they tend to be short-lived, thus invoking the hooks
would be useless overhead.

Also fixes read/write hooks to not count out-of-band data, since that's
usually not sent over the wire.

Updates tailscale/corp#9230
Updates #58

Signed-off-by: Jenny Zhang <jz@tailscale.com>
(Cherry-picked from db4dc90)
  • Loading branch information
phirework committed Jun 21, 2023
1 parent c3f0500 commit 1b166d5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 6 deletions.
2 changes: 2 additions & 0 deletions api/go1.99999.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ pkg net, type SockTrace struct #58
pkg net, type SockTrace struct, DidRead func(int) #58
pkg net, type SockTrace struct, DidWrite func(int) #58
pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58
pkg net, type SockTrace struct, DidCreateTCPConn func(syscall.RawConn) #58
pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn) #58
22 changes: 16 additions & 6 deletions src/net/fd_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type netFD struct {
// number of bytes transferred.
readHook func(int)
writeHook func(int)
closeHook func()
}

func (fd *netFD) setAddr(laddr, raddr Addr) {
Expand All @@ -39,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {

func (fd *netFD) Close() error {
runtime.SetFinalizer(fd, nil)
if fd.closeHook != nil {
fd.closeHook()
}
return fd.pfd.Close()
}

Expand All @@ -49,10 +53,16 @@ func (fd *netFD) shutdown(how int) error {
}

func (fd *netFD) closeRead() error {
if fd.closeHook != nil {
fd.closeHook()
}
return fd.shutdown(syscall.SHUT_RD)
}

func (fd *netFD) closeWrite() error {
if fd.closeHook != nil {
fd.closeHook()
}
return fd.shutdown(syscall.SHUT_WR)
}

Expand Down Expand Up @@ -94,7 +104,7 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er
func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) {
n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err)
Expand All @@ -103,7 +113,7 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int
func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
Expand All @@ -112,7 +122,7 @@ func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.Socka
func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) {
n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa)
if fd.readHook != nil && err == nil {
fd.readHook(n + oobn)
fd.readHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err)
Expand Down Expand Up @@ -157,7 +167,7 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e
func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand All @@ -166,7 +176,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob
func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand All @@ -175,7 +185,7 @@ func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4)
func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) {
n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa)
if fd.writeHook != nil && err == nil {
fd.writeHook(n + oobn)
fd.writeHook(n)
}
runtime.KeepAlive(fd)
return n, oobn, wrapSyscallError(writeMsgSyscallName, err)
Expand Down
15 changes: 15 additions & 0 deletions src/net/sock_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
if trace := ContextSockTrace(ctx); trace != nil {
fd.readHook = trace.DidRead
fd.writeHook = trace.DidWrite
if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" {
// Ignore newRawConn errors (they're not possible in the current
// implementation, but even if they were, we don't want to
// affect socket operations for a trace hook invocation).
if c, err := newRawConn(fd); err == nil {
if trace.DidCreateTCPConn != nil {
trace.DidCreateTCPConn(c)
}
if trace.WillCloseTCPConn != nil {
fd.closeHook = func() {
trace.WillCloseTCPConn(c)
}
}
}
}
}

// This function makes a network file descriptor for the
Expand Down
7 changes: 7 additions & 0 deletions src/net/socktrace.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@ package net

import (
"context"
"syscall"
)

// SockTrace is a set of hooks to run at various operations on a network socket.
// Any particular hook may be nil. Functions may be called concurrently from
// different goroutines.
type SockTrace struct {
// DidOpenTCPConn is called when a TCP socket was created. The
// underlying raw network connection that was created is provided.
DidCreateTCPConn func(c syscall.RawConn)
// DidRead is called after a successful read from the socket, where n bytes
// were read.
DidRead func(n int)
Expand All @@ -22,6 +26,9 @@ type SockTrace struct {
// subsequent call to WithSockTrace. The provided trace is the new trace
// that will be used.
WillOverwrite func(trace *SockTrace)
// WillCloseTCPConn is called when a TCP socket is about to be closed. The
// underlying raw network connection that is being closed is provided.
WillCloseTCPConn func(c syscall.RawConn)
}

// WithSockTrace returns a new context based on the provided parent
Expand Down

0 comments on commit 1b166d5

Please sign in to comment.