Skip to content

Commit

Permalink
Remove buf copy when the compressor exist (#1427)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhouyihaiDing authored and dfawley committed Aug 25, 2017
1 parent c29d638 commit 01089b2
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 89 deletions.
8 changes: 4 additions & 4 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor,
Client: true,
}
}
outBuf, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
hdr, data, err := encode(dopts.codec, args, compressor, cbuf, outPayload)
if err != nil {
return err
}
if c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(outBuf) > *c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(outBuf), *c.maxSendMessageSize)
if len(data) > *c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize)
}
err = t.Write(stream, outBuf, opts)
err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
dopts.copts.StatsHandler.HandleRPC(ctx, outPayload)
Expand Down
4 changes: 2 additions & 2 deletions call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
reply, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
hdr, data, err := encode(testCodec{}, &expectedResponse, nil, nil, nil)
if err != nil {
t.Errorf("Failed to encode the response: %v", err)
return
}
h.t.Write(s, reply, &transport.Options{})
h.t.Write(s, hdr, data, &transport.Options{})
h.t.WriteStatus(s, status.New(codes.OK, ""))
}

Expand Down
48 changes: 19 additions & 29 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,19 +288,20 @@ func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byt
return pf, msg, nil
}

// encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length.
func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, error) {
var (
b []byte
length uint
// encode serializes msg and returns a buffer of message header and a buffer of msg.
// If msg is nil, it generates the message header and an empty msg buffer.
func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayload *stats.OutPayload) ([]byte, []byte, error) {
var b []byte
const (
payloadLen = 1
sizeLen = 4
)

if msg != nil {
var err error
// TODO(zhaoq): optimize to reduce memory alloc and copying.
b, err = c.Marshal(msg)
if err != nil {
return nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
return nil, nil, Errorf(codes.Internal, "grpc: error while marshaling: %v", err.Error())
}
if outPayload != nil {
outPayload.Payload = msg
Expand All @@ -310,39 +311,28 @@ func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer, outPayl
}
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
return nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
return nil, nil, Errorf(codes.Internal, "grpc: error while compressing: %v", err.Error())
}
b = cbuf.Bytes()
}
length = uint(len(b))
}
if length > math.MaxUint32 {
return nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", length)
}

const (
payloadLen = 1
sizeLen = 4
)

var buf = make([]byte, payloadLen+sizeLen+len(b))
if len(b) > math.MaxUint32 {
return nil, nil, Errorf(codes.ResourceExhausted, "grpc: message too large (%d bytes)", len(b))
}

// Write payload format
bufHeader := make([]byte, payloadLen+sizeLen)
if cp == nil {
buf[0] = byte(compressionNone)
bufHeader[0] = byte(compressionNone)
} else {
buf[0] = byte(compressionMade)
bufHeader[0] = byte(compressionMade)
}
// Write length of b into buf
binary.BigEndian.PutUint32(buf[1:], uint32(length))
// Copy encoded msg to buf
copy(buf[5:], b)

binary.BigEndian.PutUint32(bufHeader[payloadLen:], uint32(len(b)))
if outPayload != nil {
outPayload.WireLength = len(buf)
outPayload.WireLength = payloadLen + sizeLen + len(b)
}

return buf, nil
return bufHeader, b, nil
}

func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
Expand Down
17 changes: 9 additions & 8 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ func TestEncode(t *testing.T) {
msg proto.Message
cp Compressor
// outputs
b []byte
err error
hdr []byte
data []byte
err error
}{
{nil, nil, []byte{0, 0, 0, 0, 0}, nil},
{nil, nil, []byte{0, 0, 0, 0, 0}, []byte{}, nil},
} {
b, err := encode(protoCodec{}, test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) {
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
hdr, data, err := encode(protoCodec{}, test.msg, nil, nil, nil)
if err != test.err || !bytes.Equal(hdr, test.hdr) || !bytes.Equal(data, test.data) {
t.Fatalf("encode(_, _, %v, _) = %v, %v, %v\nwant %v, %v, %v", test.cp, hdr, data, err, test.hdr, test.data, test.err)
}
}
}
Expand Down Expand Up @@ -164,8 +165,8 @@ func TestToRPCErr(t *testing.T) {
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encoded, _ := encode(protoCodec{}, msg, nil, nil, nil)
encodedSz := int64(len(encoded))
encodeHdr, encodeData, _ := encode(protoCodec{}, msg, nil, nil, nil)
encodedSz := int64(len(encodeHdr) + len(encodeData))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand Down
8 changes: 4 additions & 4 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,15 @@ func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Str
if s.opts.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
p, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
hdr, data, err := encode(s.opts.codec, msg, cp, cbuf, outPayload)
if err != nil {
grpclog.Errorln("grpc: server failed to encode response: ", err)
return err
}
if len(p) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(p), s.opts.maxSendMessageSize)
if len(data) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), s.opts.maxSendMessageSize)
}
err = t.Write(stream, p, opts)
err = t.Write(stream, hdr, data, opts)
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload)
Expand Down
18 changes: 9 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
Client: true,
}
}
out, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
hdr, data, err := encode(cs.codec, m, cs.cp, cs.cbuf, outPayload)
defer func() {
if cs.cbuf != nil {
cs.cbuf.Reset()
Expand All @@ -374,10 +374,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
if cs.c.maxSendMessageSize == nil {
return Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)")
}
if len(out) > *cs.c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), *cs.c.maxSendMessageSize)
if len(data) > *cs.c.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize)
}
err = cs.t.Write(cs.s, out, &transport.Options{Last: false})
err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false})
if err == nil && outPayload != nil {
outPayload.SentTime = time.Now()
cs.statsHandler.HandleRPC(cs.statsCtx, outPayload)
Expand Down Expand Up @@ -449,7 +449,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}

func (cs *clientStream) CloseSend() (err error) {
err = cs.t.Write(cs.s, nil, &transport.Options{Last: true})
err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true})
defer func() {
if err != nil {
cs.finish(err)
Expand Down Expand Up @@ -608,7 +608,7 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if ss.statsHandler != nil {
outPayload = &stats.OutPayload{}
}
out, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
hdr, data, err := encode(ss.codec, m, ss.cp, ss.cbuf, outPayload)
defer func() {
if ss.cbuf != nil {
ss.cbuf.Reset()
Expand All @@ -617,10 +617,10 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) {
if err != nil {
return err
}
if len(out) > ss.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(out), ss.maxSendMessageSize)
if len(data) > ss.maxSendMessageSize {
return Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), ss.maxSendMessageSize)
}
if err := ss.t.Write(ss.s, out, &transport.Options{Last: false}); err != nil {
if err := ss.t.Write(ss.s, hdr, data, &transport.Options{Last: false}); err != nil {
return toRPCErr(err)
}
if outPayload != nil {
Expand Down
3 changes: 2 additions & 1 deletion transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) {
}
}

func (ht *serverHandlerTransport) Write(s *Stream, data []byte, opts *Options) error {
func (ht *serverHandlerTransport) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
return ht.do(func() {
ht.writeCommonHeaders(s)
ht.rw.Write(hdr)
ht.rw.Write(data)
if !opts.Delay {
ht.rw.(http.Flusher).Flush()
Expand Down
34 changes: 25 additions & 9 deletions transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,15 @@ func (t *http2Client) GracefulClose() error {
// should proceed only if Write returns nil.
// TODO(zhaoq): opts.Delay is ignored in this implementation. Support it later
// if it improves the performance.
func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
r := bytes.NewBuffer(data)
func (t *http2Client) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
if len(data) < secondStart {
secondStart = len(data)
}
hdr = append(hdr, data[:secondStart]...)
data = data[secondStart:]
isLastSlice := (len(data) == 0)
r := bytes.NewBuffer(hdr)
var (
p []byte
oqv uint32
Expand Down Expand Up @@ -726,9 +733,6 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
endStream bool
forceFlush bool
)
if opts.Last && r.Len() == 0 {
endStream = true
}
// Indicate there is a writer who is about to write a data frame.
t.framer.adjustNumWriters(1)
// Got some quota. Try to acquire writing privilege on the transport.
Expand Down Expand Up @@ -768,10 +772,22 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
t.writableChan <- 0
continue
}
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 {
// Do a force flush iff this is last frame for the entire gRPC message
// and the caller is the only writer at this moment.
forceFlush = true
if r.Len() == 0 {
if isLastSlice {
if opts.Last {
endStream = true
}
if t.framer.adjustNumWriters(0) == 1 {
// Do a force flush iff this is last frame for the entire gRPC message
// and the caller is the only writer at this moment.
forceFlush = true
}
} else {
isLastSlice = true
if len(data) != 0 {
r = bytes.NewBuffer(data)
}
}
}
// If WriteData fails, all the pending streams will be handled
// by http2Client.Close(). No explicit CloseStream() needs to be
Expand Down
22 changes: 18 additions & 4 deletions transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {

// Write converts the data into HTTP2 data frame and sends it out. Non-nil error
// is returns if it fails (e.g., framing error, transport error).
func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) (err error) {
// TODO(zhaoq): Support multi-writers for a single stream.
secondStart := http2MaxFrameLen - len(hdr)%http2MaxFrameLen
if len(data) < secondStart {
secondStart = len(data)
}
hdr = append(hdr, data[:secondStart]...)
data = data[secondStart:]
isLastSlice := (len(data) == 0)
var writeHeaderFrame bool
s.mu.Lock()
if s.state == streamDone {
Expand All @@ -842,7 +849,7 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
if writeHeaderFrame {
t.WriteHeader(s, nil)
}
r := bytes.NewBuffer(data)
r := bytes.NewBuffer(hdr)
var (
p []byte
oqv uint32
Expand Down Expand Up @@ -921,8 +928,15 @@ func (t *http2Server) Write(s *Stream, data []byte, opts *Options) (err error) {
continue
}
var forceFlush bool
if r.Len() == 0 && t.framer.adjustNumWriters(0) == 1 && !opts.Last {
forceFlush = true
if r.Len() == 0 {
if isLastSlice {
if t.framer.adjustNumWriters(0) == 1 && !opts.Last {
forceFlush = true
}
} else {
r = bytes.NewBuffer(data)
isLastSlice = true
}
}
// Reset ping strikes when sending data since this might cause
// the peer to send ping.
Expand Down
4 changes: 2 additions & 2 deletions transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ type ClientTransport interface {

// Write sends the data for the given stream. A nil stream indicates
// the write is to be performed on the transport as a whole.
Write(s *Stream, data []byte, opts *Options) error
Write(s *Stream, hdr []byte, data []byte, opts *Options) error

// NewStream creates a Stream for an RPC.
NewStream(ctx context.Context, callHdr *CallHdr) (*Stream, error)
Expand Down Expand Up @@ -606,7 +606,7 @@ type ServerTransport interface {

// Write sends the data for the given stream.
// Write may not be called on all streams.
Write(s *Stream, data []byte, opts *Options) error
Write(s *Stream, hdr []byte, data []byte, opts *Options) error

// WriteStatus sends the status of a stream to the client. WriteStatus is
// the final call made on a stream and always occurs.
Expand Down
Loading

0 comments on commit 01089b2

Please sign in to comment.