@@ -53,19 +53,21 @@ type SocketOperations struct {
5353 fsutil.NoopFlush `state:"nosave"`
5454 fsutil.NoMMap `state:"nosave"`
5555 ep transport.Endpoint
56+ isPacket bool
5657}
5758
5859// New creates a new unix socket.
59- func New (ctx context.Context , endpoint transport.Endpoint ) * fs.File {
60+ func New (ctx context.Context , endpoint transport.Endpoint , isPacket bool ) * fs.File {
6061 dirent := socket .NewDirent (ctx , unixSocketDevice )
6162 defer dirent .DecRef ()
62- return NewWithDirent (ctx , dirent , endpoint , fs.FileFlags {Read : true , Write : true })
63+ return NewWithDirent (ctx , dirent , endpoint , isPacket , fs.FileFlags {Read : true , Write : true })
6364}
6465
6566// NewWithDirent creates a new unix socket using an existing dirent.
66- func NewWithDirent (ctx context.Context , d * fs.Dirent , ep transport.Endpoint , flags fs.FileFlags ) * fs.File {
67+ func NewWithDirent (ctx context.Context , d * fs.Dirent , ep transport.Endpoint , isPacket bool , flags fs.FileFlags ) * fs.File {
6768 return fs .NewFile (ctx , d , flags , & SocketOperations {
68- ep : ep ,
69+ ep : ep ,
70+ isPacket : isPacket ,
6971 })
7072}
7173
@@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
188190 }
189191 }
190192
191- ns := New (t , ep )
193+ ns := New (t , ep , s . isPacket )
192194 defer ns .DecRef ()
193195
194196 if flags & linux .SOCK_NONBLOCK != 0 {
@@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
471473func (s * SocketOperations ) RecvMsg (t * kernel.Task , dst usermem.IOSequence , flags int , haveDeadline bool , deadline ktime.Time , senderRequested bool , controlDataLen uint64 ) (n int , senderAddr interface {}, senderAddrLen uint32 , controlMessages socket.ControlMessages , err * syserr.Error ) {
472474 trunc := flags & linux .MSG_TRUNC != 0
473475 peek := flags & linux .MSG_PEEK != 0
476+ dontWait := flags & linux .MSG_DONTWAIT != 0
477+ waitAll := flags & linux .MSG_WAITALL != 0
474478
475479 // Calculate the number of FDs for which we have space and if we are
476480 // requesting credentials.
@@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
497501 if senderRequested {
498502 r .From = & tcpip.FullAddress {}
499503 }
500- if n , err := dst .CopyOutFrom (t , & r ); err != syserror .ErrWouldBlock || flags & linux .MSG_DONTWAIT != 0 {
504+ var total int64
505+ if n , err := dst .CopyOutFrom (t , & r ); err != syserror .ErrWouldBlock || dontWait {
501506 var from interface {}
502507 var fromLen uint32
503508 if r .From != nil {
@@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
506511 if trunc {
507512 n = int64 (r .MsgSize )
508513 }
509- return int (n ), from , fromLen , socket.ControlMessages {Unix : r .Control }, syserr .FromError (err )
514+ if err != nil || dontWait || ! waitAll || s .isPacket || n >= dst .NumBytes () {
515+ return int (n ), from , fromLen , socket.ControlMessages {Unix : r .Control }, syserr .FromError (err )
516+ }
517+
518+ // Don't overwrite any data we received.
519+ dst = dst .DropFirst64 (n )
520+ total += n
510521 }
511522
512523 // We'll have to block. Register for notification and keep trying to
@@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
525536 if trunc {
526537 n = int64 (r .MsgSize )
527538 }
528- return int (n ), from , fromLen , socket.ControlMessages {Unix : r .Control }, syserr .FromError (err )
539+ total += n
540+ if err != nil || ! waitAll || s .isPacket || n >= dst .NumBytes () {
541+ return int (total ), from , fromLen , socket.ControlMessages {Unix : r .Control }, syserr .FromError (err )
542+ }
543+
544+ // Don't overwrite any data we received.
545+ dst = dst .DropFirst64 (n )
529546 }
530547
531548 if err := t .BlockWithDeadline (ch , haveDeadline , deadline ); err != nil {
@@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
549566
550567 // Create the endpoint and socket.
551568 var ep transport.Endpoint
569+ var isPacket bool
552570 switch stype {
553571 case linux .SOCK_DGRAM :
572+ isPacket = true
554573 ep = transport .NewConnectionless ()
555- case linux .SOCK_STREAM , linux .SOCK_SEQPACKET :
574+ case linux .SOCK_SEQPACKET :
575+ isPacket = true
576+ fallthrough
577+ case linux .SOCK_STREAM :
556578 ep = transport .NewConnectioned (stype , t .Kernel ())
557579 default :
558580 return nil , syserr .ErrInvalidArgument
559581 }
560582
561- return New (t , ep ), nil
583+ return New (t , ep , isPacket ), nil
562584}
563585
564586// Pair creates a new pair of AF_UNIX connected sockets.
@@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*
568590 return nil , nil , syserr .ErrInvalidArgument
569591 }
570592
593+ var isPacket bool
571594 switch stype {
572- case linux .SOCK_STREAM , linux .SOCK_DGRAM , linux .SOCK_SEQPACKET :
595+ case linux .SOCK_STREAM :
596+ case linux .SOCK_DGRAM , linux .SOCK_SEQPACKET :
597+ isPacket = true
573598 default :
574599 return nil , nil , syserr .ErrInvalidArgument
575600 }
576601
577602 // Create the endpoints and sockets.
578603 ep1 , ep2 := transport .NewPair (stype , t .Kernel ())
579- s1 := New (t , ep1 )
580- s2 := New (t , ep2 )
604+ s1 := New (t , ep1 , isPacket )
605+ s2 := New (t , ep2 , isPacket )
581606
582607 return s1 , s2 , nil
583608}
0 commit comments