Skip to content

Commit

Permalink
rpc_util: Reuse memory buffer for receiving message (#5862)
Browse files Browse the repository at this point in the history
  • Loading branch information
hueypark authored Jun 27, 2023
1 parent 789cf4e commit 1634254
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 20 deletions.
33 changes: 33 additions & 0 deletions benchmark/benchmain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ var (
serverWriteBufferSize = flags.IntSlice("serverWriteBufferSize", []int{-1}, "Configures the server write buffer size in bytes. If negative, use the default - may be a a comma-separated list")
sleepBetweenRPCs = flags.DurationSlice("sleepBetweenRPCs", []time.Duration{0}, "Configures the maximum amount of time the client should sleep between consecutive RPCs - may be a a comma-separated list")
connections = flag.Int("connections", 1, "The number of connections. Each connection will handle maxConcurrentCalls RPC streams")
recvBufferPool = flags.StringWithAllowedValues("recvBufferPool", recvBufferPoolNil, "Configures the shared receive buffer pool. One of: nil, simple, all", allRecvBufferPools)

logger = grpclog.Component("benchmark")
)
Expand All @@ -136,6 +137,10 @@ const (
networkModeLAN = "LAN"
networkModeWAN = "WAN"
networkLongHaul = "Longhaul"
// Shared recv buffer pool
recvBufferPoolNil = "nil"
recvBufferPoolSimple = "simple"
recvBufferPoolAll = "all"

numStatsBuckets = 10
warmupCallCount = 10
Expand All @@ -147,6 +152,7 @@ var (
allCompModes = []string{compModeOff, compModeGzip, compModeNop, compModeAll}
allToggleModes = []string{toggleModeOff, toggleModeOn, toggleModeBoth}
allNetworkModes = []string{networkModeNone, networkModeLocal, networkModeLAN, networkModeWAN, networkLongHaul}
allRecvBufferPools = []string{recvBufferPoolNil, recvBufferPoolSimple, recvBufferPoolAll}
defaultReadLatency = []time.Duration{0, 40 * time.Millisecond} // if non-positive, no delay.
defaultReadKbps = []int{0, 10240} // if non-positive, infinite
defaultReadMTU = []int{0} // if non-positive, infinite
Expand Down Expand Up @@ -330,6 +336,15 @@ func makeClients(bf stats.Features) ([]testgrpc.BenchmarkServiceClient, func())
if bf.ServerWriteBufferSize >= 0 {
sopts = append(sopts, grpc.WriteBufferSize(bf.ServerWriteBufferSize))
}
switch bf.RecvBufferPool {
case recvBufferPoolNil:
// Do nothing.
case recvBufferPoolSimple:
opts = append(opts, grpc.WithRecvBufferPool(grpc.NewSharedBufferPool()))
sopts = append(sopts, grpc.RecvBufferPool(grpc.NewSharedBufferPool()))
default:
logger.Fatalf("Unknown shared recv buffer pool type: %v", bf.RecvBufferPool)
}

sopts = append(sopts, grpc.MaxConcurrentStreams(uint32(bf.MaxConcurrentCalls+1)))
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down Expand Up @@ -573,6 +588,7 @@ type featureOpts struct {
serverReadBufferSize []int
serverWriteBufferSize []int
sleepBetweenRPCs []time.Duration
recvBufferPools []string
}

// makeFeaturesNum returns a slice of ints of size 'maxFeatureIndex' where each
Expand Down Expand Up @@ -619,6 +635,8 @@ func makeFeaturesNum(b *benchOpts) []int {
featuresNum[i] = len(b.features.serverWriteBufferSize)
case stats.SleepBetweenRPCs:
featuresNum[i] = len(b.features.sleepBetweenRPCs)
case stats.RecvBufferPool:
featuresNum[i] = len(b.features.recvBufferPools)
default:
log.Fatalf("Unknown feature index %v in generateFeatures. maxFeatureIndex is %v", i, stats.MaxFeatureIndex)
}
Expand Down Expand Up @@ -687,6 +705,7 @@ func (b *benchOpts) generateFeatures(featuresNum []int) []stats.Features {
ServerReadBufferSize: b.features.serverReadBufferSize[curPos[stats.ServerReadBufferSize]],
ServerWriteBufferSize: b.features.serverWriteBufferSize[curPos[stats.ServerWriteBufferSize]],
SleepBetweenRPCs: b.features.sleepBetweenRPCs[curPos[stats.SleepBetweenRPCs]],
RecvBufferPool: b.features.recvBufferPools[curPos[stats.RecvBufferPool]],
}
if len(b.features.reqPayloadCurves) == 0 {
f.ReqSizeBytes = b.features.reqSizeBytes[curPos[stats.ReqSizeBytesIndex]]
Expand Down Expand Up @@ -759,6 +778,7 @@ func processFlags() *benchOpts {
serverReadBufferSize: append([]int(nil), *serverReadBufferSize...),
serverWriteBufferSize: append([]int(nil), *serverWriteBufferSize...),
sleepBetweenRPCs: append([]time.Duration(nil), *sleepBetweenRPCs...),
recvBufferPools: setRecvBufferPool(*recvBufferPool),
},
}

Expand Down Expand Up @@ -834,6 +854,19 @@ func setCompressorMode(val string) []string {
}
}

func setRecvBufferPool(val string) []string {
switch val {
case recvBufferPoolNil, recvBufferPoolSimple:
return []string{val}
case recvBufferPoolAll:
return []string{recvBufferPoolNil, recvBufferPoolSimple}
default:
// This should never happen because a wrong value passed to this flag would
// be caught during flag.Parse().
return []string{}
}
}

func main() {
opts := processFlags()
before(opts)
Expand Down
10 changes: 8 additions & 2 deletions benchmark/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ const (
ServerReadBufferSize
ServerWriteBufferSize
SleepBetweenRPCs
RecvBufferPool

// MaxFeatureIndex is a place holder to indicate the total number of feature
// indices we have. Any new feature indices should be added above this.
Expand Down Expand Up @@ -126,6 +127,8 @@ type Features struct {
ServerWriteBufferSize int
// SleepBetweenRPCs configures optional delay between RPCs.
SleepBetweenRPCs time.Duration
// RecvBufferPool represents the shared recv buffer pool used.
RecvBufferPool string
}

// String returns all the feature values as a string.
Expand All @@ -145,12 +148,13 @@ func (f Features) String() string {
"trace_%v-latency_%v-kbps_%v-MTU_%v-maxConcurrentCalls_%v-%s-%s-"+
"compressor_%v-channelz_%v-preloader_%v-clientReadBufferSize_%v-"+
"clientWriteBufferSize_%v-serverReadBufferSize_%v-serverWriteBufferSize_%v-"+
"sleepBetweenRPCs_%v-connections_%v-",
"sleepBetweenRPCs_%v-connections_%v-recvBufferPool_%v-",
f.NetworkMode, f.UseBufConn, f.EnableKeepalive, f.BenchTime, f.EnableTrace,
f.Latency, f.Kbps, f.MTU, f.MaxConcurrentCalls, reqPayloadString,
respPayloadString, f.ModeCompressor, f.EnableChannelz, f.EnablePreloader,
f.ClientReadBufferSize, f.ClientWriteBufferSize, f.ServerReadBufferSize,
f.ServerWriteBufferSize, f.SleepBetweenRPCs, f.Connections)
f.ServerWriteBufferSize, f.SleepBetweenRPCs, f.Connections,
f.RecvBufferPool)
}

// SharedFeatures returns the shared features as a pretty printable string.
Expand Down Expand Up @@ -224,6 +228,8 @@ func (f Features) partialString(b *bytes.Buffer, wantFeatures []bool, sep, delim
b.WriteString(fmt.Sprintf("ServerWriteBufferSize%v%v%v", sep, f.ServerWriteBufferSize, delim))
case SleepBetweenRPCs:
b.WriteString(fmt.Sprintf("SleepBetweenRPCs%v%v%v", sep, f.SleepBetweenRPCs, delim))
case RecvBufferPool:
b.WriteString(fmt.Sprintf("RecvBufferPool%v%v%v", sep, f.RecvBufferPool, delim))
default:
log.Fatalf("Unknown feature index %v. maxFeatureIndex is %v", i, MaxFeatureIndex)
}
Expand Down
23 changes: 23 additions & 0 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type dialOptions struct {
defaultServiceConfigRawJSON *string
resolvers []resolver.Builder
idleTimeout time.Duration
recvBufferPool SharedBufferPool
}

// DialOption configures how we set up the connection.
Expand Down Expand Up @@ -628,6 +629,7 @@ func defaultDialOptions() dialOptions {
ReadBufferSize: defaultReadBufSize,
UseProxy: true,
},
recvBufferPool: nopBufferPool{},
}
}

Expand Down Expand Up @@ -676,3 +678,24 @@ func WithIdleTimeout(d time.Duration) DialOption {
o.idleTimeout = d
})
}

// WithRecvBufferPool returns a DialOption that configures the ClientConn
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: WithStatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func WithRecvBufferPool(bufferPool SharedBufferPool) DialOption {
return newFuncDialOption(func(o *dialOptions) {
o.recvBufferPool = bufferPool
})
}
27 changes: 15 additions & 12 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,9 @@ type parser struct {
// The header of a gRPC message. Find more detail at
// https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
header [5]byte

// recvBufferPool is the pool of shared receive buffers.
recvBufferPool SharedBufferPool
}

// recvMsg reads a complete gRPC message from the stream.
Expand Down Expand Up @@ -610,9 +613,7 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
msg = p.recvBufferPool.Get(int(length))
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
Expand Down Expand Up @@ -726,12 +727,12 @@ type payloadInfo struct {
}

func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) ([]byte, error) {
pf, d, err := p.recvMsg(maxReceiveMessageSize)
pf, buf, err := p.recvMsg(maxReceiveMessageSize)
if err != nil {
return nil, err
}
if payInfo != nil {
payInfo.compressedLength = len(d)
payInfo.compressedLength = len(buf)
}

if st := checkRecvPayload(pf, s.RecvCompress(), compressor != nil || dc != nil); st != nil {
Expand All @@ -743,10 +744,10 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
// To match legacy behavior, if the decompressor is set by WithDecompressor or RPCDecompressor,
// use this decompressor as the default.
if dc != nil {
d, err = dc.Do(bytes.NewReader(d))
size = len(d)
buf, err = dc.Do(bytes.NewReader(buf))
size = len(buf)
} else {
d, size, err = decompress(compressor, d, maxReceiveMessageSize)
buf, size, err = decompress(compressor, buf, maxReceiveMessageSize)
}
if err != nil {
return nil, status.Errorf(codes.Internal, "grpc: failed to decompress the received message: %v", err)
Expand All @@ -757,7 +758,7 @@ func recvAndDecompress(p *parser, s *transport.Stream, dc Decompressor, maxRecei
return nil, status.Errorf(codes.ResourceExhausted, "grpc: received message after decompression larger than max (%d vs. %d)", size, maxReceiveMessageSize)
}
}
return d, nil
return buf, nil
}

// Using compressor, decompress d, returning data and size.
Expand Down Expand Up @@ -792,15 +793,17 @@ func decompress(compressor encoding.Compressor, d []byte, maxReceiveMessageSize
// dc takes precedence over compressor.
// TODO(dfawley): wrap the old compressor/decompressor using the new API?
func recv(p *parser, c baseCodec, s *transport.Stream, dc Decompressor, m interface{}, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor) error {
d, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
buf, err := recvAndDecompress(p, s, dc, maxReceiveMessageSize, payInfo, compressor)
if err != nil {
return err
}
if err := c.Unmarshal(d, m); err != nil {
if err := c.Unmarshal(buf, m); err != nil {
return status.Errorf(codes.Internal, "grpc: failed to unmarshal the received message: %v", err)
}
if payInfo != nil {
payInfo.uncompressedBytes = d
payInfo.uncompressedBytes = buf
} else {
p.recvBufferPool.Put(&buf)
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (s) TestSimpleParsing(t *testing.T) {
{append([]byte{0, 1, 0, 0, 0}, bigMsg...), nil, bigMsg, compressionNone},
} {
buf := fullReader{bytes.NewReader(test.p)}
parser := &parser{r: buf}
parser := &parser{r: buf, recvBufferPool: nopBufferPool{}}
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
Expand All @@ -77,7 +77,7 @@ func (s) TestMultipleParsing(t *testing.T) {
// Set a byte stream consists of 3 messages with their headers.
p := []byte{0, 0, 0, 0, 1, 'a', 0, 0, 0, 0, 2, 'b', 'c', 0, 0, 0, 0, 1, 'd'}
b := fullReader{bytes.NewReader(p)}
parser := &parser{r: b}
parser := &parser{r: b, recvBufferPool: nopBufferPool{}}

wantRecvs := []struct {
pt payloadFormat
Expand Down
27 changes: 25 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ type serverOptions struct {
maxHeaderListSize *uint32
headerTableSize *uint32
numServerWorkers uint32
recvBufferPool SharedBufferPool
}

var defaultServerOptions = serverOptions{
Expand All @@ -182,6 +183,7 @@ var defaultServerOptions = serverOptions{
connectionTimeout: 120 * time.Second,
writeBufferSize: defaultWriteBufSize,
readBufferSize: defaultReadBufSize,
recvBufferPool: nopBufferPool{},
}
var globalServerOptions []ServerOption

Expand Down Expand Up @@ -552,6 +554,27 @@ func NumStreamWorkers(numServerWorkers uint32) ServerOption {
})
}

// RecvBufferPool returns a ServerOption that configures the server
// to use the provided shared buffer pool for parsing incoming messages. Depending
// on the application's workload, this could result in reduced memory allocation.
//
// If you are unsure about how to implement a memory pool but want to utilize one,
// begin with grpc.NewSharedBufferPool.
//
// Note: The shared buffer pool feature will not be active if any of the following
// options are used: StatsHandler, EnableTracing, or binary logging. In such
// cases, the shared buffer pool will be ignored.
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a
// later release.
func RecvBufferPool(bufferPool SharedBufferPool) ServerOption {
return newFuncServerOption(func(o *serverOptions) {
o.recvBufferPool = bufferPool
})
}

// serverWorkerResetThreshold defines how often the stack must be reset. Every
// N requests, by spawning a new goroutine in its place, a worker can reset its
// stack so that large stacks don't live in memory forever. 2^16 should allow
Expand Down Expand Up @@ -1296,7 +1319,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
if len(shs) != 0 || len(binlogs) != 0 {
payInfo = &payloadInfo{}
}
d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
d, err := recvAndDecompress(&parser{r: stream, recvBufferPool: s.opts.recvBufferPool}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
if err != nil {
if e := t.WriteStatus(stream, status.Convert(err)); e != nil {
channelz.Warningf(logger, s.channelzID, "grpc: Server.processUnaryRPC failed to write status: %v", e)
Expand Down Expand Up @@ -1506,7 +1529,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
ctx: ctx,
t: t,
s: stream,
p: &parser{r: stream},
p: &parser{r: stream, recvBufferPool: s.opts.recvBufferPool},
codec: s.getCodec(stream.ContentSubtype()),
maxReceiveMessageSize: s.opts.maxReceiveMessageSize,
maxSendMessageSize: s.opts.maxSendMessageSize,
Expand Down
Loading

0 comments on commit 1634254

Please sign in to comment.