Skip to content

Commit 5d87d88

Browse files
iangudgershentubot
authored andcommitted
Implement MSG_WAITALL
MSG_WAITALL requests that recv family calls do not perform short reads. It only has an effect for SOCK_STREAM sockets, other types ignore it. PiperOrigin-RevId: 224918540 Change-Id: Id97fbf972f1f7cbd4e08eec0138f8cbdf1c94fe7
1 parent d3bc79b commit 5d87d88

File tree

7 files changed

+79
-27
lines changed

7 files changed

+79
-27
lines changed

pkg/sentry/fs/host/socket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F
169169

170170
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
171171

172-
return unixsocket.NewWithDirent(ctx, d, ep, flags), nil
172+
return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil
173173
}
174174

175175
// newSocket allocates a new unix socket with host endpoint.
@@ -201,7 +201,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
201201

202202
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
203203

204-
return unixsocket.New(ctx, ep), nil
204+
return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil
205205
}
206206

207207
// Send implements transport.ConnectedEndpoint.Send.

pkg/sentry/socket/epsocket/epsocket.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
13001300
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
13011301
trunc := flags&linux.MSG_TRUNC != 0
13021302
peek := flags&linux.MSG_PEEK != 0
1303+
dontWait := flags&linux.MSG_DONTWAIT != 0
1304+
waitAll := flags&linux.MSG_WAITALL != 0
13031305
if senderRequested && !s.isPacketBased() {
13041306
// Stream sockets ignore the sender address.
13051307
senderRequested = false
@@ -1311,21 +1313,43 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
13111313
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
13121314
}
13131315

1314-
if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
1316+
if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
1317+
// Read failed and we should not retry.
1318+
return 0, nil, 0, socket.ControlMessages{}, err
1319+
}
1320+
1321+
if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
1322+
// We got all the data we need.
13151323
return
13161324
}
13171325

1326+
// Don't overwrite any data we received.
1327+
dst = dst.DropFirst(n)
1328+
13181329
// We'll have to block. Register for notifications and keep trying to
13191330
// send all the data.
13201331
e, ch := waiter.NewChannelEntry(nil)
13211332
s.EventRegister(&e, waiter.EventIn)
13221333
defer s.EventUnregister(&e)
13231334

13241335
for {
1325-
n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
1326-
if err != syserr.ErrWouldBlock {
1336+
var rn int
1337+
rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
1338+
n += rn
1339+
if err != nil && err != syserr.ErrWouldBlock {
1340+
// Always stop on errors other than would block as we generally
1341+
// won't be able to get any more data. Eat the error if we got
1342+
// any data.
1343+
if n > 0 {
1344+
err = nil
1345+
}
1346+
return
1347+
}
1348+
if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
1349+
// We got all the data we need.
13271350
return
13281351
}
1352+
dst = dst.DropFirst(rn)
13291353

13301354
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
13311355
if err == syserror.ETIMEDOUT {

pkg/sentry/socket/unix/unix.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,21 @@ type SocketOperations struct {
5353
fsutil.NoopFlush `state:"nosave"`
5454
fsutil.NoMMap `state:"nosave"`
5555
ep transport.Endpoint
56+
isPacket bool
5657
}
5758

5859
// New creates a new unix socket.
59-
func New(ctx context.Context, endpoint transport.Endpoint) *fs.File {
60+
func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
6061
dirent := socket.NewDirent(ctx, unixSocketDevice)
6162
defer dirent.DecRef()
62-
return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true})
63+
return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
6364
}
6465

6566
// NewWithDirent creates a new unix socket using an existing dirent.
66-
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File {
67+
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
6768
return fs.NewFile(ctx, d, flags, &SocketOperations{
68-
ep: ep,
69+
ep: ep,
70+
isPacket: isPacket,
6971
})
7072
}
7173

@@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
188190
}
189191
}
190192

191-
ns := New(t, ep)
193+
ns := New(t, ep, s.isPacket)
192194
defer ns.DecRef()
193195

194196
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
471473
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
472474
trunc := flags&linux.MSG_TRUNC != 0
473475
peek := flags&linux.MSG_PEEK != 0
476+
dontWait := flags&linux.MSG_DONTWAIT != 0
477+
waitAll := flags&linux.MSG_WAITALL != 0
474478

475479
// Calculate the number of FDs for which we have space and if we are
476480
// requesting credentials.
@@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
497501
if senderRequested {
498502
r.From = &tcpip.FullAddress{}
499503
}
500-
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
504+
var total int64
505+
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
501506
var from interface{}
502507
var fromLen uint32
503508
if r.From != nil {
@@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
506511
if trunc {
507512
n = int64(r.MsgSize)
508513
}
509-
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
514+
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
515+
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
516+
}
517+
518+
// Don't overwrite any data we received.
519+
dst = dst.DropFirst64(n)
520+
total += n
510521
}
511522

512523
// We'll have to block. Register for notification and keep trying to
@@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
525536
if trunc {
526537
n = int64(r.MsgSize)
527538
}
528-
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
539+
total += n
540+
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
541+
return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
542+
}
543+
544+
// Don't overwrite any data we received.
545+
dst = dst.DropFirst64(n)
529546
}
530547

531548
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
@@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
549566

550567
// Create the endpoint and socket.
551568
var ep transport.Endpoint
569+
var isPacket bool
552570
switch stype {
553571
case linux.SOCK_DGRAM:
572+
isPacket = true
554573
ep = transport.NewConnectionless()
555-
case linux.SOCK_STREAM, linux.SOCK_SEQPACKET:
574+
case linux.SOCK_SEQPACKET:
575+
isPacket = true
576+
fallthrough
577+
case linux.SOCK_STREAM:
556578
ep = transport.NewConnectioned(stype, t.Kernel())
557579
default:
558580
return nil, syserr.ErrInvalidArgument
559581
}
560582

561-
return New(t, ep), nil
583+
return New(t, ep, isPacket), nil
562584
}
563585

564586
// Pair creates a new pair of AF_UNIX connected sockets.
@@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*
568590
return nil, nil, syserr.ErrInvalidArgument
569591
}
570592

593+
var isPacket bool
571594
switch stype {
572-
case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
595+
case linux.SOCK_STREAM:
596+
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
597+
isPacket = true
573598
default:
574599
return nil, nil, syserr.ErrInvalidArgument
575600
}
576601

577602
// Create the endpoints and sockets.
578603
ep1, ep2 := transport.NewPair(stype, t.Kernel())
579-
s1 := New(t, ep1)
580-
s2 := New(t, ep2)
604+
s1 := New(t, ep1, isPacket)
605+
s2 := New(t, ep2, isPacket)
581606

582607
return s1, s2, nil
583608
}

pkg/sentry/syscalls/linux/sys_socket.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
602602
}
603603

604604
// Reject flags that we don't handle yet.
605-
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
605+
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
606606
return 0, nil, syscall.EINVAL
607607
}
608608

@@ -635,7 +635,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
635635
}
636636

637637
// Reject flags that we don't handle yet.
638-
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
638+
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
639639
return 0, nil, syscall.EINVAL
640640
}
641641

@@ -791,7 +791,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f
791791
}
792792

793793
// Reject flags that we don't handle yet.
794-
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM) != 0 {
794+
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM|linux.MSG_WAITALL) != 0 {
795795
return 0, syscall.EINVAL
796796
}
797797

test/syscalls/linux/socket_generic.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,6 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) {
383383
}
384384

385385
TEST_P(AllSocketPairTest, RecvWaitAll) {
386-
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
387-
388386
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
389387

390388
char sent_data[100];
@@ -399,5 +397,14 @@ TEST_P(AllSocketPairTest, RecvWaitAll) {
399397
SyscallSucceedsWithValue(sizeof(sent_data)));
400398
}
401399

400+
TEST_P(AllSocketPairTest, RecvWaitAllDontWait) {
401+
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
402+
403+
char data[100] = {};
404+
ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data),
405+
MSG_WAITALL | MSG_DONTWAIT),
406+
SyscallFailsWithErrno(EAGAIN));
407+
}
408+
402409
} // namespace testing
403410
} // namespace gvisor

test/syscalls/linux/socket_non_stream_blocking.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ namespace gvisor {
3131
namespace testing {
3232

3333
TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
34-
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
35-
3634
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
3735

3836
char sent_data[100];

test/syscalls/linux/socket_stream_blocking.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,6 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) {
9999
}
100100

101101
TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
102-
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
103-
104102
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
105103

106104
char sent_data[100];

0 commit comments

Comments
 (0)