Skip to content

rpc_util: Fix RecvBufferPool deactivation issues #6766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 149 additions & 47 deletions experimental/shared_buffer_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/experimental"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

type s struct {
Expand All @@ -44,59 +44,161 @@ func Test(t *testing.T) {

const defaultTestTimeout = 10 * time.Second

func (s) TestRecvBufferPool(t *testing.T) {
ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < 10; i++ {
preparedMsg := &grpc.PreparedMsg{}
err := preparedMsg.Encode(stream, &testpb.StreamingOutputCallResponse{
Payload: &testpb.Payload{
Body: []byte{'0' + uint8(i)},
},
})
func (s) TestRecvBufferPoolStream(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const reqCount = 10

ss := &stubserver.StubServer{
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for i := 0; i < reqCount; i++ {
preparedMsg := &grpc.PreparedMsg{}
if err := preparedMsg.Encode(stream, &testgrpc.StreamingOutputCallResponse{
Payload: &testgrpc.Payload{
Body: []byte{'0' + uint8(i)},
},
}); err != nil {
return err
}
stream.SendMsg(preparedMsg)
}
return nil
},
}

pool := &checkBufferPool{}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

stream, err := ss.Client.FullDuplexCall(ctx, tc.callOpts...)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %v", err)
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return err
t.Fatal(err)
}
stream.SendMsg(preparedMsg)
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
return nil
},
if want := 10; ngot != want {
t.Fatalf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Fatalf("Got replies %q; want %q", got, want)
}

if len(pool.puts) != reqCount {
t.Fatalf("Expected 10 buffers to be returned to the pool, got %d", len(pool.puts))
}
})
}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(grpc.NewSharedBufferPool())}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(grpc.NewSharedBufferPool())}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}

func (s) TestRecvBufferPoolUnary(t *testing.T) {
tcs := []struct {
name string
callOpts []grpc.CallOption
}{
{
name: "default",
},
{
name: "useCompressor",
callOpts: []grpc.CallOption{
grpc.UseCompressor(gzip.Name),
},
},
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
const largeSize = 1024

stream, err := ss.Client.FullDuplexCall(ctx)
if err != nil {
t.Fatalf("ss.Client.FullDuplexCall failed: %f", err)
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testgrpc.SimpleRequest) (*testgrpc.SimpleResponse, error) {
return &testgrpc.SimpleResponse{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
}, nil
},
}

var ngot int
var buf bytes.Buffer
for {
reply, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
ngot++
if buf.Len() > 0 {
buf.WriteByte(',')
}
buf.Write(reply.GetPayload().GetBody())
}
if want := 10; ngot != want {
t.Errorf("Got %d replies, want %d", ngot, want)
}
if got, want := buf.String(), "0,1,2,3,4,5,6,7,8,9"; got != want {
t.Errorf("Got replies %q; want %q", got, want)
pool := &checkBufferPool{}
sopts := []grpc.ServerOption{experimental.RecvBufferPool(pool)}
dopts := []grpc.DialOption{experimental.WithRecvBufferPool(pool)}
if err := ss.Start(sopts, dopts...); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

const reqCount = 10
for i := 0; i < reqCount; i++ {
if _, err := ss.Client.UnaryCall(
ctx,
&testgrpc.SimpleRequest{
Payload: &testgrpc.Payload{
Body: make([]byte, largeSize),
},
},
tc.callOpts...,
); err != nil {
t.Fatalf("ss.Client.UnaryCall failed: %v", err)
}
}

const bufferCount = reqCount * 2 // req + resp
if len(pool.puts) != bufferCount {
t.Fatalf("Expected %d buffers to be returned to the pool, got %d", bufferCount, len(pool.puts))
}
})
}
}

type checkBufferPool struct {
puts [][]byte
}

func (p *checkBufferPool) Get(size int) []byte {
return make([]byte, size)
}

func (p *checkBufferPool) Put(bs *[]byte) {
p.puts = append(p.puts, *bs)
}
54 changes: 35 additions & 19 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -726,39 +726,55 @@ type payloadInfo struct {
uncompressedBytes []byte
}

func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
// recvAndDecompress reads a message from the stream, decompressing it if necessary.
//
// Cancelling the returned cancel function releases the buffer back to the pool. So the caller should cancel as soon as
// the buffer is no longer needed.
func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor,
) (uncompressedBuf []byte, cancel func(), err error) {
pf, compressedBuf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.compressedLength = len(buf)
return nil, nil, err
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
return nil, st.Err()
return nil, nil, st.Err()
}

var size int
if pf == compressionMade {
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
uncompressedBuf, err = dc.Do(bytes.NewReader(compressedBuf))
size = len(uncompressedBuf)
} else {
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
uncompressedBuf, size, err = decompress(compressor, compressedBuf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
return nil, nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
}
if size > maxReceiveMessageSize {
// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
return nil, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
} else {
uncompressedBuf = compressedBuf
}
return buf, nil

if payInfo != nil {
payInfo.compressedLength = len(compressedBuf)
payInfo.uncompressedBytes = uncompressedBuf

cancel = func() {}
} else {
cancel = func() {
p.recvBufferPool.Put(&compressedBuf)
}
}

return uncompressedBuf, cancel, nil
}

// Using compressor, decompress d, returning data and size.
Expand All @@ -778,6 +794,9 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// size is used as an estimate to size the buffer, but we
// will read more data if available.
// +MinRead so ReadFrom will not reallocate if size is correct.
//
// TODO: If we ensure that the buffer size is the same as the DecompressedSize,
// we can also utilize the recv buffer pool here.
buf := bytes.NewBuffer(make([]byte, 0, size+bytes.MinRead))
bytesRead, err := buf.ReadFrom(io.LimitReader(dcReader, int64(maxReceiveMessageSize)+1))
return buf.Bytes(), int(bytesRead), err
Expand All @@ -793,18 +812,15 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m any, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, cancel, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
defer cancel()

if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil
}

Expand Down
5 changes: 4 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1319,7 +1319,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
}
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)

d, cancel, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand All @@ -1330,6 +1331,8 @@ func (s *Server) processUnaryRPC(ctx context.Context, t transport.ServerTranspor
t.IncrMsgRecv()
}
df := func(v any) error {
defer cancel()

if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
}
Expand Down