From d09ec435457694f2e314397fe5bc18a1d959f303 Mon Sep 17 00:00:00 2001 From: dfawley Date: Mon, 5 Feb 2018 12:54:13 -0800 Subject: [PATCH] Implement unary functionality using streams (#1835) --- call.go | 299 +++---------------------------------------- stream.go | 181 +++++++++++++------------- test/end2end_test.go | 2 +- 3 files changed, 104 insertions(+), 378 deletions(-) diff --git a/call.go b/call.go index 8fa542a7ee3b..a66e3c2d9581 100644 --- a/call.go +++ b/call.go @@ -19,123 +19,9 @@ package grpc import ( - "io" - "time" - "golang.org/x/net/context" - "golang.org/x/net/trace" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/encoding" - "google.golang.org/grpc/stats" - "google.golang.org/grpc/status" - "google.golang.org/grpc/transport" ) -// recvResponse receives and parses an RPC response. -// On error, it returns the error and indicates whether the call should be retried. -// -// TODO(zhaoq): Check whether the received message sequence is valid. -// TODO ctx is used for stats collection and processing. It is the context passed from the application. -func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) (err error) { - // Try to acquire header metadata from the server if there is any. - defer func() { - if err != nil { - if _, ok := err.(transport.ConnectionError); !ok { - t.CloseStream(stream, err) - } - } - }() - p := &parser{r: stream} - var inPayload *stats.InPayload - if dopts.copts.StatsHandler != nil { - inPayload = &stats.InPayload{ - Client: true, - } - } - for { - if c.maxReceiveMessageSize == nil { - return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") - } - - // Set dc if it exists and matches the message compression type used, - // otherwise set comp if a registered compressor exists for it. - var comp encoding.Compressor - var dc Decompressor - if rc := stream.RecvCompress(); dopts.dc != nil && dopts.dc.Type() == rc { - dc = dopts.dc - } else if rc != "" && rc != encoding.Identity { - comp = encoding.GetCompressor(rc) - } - if err = recv(p, c.codec, stream, dc, reply, *c.maxReceiveMessageSize, inPayload, comp); err != nil { - if err == io.EOF { - break - } - return - } - } - if inPayload != nil && err == io.EOF && stream.Status().Code() == codes.OK { - // TODO in the current implementation, inTrailer may be handled before inPayload in some cases. - // Fix the order if necessary. - dopts.copts.StatsHandler.HandleRPC(ctx, inPayload) - } - return nil -} - -// sendRequest writes out various information of an RPC such as Context and Message. -func sendRequest(ctx context.Context, dopts dialOptions, compressor Compressor, c *callInfo, callHdr *transport.CallHdr, stream *transport.Stream, t transport.ClientTransport, args interface{}, opts *transport.Options) (err error) { - defer func() { - if err != nil { - // If err is connection error, t will be closed, no need to close stream here. - if _, ok := err.(transport.ConnectionError); !ok { - t.CloseStream(stream, err) - } - } - }() - var ( - outPayload *stats.OutPayload - ) - if dopts.copts.StatsHandler != nil { - outPayload = &stats.OutPayload{ - Client: true, - } - } - // Set comp and clear compressor if a registered compressor matches the type - // specified via UseCompressor. (And error if a matching compressor is not - // registered.) - var comp encoding.Compressor - if ct := c.compressorType; ct != "" && ct != encoding.Identity { - compressor = nil // Disable the legacy compressor. - comp = encoding.GetCompressor(ct) - if comp == nil { - return status.Errorf(codes.Internal, "grpc: Compressor is not installed for grpc-encoding %q", ct) - } - } - hdr, data, err := encode(c.codec, args, compressor, outPayload, comp) - if err != nil { - return err - } - if c.maxSendMessageSize == nil { - return status.Errorf(codes.Internal, "callInfo maxSendMessageSize field uninitialized(nil)") - } - if len(data) > *c.maxSendMessageSize { - return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(data), *c.maxSendMessageSize) - } - err = t.Write(stream, hdr, data, opts) - if err == nil && outPayload != nil { - outPayload.SentTime = time.Now() - dopts.copts.StatsHandler.HandleRPC(ctx, outPayload) - } - // t.NewStream(...) could lead to an early rejection of the RPC (e.g., the service/method - // does not exist.) so that t.Write could get io.EOF from wait(...). Leave the following - // recvResponse to get the final status. - if err != nil && err != io.EOF { - return err - } - // Sent successfully. - return nil -} - // Invoke sends the RPC request on the wire and returns after response is // received. This is typically called by generated code. // @@ -155,189 +41,34 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli return cc.Invoke(ctx, method, args, reply, opts...) } -func invoke(ctx context.Context, method string, args, reply interface{}, cc *ClientConn, opts ...CallOption) (e error) { - c := defaultCallInfo() - mc := cc.GetMethodConfig(method) - if mc.WaitForReady != nil { - c.failFast = !*mc.WaitForReady - } - - if mc.Timeout != nil && *mc.Timeout >= 0 { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, *mc.Timeout) - defer cancel() - } - - opts = append(cc.dopts.callOptions, opts...) - for _, o := range opts { - if err := o.before(c); err != nil { - return toRPCErr(err) - } - } - defer func() { - for _, o := range opts { - o.after(c) - } - }() - - c.maxSendMessageSize = getMaxSize(mc.MaxReqSize, c.maxSendMessageSize, defaultClientMaxSendMessageSize) - c.maxReceiveMessageSize = getMaxSize(mc.MaxRespSize, c.maxReceiveMessageSize, defaultClientMaxReceiveMessageSize) - if err := setCallInfoCodec(c); err != nil { - return err - } +var unaryStreamDesc = &StreamDesc{ServerStreams: false, ClientStreams: false} - if EnableTracing { - c.traceInfo.tr = trace.New("grpc.Sent."+methodFamily(method), method) - defer c.traceInfo.tr.Finish() - c.traceInfo.firstLine.client = true - if deadline, ok := ctx.Deadline(); ok { - c.traceInfo.firstLine.deadline = deadline.Sub(time.Now()) - } - c.traceInfo.tr.LazyLog(&c.traceInfo.firstLine, false) - // TODO(dsymonds): Arrange for c.traceInfo.firstLine.remoteAddr to be set. - defer func() { - if e != nil { - c.traceInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{e}}, true) - c.traceInfo.tr.SetError() - } - }() - } - ctx = newContextWithRPCInfo(ctx, c.failFast) - sh := cc.dopts.copts.StatsHandler - if sh != nil { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) - begin := &stats.Begin{ - Client: true, - BeginTime: time.Now(), - FailFast: c.failFast, - } - sh.HandleRPC(ctx, begin) - defer func() { - end := &stats.End{ - Client: true, - EndTime: time.Now(), - Error: e, - } - sh.HandleRPC(ctx, end) - }() - } - topts := &transport.Options{ - Last: true, - Delay: false, - } - callHdr := &transport.CallHdr{ - Host: cc.authority, - Method: method, - } - if c.creds != nil { - callHdr.Creds = c.creds - } - if c.compressorType != "" { - callHdr.SendCompress = c.compressorType - } else if cc.dopts.cp != nil { - callHdr.SendCompress = cc.dopts.cp.Type() - } +func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error { + // TODO: implement retries in clientStream and make this simply + // newClientStream, SendMsg, RecvMsg. firstAttempt := true - for { - // Check to make sure the context has expired. This will prevent us from - // looping forever if an error occurs for wait-for-ready RPCs where no data - // is sent on the wire. - select { - case <-ctx.Done(): - return toRPCErr(ctx.Err()) - default: - } - - // Record the done handler from Balancer.Get(...). It is called once the - // RPC has completed or failed. - t, done, err := cc.getTransport(ctx, c.failFast) + csInt, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...) if err != nil { return err } - stream, err := t.NewStream(ctx, callHdr) - if err != nil { - if done != nil { - done(balancer.DoneInfo{Err: err}) - } - // In the event of any error from NewStream, we never attempted to write - // anything to the wire, so we can retry indefinitely for non-fail-fast - // RPCs. - if !c.failFast { + cs := csInt.(*clientStream) + if err := cs.SendMsg(req); err != nil { + if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt { + // TODO: Add a field to header for grpc-transparent-retry-attempts + firstAttempt = false continue } - return toRPCErr(err) - } - c.stream = stream - if c.traceInfo.tr != nil { - c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true) - } - err = sendRequest(ctx, cc.dopts, cc.dopts.cp, c, callHdr, stream, t, args, topts) - if err != nil { - if done != nil { - done(balancer.DoneInfo{ - Err: err, - BytesSent: true, - BytesReceived: stream.BytesReceived(), - }) - } - // Retry a non-failfast RPC when - // i) the server started to drain before this RPC was initiated. - // ii) the server refused the stream. - if !c.failFast && stream.Unprocessed() { - // In this case, the server did not receive the data, but we still - // created wire traffic, so we should not retry indefinitely. - if firstAttempt { - // TODO: Add a field to header for grpc-transparent-retry-attempts - firstAttempt = false - continue - } - // Otherwise, give up and return an error anyway. - } - return toRPCErr(err) - } - err = recvResponse(ctx, cc.dopts, t, c, stream, reply) - if err != nil { - if done != nil { - done(balancer.DoneInfo{ - Err: err, - BytesSent: true, - BytesReceived: stream.BytesReceived(), - }) - } - if !c.failFast && stream.Unprocessed() { - // In these cases, the server did not receive the data, but we still - // created wire traffic, so we should not retry indefinitely. - if firstAttempt { - // TODO: Add a field to header for grpc-transparent-retry-attempts - firstAttempt = false - continue - } - // Otherwise, give up and return an error anyway. - } - return toRPCErr(err) - } - if c.traceInfo.tr != nil { - c.traceInfo.tr.LazyLog(&payload{sent: false, msg: reply}, true) - } - t.CloseStream(stream, nil) - err = stream.Status().Err() - if done != nil { - done(balancer.DoneInfo{ - Err: err, - BytesSent: true, - BytesReceived: stream.BytesReceived(), - }) + return err } - if !c.failFast && stream.Unprocessed() { - // In these cases, the server did not receive the data, but we still - // created wire traffic, so we should not retry indefinitely. - if firstAttempt { + if err := cs.RecvMsg(reply); err != nil { + if !cs.c.failFast && cs.s.Unprocessed() && firstAttempt { // TODO: Add a field to header for grpc-transparent-retry-attempts firstAttempt = false continue } + return err } - return err + return nil } } diff --git a/stream.go b/stream.go index d76f4dbebfee..50ae3ea1ea62 100644 --- a/stream.go +++ b/stream.go @@ -246,14 +246,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth s, err = t.NewStream(ctx, callHdr) if err != nil { if done != nil { - doneInfo := balancer.DoneInfo{Err: err} - if _, ok := err.(transport.ConnectionError); ok { - // If error is connection error, transport was sending data on wire, - // and we are not sure if anything has been sent on wire. - // If error is not connection error, we are sure nothing has been sent. - doneInfo.BytesSent = true - } - done(doneInfo) + done(balancer.DoneInfo{Err: err}) done = nil } // In the event of any error from NewStream, we never attempted to write @@ -289,29 +282,33 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth statsCtx: ctx, statsHandler: cc.dopts.copts.StatsHandler, } - // Listen on s.Context().Done() to detect cancellation and s.Done() to detect - // normal termination when there is no pending I/O operations on this stream. - go func() { - select { - case <-t.Error(): - // Incur transport error, simply exit. - case <-cc.ctx.Done(): - cs.finish(ErrClientConnClosing) - cs.closeTransportStream(ErrClientConnClosing) - case <-s.Done(): - // TODO: The trace of the RPC is terminated here when there is no pending - // I/O, which is probably not the optimal solution. - cs.finish(s.Status().Err()) - cs.closeTransportStream(nil) - case <-s.GoAway(): - cs.finish(errConnDrain) - cs.closeTransportStream(errConnDrain) - case <-s.Context().Done(): - err := s.Context().Err() - cs.finish(err) - cs.closeTransportStream(transport.ContextErr(err)) - } - }() + if desc != unaryStreamDesc { + // Listen on s.Context().Done() to detect cancellation and s.Done() to + // detect normal termination when there is no pending I/O operations on + // this stream. Not necessary for "unary" streams, since we are guaranteed + // to always be in stream functions. + go func() { + select { + case <-t.Error(): + // Incur transport error, simply exit. + case <-cc.ctx.Done(): + cs.finish(ErrClientConnClosing) + cs.closeTransportStream(ErrClientConnClosing) + case <-s.Done(): + // TODO: The trace of the RPC is terminated here when there is no pending + // I/O, which is probably not the optimal solution. + cs.finish(s.Status().Err()) + cs.closeTransportStream(nil) + case <-s.GoAway(): + cs.finish(errConnDrain) + cs.closeTransportStream(errConnDrain) + case <-s.Context().Done(): + err := s.Context().Err() + cs.finish(err) + cs.closeTransportStream(transport.ContextErr(err)) + } + }() + } return cs, nil } @@ -339,6 +336,7 @@ type clientStream struct { mu sync.Mutex done func(balancer.DoneInfo) + sentLast bool // sent an end stream closed bool finished bool // trInfo.tr is set when the clientStream is created (if EnableTracing is true), @@ -371,6 +369,7 @@ func (cs *clientStream) Trailer() metadata.MD { } func (cs *clientStream) SendMsg(m interface{}) (err error) { + // TODO: Check cs.sentLast and error if we already ended the stream. if cs.tracing { cs.mu.Lock() if cs.trInfo.tr != nil { @@ -381,18 +380,15 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { // TODO Investigate how to signal the stats handling party. // generate error stats if err != nil && err != io.EOF? defer func() { - if err != nil { - cs.finish(err) - } if err == nil { return } + cs.finish(err) if err == io.EOF { - // Specialize the process for server streaming. SendMsg is only called - // once when creating the stream object. io.EOF needs to be skipped when - // the rpc is early finished (before the stream object is created.). - // TODO: It is probably better to move this into the generated code. - if !cs.desc.ClientStreams && cs.desc.ServerStreams { + // SendMsg is only called once for non-client-streams. io.EOF needs to be + // skipped when the rpc is early finished (before the stream object is + // created.). + if !cs.desc.ClientStreams { err = nil } return @@ -418,7 +414,10 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { if len(data) > *cs.c.maxSendMessageSize { return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(data), *cs.c.maxSendMessageSize) } - err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: false}) + if !cs.desc.ClientStreams { + cs.sentLast = true + } + err = cs.t.Write(cs.s, hdr, data, &transport.Options{Last: !cs.desc.ClientStreams}) if err == nil && outPayload != nil { outPayload.SentTime = time.Now() cs.statsHandler.HandleRPC(cs.statsCtx, outPayload) @@ -427,6 +426,15 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { } func (cs *clientStream) RecvMsg(m interface{}) (err error) { + defer func() { + // err != nil indicates the termination of the stream. + if err != nil { + cs.finish(err) + if cs.cancel != nil { + cs.cancel() + } + } + }() var inPayload *stats.InPayload if cs.statsHandler != nil { inPayload = &stats.InPayload{ @@ -453,71 +461,55 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) { cs.decompSet = true } err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, inPayload, cs.decomp) - defer func() { - // err != nil indicates the termination of the stream. - if err != nil { - cs.finish(err) - if cs.cancel != nil { - cs.cancel() - } - } - }() - if err == nil { - if cs.tracing { - cs.mu.Lock() - if cs.trInfo.tr != nil { - cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) - } - cs.mu.Unlock() - } - if inPayload != nil { - cs.statsHandler.HandleRPC(cs.statsCtx, inPayload) - } - if !cs.desc.ClientStreams || cs.desc.ServerStreams { - return - } - // Special handling for client streaming rpc. - // This recv expects EOF or errors, so we don't collect inPayload. - if cs.c.maxReceiveMessageSize == nil { - return status.Errorf(codes.Internal, "callInfo maxReceiveMessageSize field uninitialized(nil)") - } - err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) - cs.closeTransportStream(err) - if err == nil { - return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) + if err != nil { + if _, ok := err.(transport.ConnectionError); !ok { + cs.closeTransportStream(err) } if err == io.EOF { - if se := cs.s.Status().Err(); se != nil { - return se - } - cs.finish(err) - if cs.cancel != nil { - cs.cancel() + if statusErr := cs.s.Status().Err(); statusErr != nil { + return statusErr } - return nil + return io.EOF // indicates end of stream. } return toRPCErr(err) } - if _, ok := err.(transport.ConnectionError); !ok { - cs.closeTransportStream(err) + if cs.tracing { + cs.mu.Lock() + if cs.trInfo.tr != nil { + cs.trInfo.tr.LazyLog(&payload{sent: false, msg: m}, true) + } + cs.mu.Unlock() + } + if inPayload != nil { + cs.statsHandler.HandleRPC(cs.statsCtx, inPayload) + } + if cs.desc.ServerStreams { + return nil + } + + // Special handling for non-server-stream rpcs. + // This recv expects EOF or errors, so we don't collect inPayload. + err = recv(cs.p, cs.codec, cs.s, cs.dc, m, *cs.c.maxReceiveMessageSize, nil, cs.decomp) + cs.closeTransportStream(err) + if err == nil { + return toRPCErr(errors.New("grpc: client streaming protocol violation: get , want ")) } if err == io.EOF { - if statusErr := cs.s.Status().Err(); statusErr != nil { - return statusErr + if se := cs.s.Status().Err(); se != nil { + return se } - // Returns io.EOF to indicate the end of the stream. - return + cs.finish(err) + return nil } return toRPCErr(err) } -func (cs *clientStream) CloseSend() (err error) { - err = cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) - defer func() { - if err != nil { - cs.finish(err) - } - }() +func (cs *clientStream) CloseSend() error { + if cs.sentLast { + return nil + } + cs.sentLast = true + err := cs.t.Write(cs.s, nil, nil, &transport.Options{Last: true}) if err == nil || err == io.EOF { return nil } @@ -525,7 +517,10 @@ func (cs *clientStream) CloseSend() (err error) { cs.closeTransportStream(err) } err = toRPCErr(err) - return + if err != nil { + cs.finish(err) + } + return err } func (cs *clientStream) closeTransportStream(err error) { diff --git a/test/end2end_test.go b/test/end2end_test.go index 412027d3b1a1..772f6e606df5 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -925,7 +925,7 @@ func testServerGoAwayPendingRPC(t *testing.T, e env) { cc := te.clientConn() tc := testpb.NewTestServiceClient(cc) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err)