Skip to content

Commit

Permalink
Refactor getPidByAddr function
Browse files Browse the repository at this point in the history
This commit refactors the getPidByAddr function to handle some unusual cases
where the source IP address in the IP packet header has been rewritten by
iptables, nftables, or eBPF. In such cases, the function attempts to retrieve
the process ID (pid) and destination address (destAddr) using the local IP
address (127.0.0.1 for IPv4 or [::1] for IPv6).

The original getPidByAddr function is renamed to _getPidByAddr, and a new
getPidByAddr function is introduced. The new function first calls _getPidByAddr
to obtain the pid and destAddr. If either of them is empty, it checks for the
unusual case mentioned above and tries again with the local IP address.

The commit also includes a detailed comment explaining the rationale behind
this change and the limitations of using the pidAddrMap directly in concurrent
situations.

Note: This change does not address the case where the client binds to an IP
address other than the local IP before the connect call, as this scenario is
also considered unusual.
  • Loading branch information
hmgle committed May 27, 2024
1 parent d200391 commit 330f301
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 22 deletions.
29 changes: 28 additions & 1 deletion local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func (l *Local) Start() {
l.StartService(ln)
}

func getPidByAddr(localAddr, remoteAddr string, isTCP6 bool) (pid string, destAddr string) {
func _getPidByAddr(localAddr, remoteAddr string, isTCP6 bool) (pid string, destAddr string) {
inode, err := getInodeByAddrs(localAddr, remoteAddr, isTCP6)
if err != nil {
log.Errorf("getInodeByAddrs(%s, %s) err: %s", localAddr, remoteAddr, err.Error())
Expand All @@ -211,6 +211,33 @@ func getPidByAddr(localAddr, remoteAddr string, isTCP6 bool) (pid string, destAd
return
}

func getPidByAddr(localAddr, remoteAddr string, isTCP6 bool) (pid string, destAddr string) {
pid, destAddr = _getPidByAddr(localAddr, remoteAddr, isTCP6)
if pid == "" || destAddr == "" {
// NOTE: There are some unusual cases that can cause the above getPidByAddr
// to fail to obtain the pid and destAddr.
// For example: The Source IP Address field in the IP packet header has been
// rewritten (iptables, nftables, and eBPF can all do this).
// In this case, we try again using "127.0.0.1".
// However, if the client binds to another IP (which is also unusual) before
// the connect call, rather than "127.0.0.1", then the retrieval will also fail.
// Although the pidAddrMap does store the pid and destAddr information we're
// looking for, it cannot guarantee correctness in concurrent situations,
// so we're not planning to use the information in pidAddrMap directly for now.
var localIP string
if isTCP6 {
localIP = localIPv6
} else {
localIP = localIPv4
}
host, port, err := splitAddr(localAddr, isTCP6)
if err == nil && host != localIP {
pid, destAddr = _getPidByAddr(net.JoinHostPort(localIP, port), remoteAddr, isTCP6)
}
}
return pid, destAddr
}

// HandleConn handle conn.
func (l *Local) HandleConn(conn net.Conn) error {
raddr := conn.RemoteAddr()
Expand Down
17 changes: 7 additions & 10 deletions local/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@ import "github.com/jedisct1/dlog"

// Logger represents a general-purpose logger.
type Logger interface {
//Fatalf is critical fatal logging, should possibly followed by system shutdown
// Fatalf is critical fatal logging, should possibly followed by system shutdown
Fatalf(msg string, args ...interface{})

//Errorf is for logging errors
// Errorf is for logging errors
Errorf(msg string, args ...interface{})

//Warnf is for logging messages about possible issues
// Warnf is for logging messages about possible issues
Warnf(msg string, args ...interface{})

//Infof for logging general logging messages
// Infof for logging general logging messages
Infof(msg string, args ...interface{})

//Debugf is for logging verbose messages
// Debugf is for logging verbose messages
Debugf(msg string, args ...interface{})
}

type dlogT struct {
}
type dlogT struct{}

func (d dlogT) Debugf(msg string, args ...interface{}) {
dlog.Debugf(msg, args...)
Expand All @@ -43,9 +42,7 @@ func (d dlogT) Fatalf(msg string, args ...interface{}) {
dlog.Fatalf(msg, args...)
}

var (
log Logger = dlogT{}
)
var log Logger = dlogT{}

// SetLogger allows users to inject their own logger instead of the default one.
func SetLogger(l Logger) {
Expand Down
26 changes: 15 additions & 11 deletions local/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,11 @@ func getInodeByAddrs(localAddr, remoteAddr string, isTCP6 bool) (inode string, e
remoteIP string
remotePort string
)
if isTCP6 {
localIP, localPort, err = splitAddrIPv6(localAddr)
} else {
localIP, localPort, err = splitAddrIPv4(localAddr)
}
localIP, localPort, err = splitAddr(localAddr, isTCP6)
if err != nil {
return
}
if isTCP6 {
remoteIP, remotePort, err = splitAddrIPv6(remoteAddr)
} else {
remoteIP, remotePort, err = splitAddrIPv4(remoteAddr)
}
remoteIP, remotePort, err = splitAddr(remoteAddr, isTCP6)
if err != nil {
return
}
Expand All @@ -94,6 +86,11 @@ func getInodeByAddrs(localAddr, remoteAddr string, isTCP6 bool) (inode string, e
return getInode(localIPHex+":"+localPortHex, remoteIPHex+":"+remotePortHex, isTCP6)
}

const (
localIPv4 = "127.0.0.1"
localIPv6 = "[::1]"
)

// addr format: "127.0.0.1:53816"
func splitAddrIPv4(addr string) (ipv4 string, port string, err error) {
addrSplit := strings.Split(addr, ":")
Expand All @@ -103,7 +100,7 @@ func splitAddrIPv4(addr string) (ipv4 string, port string, err error) {
}
ipv4 = addrSplit[0]
if ipv4 == "" {
ipv4 = "127.0.0.1"
ipv4 = localIPv4
}
port = addrSplit[1]
return
Expand All @@ -121,6 +118,13 @@ func splitAddrIPv6(addr string) (ipv6 string, port string, err error) {
return
}

func splitAddr(addr string, isTCP6 bool) (ip string, port string, err error) {
if isTCP6 {
return splitAddrIPv6(addr)
}
return splitAddrIPv4(addr)
}

// getInode get the inode, localAddrHex format: 0100007F:04D2
func getInode(localAddrHex, remoteAddrHex string, isTCP6 bool) (inode string, err error) {
var path string
Expand Down

0 comments on commit 330f301

Please sign in to comment.