-
Notifications
You must be signed in to change notification settings - Fork 4.5k
grpc: Add a pointer of server to ctx passed into stats handler #6625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -73,6 +73,10 @@ func init() { | |
internal.DrainServerTransports = func(srv *Server, addr string) { | ||
srv.drainServerTransports(addr) | ||
} | ||
internal.IsRegisteredMethod = func(srv *Server, method string) bool { | ||
return srv.isRegisteredMethod(method) | ||
} | ||
internal.GetServer = getServer | ||
internal.AddGlobalServerOptions = func(opt ...ServerOption) { | ||
globalServerOptions = append(globalServerOptions, opt...) | ||
} | ||
|
@@ -912,9 +916,17 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { | |
return | ||
} | ||
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) | ||
ctx := context.Background() | ||
for _, sh := range s.opts.statsHandlers { | ||
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ | ||
RemoteAddr: rawConn.RemoteAddr(), | ||
LocalAddr: rawConn.LocalAddr(), | ||
}) | ||
sh.HandleConn(ctx, &stats.ConnBegin{}) | ||
} | ||
|
||
// Finish handshaking (HTTP2) | ||
st := s.newHTTP2Transport(rawConn) | ||
st := s.newHTTP2Transport(ctx, rawConn) | ||
rawConn.SetDeadline(time.Time{}) | ||
if st == nil { | ||
return | ||
|
@@ -940,7 +952,7 @@ func (s *Server) drainServerTransports(addr string) { | |
|
||
// newHTTP2Transport sets up a http/2 transport (using the | ||
// gRPC http2 server transport in transport/http2_server.go). | ||
func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { | ||
func (s *Server) newHTTP2Transport(ctx context.Context, c net.Conn) transport.ServerTransport { | ||
config := &transport.ServerConfig{ | ||
MaxStreams: s.opts.maxConcurrentStreams, | ||
ConnectionTimeout: s.opts.connectionTimeout, | ||
|
@@ -958,7 +970,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { | |
MaxHeaderListSize: s.opts.maxHeaderListSize, | ||
HeaderTableSize: s.opts.headerTableSize, | ||
} | ||
st, err := transport.NewServerTransport(c, config) | ||
st, err := transport.NewServerTransport(ctx, c, config) | ||
if err != nil { | ||
s.mu.Lock() | ||
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) | ||
|
@@ -1689,8 +1701,22 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran | |
return t.WriteStatus(ss.s, statusOK) | ||
} | ||
|
||
type serverKey struct{} | ||
|
||
// getServer gets the Server from the context. | ||
func getServer(ctx context.Context) *Server { | ||
s, _ := ctx.Value(serverKey{}).(*Server) | ||
return s | ||
} | ||
|
||
// setServer sets the Server in the context. | ||
func setServer(ctx context.Context, server *Server) context.Context { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
return context.WithValue(ctx, serverKey{}, server) | ||
} | ||
|
||
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) { | ||
ctx := stream.Context() | ||
ctx = setServer(ctx, s) | ||
var ti *traceInfo | ||
if EnableTracing { | ||
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method()) | ||
|
@@ -1707,6 +1733,20 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str | |
} | ||
} | ||
|
||
md, _ := metadata.FromIncomingContext(ctx) | ||
for _, sh := range s.opts.statsHandlers { | ||
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) | ||
stream.SetContext(ctx) // To have calls in stream callouts work. Will delete once all stats handler calls come from the gRPC layer. | ||
sh.HandleRPC(ctx, &stats.InHeader{ | ||
FullMethod: stream.Method(), | ||
RemoteAddr: t.RemoteAddr(), | ||
LocalAddr: t.LocalAddr(), | ||
Compression: stream.RecvCompress(), | ||
WireLength: stream.HeaderWireLength(), | ||
Header: md, | ||
}) | ||
} | ||
|
||
sm := stream.Method() | ||
if sm != "" && sm[0] == '/' { | ||
sm = sm[1:] | ||
|
@@ -1920,6 +1960,24 @@ func (s *Server) getCodec(contentSubtype string) baseCodec { | |
return codec | ||
} | ||
|
||
// isRegisteredMethod returns whether the passed in method is registered as a | ||
// method on the server. | ||
func (s *Server) isRegisteredMethod(method string) bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline this needs to be a "full method name", parsed accordingly, with both service & method included in the check. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
for _, service := range s.services { | ||
for mName := range service.methods { | ||
if method == mName { | ||
return true | ||
} | ||
} | ||
for mName := range service.streams { | ||
if method == mName { | ||
return true | ||
} | ||
} | ||
} | ||
return false | ||
} | ||
|
||
// SetHeader sets the header metadata to be sent from the server to the client. | ||
// The context provided must be the context passed to the server's handler. | ||
// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6539,3 +6539,89 @@ func (s) TestRPCBlockingOnPickerStatsCall(t *testing.T) { | |
t.Fatalf("sh.pickerUpdated count: %v, want: %v", pickerUpdatedCount, 2) | ||
} | ||
} | ||
|
||
type statsHandlerServerAssert struct { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we make a stub stats handler that works like our server, so the implementation relevant to the test can live inside the test itself? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I'm gonna switch this test to use this new util. However, I won't switch all the tests that I wrote that use stats handler components to use this, but if you want me to I can do that (probably as another PR). Any future tests I will write I'll use this stub type. |
||
errorCh *testutils.Channel | ||
} | ||
|
||
func (shsa *statsHandlerServerAssert) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { | ||
// OpenTelemetry instrumentation needs the passed in Server to determine if | ||
// methods are registered in different handle calls in to record metrics. | ||
// This tag RPC call context gets passed into every handle call, so can | ||
// assert once here, since it maps to all the handle RPC calls that come | ||
// after. These internal calls will be how the OpenTelemetry instrumentation | ||
// component accesses this server and the subsequent helper on the server. | ||
server := internal.GetServer.(func(context.Context) *grpc.Server)(ctx) | ||
if server == nil { | ||
shsa.errorCh.Send("stats handler received ctx has no server present") | ||
} | ||
|
||
if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "UnaryCall"); !registeredMethod { | ||
shsa.errorCh.Send(errors.New("UnaryCall should be a registered method according to server")) | ||
return ctx | ||
} | ||
Comment on lines
+6559
to
+6562
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little more readable: isRegistered := internal.IsRegisteredMethod.(func..etc..)
if !isRegistered(server, "UnaryCall) {
// fail
}
// etc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, great point. Switched. |
||
|
||
if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "FullDuplexCall"); !registeredMethod { | ||
shsa.errorCh.Send(errors.New("FullDuplexCall should be a registered method according to server")) | ||
return ctx | ||
} | ||
|
||
if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "DoesNotExistCall"); registeredMethod { | ||
shsa.errorCh.Send(errors.New("DoesNotExistCall should not be a registered method according to server")) | ||
return ctx | ||
} | ||
|
||
shsa.errorCh.Send(nil) | ||
return ctx | ||
} | ||
|
||
func (shsa *statsHandlerServerAssert) HandleRPC(ctx context.Context, s stats.RPCStats) {} | ||
|
||
func (shsa *statsHandlerServerAssert) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { | ||
return ctx | ||
} | ||
|
||
func (shsa *statsHandlerServerAssert) HandleConn(context.Context, stats.ConnStats) {} | ||
|
||
// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler | ||
// gets access to a Server on the server side, and thus the method that the | ||
// server owns which specifies whether a method is made or not. The test sets up | ||
// a server with a unary call and full duplex call configured, and makes an RPC. | ||
// Within the stats handler, asking the server whether unary or duplex method | ||
// names are registered should return true, and any other query should return | ||
// false. | ||
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add this to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chose to put it in stats directory, since that is where we want this helper to be called. |
||
errorCh := testutils.NewChannel() | ||
shsa := &statsHandlerServerAssert{ | ||
errorCh: errorCh, | ||
} | ||
ss := &stubserver.StubServer{ | ||
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { | ||
return &testpb.SimpleResponse{}, nil | ||
}, | ||
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error { | ||
for { | ||
if _, err := stream.Recv(); err == io.EOF { | ||
return nil | ||
} | ||
} | ||
}, | ||
} | ||
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(shsa)}); err != nil { | ||
t.Fatalf("Error starting endpoint server: %v", err) | ||
} | ||
defer ss.Stop() | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) | ||
defer cancel() | ||
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil { | ||
t.Fatalf("Unexpected error from UnaryCall: %v", err) | ||
} | ||
err, errRecv := errorCh.Receive(ctx) | ||
if errRecv != nil { | ||
t.Fatalf("error receiving from channel: %v", errRecv) | ||
} | ||
if err != nil { | ||
t.Fatalf("error received from error channel: %v", err) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serverFromContext
pleaseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.