From 51a96ad6414edd8c39affd9923a614e4c8d1b221 Mon Sep 17 00:00:00 2001 From: Jenny Zhang Date: Wed, 21 Jun 2023 14:46:28 -0400 Subject: [PATCH] [tailscale] net: add TCP socket creation/close hooks to SockTrace API Extends the hooks added by #45 to also expose when TCP sockets are created or closed (meant to allow TCP stats to be read from them). We don't do this for all socket types since stats are not available for UDP sockets, and they tend to be short-lived, thus invoking the hooks would be useless overhead. Also fixes read/write hooks to not count out-of-band data, since that's usually not sent over the wire. Updates tailscale/corp#9230 Updates #58 Signed-off-by: Jenny Zhang (Cherry-picked from db4dc90) --- api/go1.99999.txt | 2 ++ src/net/fd_posix.go | 22 ++++++++++++++++------ src/net/sock_posix.go | 15 +++++++++++++++ src/net/socktrace.go | 7 +++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/api/go1.99999.txt b/api/go1.99999.txt index 752e6be3447d0..a4b8591391989 100644 --- a/api/go1.99999.txt +++ b/api/go1.99999.txt @@ -6,3 +6,5 @@ pkg net, type SockTrace struct #58 pkg net, type SockTrace struct, DidRead func(int) #58 pkg net, type SockTrace struct, DidWrite func(int) #58 pkg net, type SockTrace struct, WillOverwrite func(*SockTrace) #58 +pkg net, type SockTrace struct, DidCreateTCPConn func(syscall.RawConn) #58 +pkg net, type SockTrace struct, WillCloseTCPConn func(syscall.RawConn) #58 diff --git a/src/net/fd_posix.go b/src/net/fd_posix.go index 7f3aeff580c95..5c88b50cdae49 100644 --- a/src/net/fd_posix.go +++ b/src/net/fd_posix.go @@ -29,6 +29,7 @@ type netFD struct { // number of bytes transferred. readHook func(int) writeHook func(int) + closeHook func() } func (fd *netFD) setAddr(laddr, raddr Addr) { @@ -39,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { func (fd *netFD) Close() error { runtime.SetFinalizer(fd, nil) + if fd.closeHook != nil { + fd.closeHook() + } return fd.pfd.Close() } @@ -49,10 +53,16 @@ func (fd *netFD) shutdown(how int) error { } func (fd *netFD) closeRead() error { + if fd.closeHook != nil { + fd.closeHook() + } return fd.shutdown(syscall.SHUT_RD) } func (fd *netFD) closeWrite() error { + if fd.closeHook != nil { + fd.closeHook() + } return fd.shutdown(syscall.SHUT_WR) } @@ -94,7 +104,7 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err) @@ -103,7 +113,7 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) @@ -112,7 +122,7 @@ func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.Socka func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) @@ -157,7 +167,7 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) @@ -166,7 +176,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) @@ -175,7 +185,7 @@ func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index cd36c2ab6bb5b..7f0ff9bc91096 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -31,6 +31,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only if trace := ContextSockTrace(ctx); trace != nil { fd.readHook = trace.DidRead fd.writeHook = trace.DidWrite + if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" { + // Ignore newRawConn errors (they're not possible in the current + // implementation, but even if they were, we don't want to + // affect socket operations for a trace hook invocation). + if c, err := newRawConn(fd); err == nil { + if trace.DidCreateTCPConn != nil { + trace.DidCreateTCPConn(c) + } + if trace.WillCloseTCPConn != nil { + fd.closeHook = func() { + trace.WillCloseTCPConn(c) + } + } + } + } } // This function makes a network file descriptor for the diff --git a/src/net/socktrace.go b/src/net/socktrace.go index 57af2be8b75e3..b02a8d12484d4 100644 --- a/src/net/socktrace.go +++ b/src/net/socktrace.go @@ -6,12 +6,16 @@ package net import ( "context" + "syscall" ) // SockTrace is a set of hooks to run at various operations on a network socket. // Any particular hook may be nil. Functions may be called concurrently from // different goroutines. type SockTrace struct { + // DidOpenTCPConn is called when a TCP socket was created. The + // underlying raw network connection that was created is provided. + DidCreateTCPConn func(c syscall.RawConn) // DidRead is called after a successful read from the socket, where n bytes // were read. DidRead func(n int) @@ -22,6 +26,9 @@ type SockTrace struct { // subsequent call to WithSockTrace. The provided trace is the new trace // that will be used. WillOverwrite func(trace *SockTrace) + // WillCloseTCPConn is called when a TCP socket is about to be closed. The + // underlying raw network connection that is being closed is provided. + WillCloseTCPConn func(c syscall.RawConn) } // WithSockTrace returns a new context based on the provided parent