Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pkg/sentry/socket/unix/transport/connectioned.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,21 @@ func (e *connectionedEndpoint) Close(ctx context.Context) {
e.Unlock()
if acceptedChan != nil {
for n := range acceptedChan {
// When listener is closed, pending connections should receive
// ECONNRESET instead of EOF to match Linux behavior.
n.Lock()
if n.connected != nil {
// Try to set SO_ERROR on the client endpoint so that
// getsockopt(SO_ERROR) or read() returns ECONNRESET.
if ce, ok := n.connected.(*connectedEndpoint); ok {
if clientEP, ok := ce.endpoint.(*connectionedEndpoint); ok {
clientEP.SocketOptions().SetLastError(&tcpip.ErrConnectionReset{})
// Notify waiter queue about error events so epoll detects EPOLLERR.
clientEP.Queue.Notify(waiter.EventErr)
}
}
}
n.Unlock()
n.Close(ctx)
}
}
Expand Down Expand Up @@ -589,6 +604,15 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
ready |= waiter.EventHUp
}
}
// Check for error condition (SO_ERROR is set).
if mask&waiter.EventErr != 0 {
e.lastErrorMu.Lock()
hasError := e.lastError != nil
e.lastErrorMu.Unlock()
if hasError {
ready |= waiter.EventErr
}
}
case e.ListeningLocked():
if mask&waiter.ReadableEvents != 0 && (len(e.acceptedChan) > 0 || e.isBoundSocketReadable()) {
ready |= waiter.ReadableEvents
Expand Down
28 changes: 26 additions & 2 deletions pkg/sentry/socket/unix/transport/unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/waiter"
Expand Down Expand Up @@ -876,6 +877,11 @@ type baseEndpoint struct {

// ops is used to get socket level options.
ops tcpip.SocketOptions

// lastError is the last error returned by getsockopt(SO_ERROR).
// This field is protected by lastErrorMu.
lastErrorMu sync.Mutex `state:"nosave"`
lastError tcpip.Error
}

// EventRegister implements waiter.Waitable.EventRegister.
Expand Down Expand Up @@ -934,6 +940,12 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, args RecvArgs

out, notify, err := receiver.Recv(ctx, data, args)
if err != nil {
// Check if there's a pending error (e.g., ECONNRESET set when listener closed).
if err == syserr.ErrClosedForReceive || err == syserr.ErrWouldBlock {
if lastErr := e.LastError(); lastErr != nil {
return RecvOutput{}, nil, syserr.TranslateNetstackError(lastErr)
}
}
return RecvOutput{}, nil, err
}

Expand Down Expand Up @@ -1021,8 +1033,20 @@ func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
}

// LastError implements Endpoint.LastError.
func (*baseEndpoint) LastError() tcpip.Error {
return nil
// It clears and returns the last error.
func (e *baseEndpoint) LastError() tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
e.lastError = nil
return err
}

// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
func (e *baseEndpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
}

// SocketOptions implements Endpoint.SocketOptions.
Expand Down
4 changes: 4 additions & 0 deletions test/syscalls/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,10 @@ syscall_test(
test = "//test/syscalls/linux:socket_unix_stream_test",
)

syscall_test(
test = "//test/syscalls/linux:socket_unix_stream_listener_close_test",
)

syscall_test(
size = "medium",
# TODO(b/323000153): Test fails with S/R enabled during restore of abstract
Expand Down
16 changes: 16 additions & 0 deletions test/syscalls/linux/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3068,6 +3068,22 @@ cc_binary(
],
)

cc_binary(
name = "socket_unix_stream_listener_close_test",
testonly = 1,
srcs = ["socket_unix_stream_listener_close.cc"],
linkstatic = 1,
malloc = "//test/util:errno_safe_allocator",
deps = select_gtest() + [
"//test/util:file_descriptor",
"//test/util:posix_error",
"//test/util:socket_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
],
)

cc_binary(
name = "socket_ip_tcp_generic_loopback_test",
testonly = 1,
Expand Down
132 changes: 132 additions & 0 deletions test/syscalls/linux/socket_unix_stream_listener_close.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2026 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <errno.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

#include <cstddef>
#include <cstdio>
#include <cstring>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
#include "test/util/socket_util.h"
#include "test/util/test_util.h"

namespace gvisor {
namespace testing {

namespace {

// Test fixture for Unix stream socket listener close tests.
// Sets up a listener, connects a client, then closes the listener
// while the connection is pending (not accepted).
class UnixStreamListenerCloseTest : public ::testing::Test {
protected:
void SetUp() override {
// Use abstract socket namespace to avoid file system issues.
addr_.sun_family = AF_UNIX;
addr_.sun_path[0] = '\0'; // Abstract namespace.
snprintf(&addr_.sun_path[1], sizeof(addr_.sun_path) - 1,
"test_listener_close_%d_%p", getpid(), this);

addr_len_ =
offsetof(struct sockaddr_un, sun_path) + 1 + strlen(&addr_.sun_path[1]);

// Create and setup the listener socket.
listener_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0));

ASSERT_THAT(bind(listener_.get(),
reinterpret_cast<struct sockaddr*>(&addr_), addr_len_),
SyscallSucceeds());
ASSERT_THAT(listen(listener_.get(), 5), SyscallSucceeds());

// Create a client and connect (but don't accept on the listener).
client_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0));
ASSERT_THAT(connect(client_.get(),
reinterpret_cast<struct sockaddr*>(&addr_), addr_len_),
SyscallSucceeds());

// Close the listener while the connection is pending (not accepted).
listener_.reset();
}

struct sockaddr_un addr_;
socklen_t addr_len_;
FileDescriptor listener_;
FileDescriptor client_;
};

// Test that when a Unix stream socket listener is closed while there are
// pending (connected but not accepted) connections, the client receives
// ECONNRESET instead of EOF.
//
// This matches Linux kernel behavior where closing the listener sends RST
// to pending connections rather than FIN.
TEST_F(UnixStreamListenerCloseTest, PendingConnectionGetsECONNRESET) {
// Check epoll events - should include EPOLLERR.
int epoll_fd = epoll_create1(0);
ASSERT_GE(epoll_fd, 0);
FileDescriptor epfd_wrapper(epoll_fd);

struct epoll_event ev;
ev.events = EPOLLIN | EPOLLOUT | EPOLLERR | EPOLLHUP;
ev.data.fd = client_.get();
ASSERT_THAT(epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_.get(), &ev),
SyscallSucceeds());

struct epoll_event events[1];
ASSERT_THAT(epoll_wait(epoll_fd, events, 1, 1000),
SyscallSucceedsWithValue(1));

// Verify EPOLLERR is set (in addition to EPOLLHUP).
EXPECT_TRUE(events[0].events & EPOLLHUP) << "Expected EPOLLHUP to be set";
EXPECT_TRUE(events[0].events & EPOLLERR) << "Expected EPOLLERR to be set";

// The first read should return ECONNRESET.
char buf[10];
EXPECT_THAT(read(client_.get(), buf, sizeof(buf)),
SyscallFailsWithErrno(ECONNRESET));

// After the error is consumed, subsequent reads should return EOF (0).
EXPECT_THAT(read(client_.get(), buf, sizeof(buf)),
SyscallSucceedsWithValue(0));
}

// Test that getsockopt(SO_ERROR) returns ECONNRESET for pending connections
// when the listener is closed.
TEST_F(UnixStreamListenerCloseTest, SOErrorReturnsECONNRESET) {
// getsockopt(SO_ERROR) should return ECONNRESET.
int err = 0;
socklen_t len = sizeof(err);
ASSERT_THAT(getsockopt(client_.get(), SOL_SOCKET, SO_ERROR, &err, &len),
SyscallSucceeds());
EXPECT_EQ(err, ECONNRESET);

// Second call to getsockopt(SO_ERROR) should return 0 (error cleared).
err = -1;
ASSERT_THAT(getsockopt(client_.get(), SOL_SOCKET, SO_ERROR, &err, &len),
SyscallSucceeds());
EXPECT_EQ(err, 0);
}

} // namespace

} // namespace testing
} // namespace gvisor
Loading