Skip to content

Commit

Permalink
Fix recvfrom goroutine leak
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai authored and aboch committed Aug 23, 2024
1 parent 298ff27 commit 6f57139
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
60 changes: 37 additions & 23 deletions nl/nl_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/binary"
"fmt"
"net"
"os"
"runtime"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -655,8 +656,9 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
}

type NetlinkSocket struct {
fd int32
lsa unix.SockaddrNetlink
fd int32
file *os.File
lsa unix.SockaddrNetlink
sync.Mutex
}

Expand All @@ -665,8 +667,13 @@ func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, err
}
s := &NetlinkSocket{
fd: int32(fd),
fd: int32(fd),
file: os.NewFile(uintptr(fd), "netlink"),
}
s.lsa.Family = unix.AF_NETLINK
if err := unix.Bind(fd, &s.lsa); err != nil {
Expand Down Expand Up @@ -753,8 +760,13 @@ func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
if err != nil {
return nil, err
}
err = unix.SetNonblock(fd, true)
if err != nil {
return nil, err
}
s := &NetlinkSocket{
fd: int32(fd),
fd: int32(fd),
file: os.NewFile(uintptr(fd), "netlink"),
}
s.lsa.Family = unix.AF_NETLINK

Expand Down Expand Up @@ -783,33 +795,36 @@ func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*Ne
}

func (s *NetlinkSocket) Close() {
fd := int(atomic.SwapInt32(&s.fd, -1))
unix.Close(fd)
s.file.Close()
}

func (s *NetlinkSocket) GetFd() int {
return int(atomic.LoadInt32(&s.fd))
return int(s.fd)
}

func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return fmt.Errorf("Send called on a closed socket")
}
if err := unix.Sendto(fd, request.Serialize(), 0, &s.lsa); err != nil {
return err
}
return nil
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
}

func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return nil, nil, fmt.Errorf("Receive called on a closed socket")
rawConn, err := s.file.SyscallConn()
if err != nil {
return nil, nil, err
}
var (
fromAddr *unix.SockaddrNetlink
rb [RECEIVE_BUFFER_SIZE]byte
nr int
from unix.Sockaddr
innerErr error
)
err = rawConn.Read(func(fd uintptr) (done bool) {
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
var fromAddr *unix.SockaddrNetlink
var rb [RECEIVE_BUFFER_SIZE]byte
nr, from, err := unix.Recvfrom(fd, rb[:], 0)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -864,8 +879,7 @@ func (s *NetlinkSocket) SetExtAck(enable bool) error {
}

func (s *NetlinkSocket) GetPid() (uint32, error) {
fd := int(atomic.LoadInt32(&s.fd))
lsa, err := unix.Getsockname(fd)
lsa, err := unix.Getsockname(int(s.fd))
if err != nil {
return 0, err
}
Expand Down
4 changes: 1 addition & 3 deletions nl/nl_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,12 @@ func TestIfSocketCloses(t *testing.T) {
if err != nil {
t.Fatalf("Error on creating the socket: %v", err)
}
nlSock.SetReceiveTimeout(&unix.Timeval{Sec: 2, Usec: 0})
endCh := make(chan error)
go func(sk *NetlinkSocket, endCh chan error) {
endCh <- nil
for {
_, _, err := sk.Receive()
// Receive returned because of a timeout and the FD == -1 means that the socket got closed
if err == unix.EAGAIN && nlSock.GetFd() == -1 {
if err == unix.EAGAIN {
endCh <- err
return
}
Expand Down

0 comments on commit 6f57139

Please sign in to comment.