Skip to content

grpc: Add a pointer of server to ctx passed into stats handler #6625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ var (
// xDS-enabled server invokes this method on a grpc.Server when a particular
// listener moves to "not-serving" mode.
DrainServerTransports any // func(*grpc.Server, string)
// IsRegisteredMethod returns whether the passed in method is registered as
// a method on the server.
IsRegisteredMethod any // func(*grpc.Server, string)
// GetServer returns the server from the context.
GetServer any // func(context.Context) *Server
// AddGlobalServerOptions adds an array of ServerOption that will be
// effective globally for newly created servers. The priority will be: 1.
// user-provided; 2. this method; 3. default values.
Expand Down
2 changes: 2 additions & 0 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ func (ht *serverHandlerTransport) Close(err error) {

func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }

func (ht *serverHandlerTransport) LocalAddr() net.Addr { return nil }

// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown.
type strAddr string
Expand Down
37 changes: 11 additions & 26 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ type http2Server struct {
// returns a nil transport and a non-nil error. For a special case where the
// underlying conn gets closed before the client preface could be read, it
// returns a nil transport and a nil error.
func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) {
func NewServerTransport(ctx context.Context, conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) {
var authInfo credentials.AuthInfo
rawConn := conn
if config.Credentials != nil {
Expand Down Expand Up @@ -249,7 +249,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,

done := make(chan struct{})
t := &http2Server{
ctx: setConnection(context.Background(), rawConn),
ctx: setConnection(ctx, rawConn),
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
Expand Down Expand Up @@ -282,14 +282,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
for _, sh := range t.stats {
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
if err != nil {
return nil, err
Expand Down Expand Up @@ -374,10 +366,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(

buf := newRecvBuffer()
s := &Stream{
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
}
var (
// if false, content-type was missing or invalid
Expand Down Expand Up @@ -597,18 +590,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: mdata.Copy(),
}
sh.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
Expand Down Expand Up @@ -1317,6 +1298,10 @@ func (t *http2Server) RemoteAddr() net.Addr {
return t.remoteAddr
}

func (t *http2Server) LocalAddr() net.Addr {
return t.localAddr
}

func (t *http2Server) Drain(debugData string) {
t.mu.Lock()
defer t.mu.Unlock()
Expand Down
20 changes: 18 additions & 2 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ type Stream struct {
// On server-side it is unused.
status *status.Status

bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
bytesReceived uint32 // indicates whether any bytes have been received on this stream
unprocessed uint32 // set if the server sends a refused stream or GOAWAY including this stream
headerWireLength int

// contentSubtype is the content-subtype for requests.
// this must be lowercase or the behavior is undefined.
Expand Down Expand Up @@ -425,6 +426,12 @@ func (s *Stream) Context() context.Context {
return s.ctx
}

// SetContext sets the context of the stream. This will be deleted once the
// stats handler callouts all move to gRPC layer.
func (s *Stream) SetContext(ctx context.Context) {
s.ctx = ctx
}

// Method returns the method for the stream.
func (s *Stream) Method() string {
return s.method
Expand All @@ -437,6 +444,12 @@ func (s *Stream) Status() *status.Status {
return s.status
}

// HeaderWireLength returns the size of theheaders of the stream as received
// from the wire.
func (s *Stream) HeaderWireLength() int {
return s.headerWireLength
}

// SetHeader sets the header metadata. This can be called multiple times.
// Server side only.
// This should not be called in parallel to other data writes.
Expand Down Expand Up @@ -720,6 +733,9 @@ type ServerTransport interface {
// RemoteAddr returns the remote network address.
RemoteAddr() net.Addr

// LocalAddr returns the local network address.
LocalAddr() net.Addr

// Drain notifies the client this ServerTransport stops accepting new RPCs.
Drain(debugData string)

Expand Down
2 changes: 1 addition & 1 deletion internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
return
}
rawConn := conn
transport, err := NewServerTransport(conn, serverConfig)
transport, err := NewServerTransport(context.Background(), conn, serverConfig)
if err != nil {
return
}
Expand Down
64 changes: 61 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ func init() {
internal.DrainServerTransports = func(srv *Server, addr string) {
srv.drainServerTransports(addr)
}
internal.IsRegisteredMethod = func(srv *Server, method string) bool {
return srv.isRegisteredMethod(method)
}
internal.GetServer = getServer
internal.AddGlobalServerOptions = func(opt ...ServerOption) {
globalServerOptions = append(globalServerOptions, opt...)
}
Expand Down Expand Up @@ -912,9 +916,17 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
return
}
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
ctx := context.Background()
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagConn(ctx, &stats.ConnTagInfo{
RemoteAddr: rawConn.RemoteAddr(),
LocalAddr: rawConn.LocalAddr(),
})
sh.HandleConn(ctx, &stats.ConnBegin{})
}

// Finish handshaking (HTTP2)
st := s.newHTTP2Transport(rawConn)
st := s.newHTTP2Transport(ctx, rawConn)
rawConn.SetDeadline(time.Time{})
if st == nil {
return
Expand All @@ -940,7 +952,7 @@ func (s *Server) drainServerTransports(addr string) {

// newHTTP2Transport sets up a http/2 transport (using the
// gRPC http2 server transport in transport/http2_server.go).
func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
func (s *Server) newHTTP2Transport(ctx context.Context, c net.Conn) transport.ServerTransport {
config := &transport.ServerConfig{
MaxStreams: s.opts.maxConcurrentStreams,
ConnectionTimeout: s.opts.connectionTimeout,
Expand All @@ -958,7 +970,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport {
MaxHeaderListSize: s.opts.maxHeaderListSize,
HeaderTableSize: s.opts.headerTableSize,
}
st, err := transport.NewServerTransport(c, config)
st, err := transport.NewServerTransport(ctx, c, config)
if err != nil {
s.mu.Lock()
s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err)
Expand Down Expand Up @@ -1689,8 +1701,22 @@ func (s *Server) processStreamingRPC(ctx context.Context, t transport.ServerTran
return t.WriteStatus(ss.s, statusOK)
}

type serverKey struct{}

// getServer gets the Server from the context.
func getServer(ctx context.Context) *Server {
Comment on lines +1706 to +1707
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serverFromContext please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

s, _ := ctx.Value(serverKey{}).(*Server)
return s
}

// setServer sets the Server in the context.
func setServer(ctx context.Context, server *Server) context.Context {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contextWithServer please

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return context.WithValue(ctx, serverKey{}, server)
}

func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream) {
ctx := stream.Context()
ctx = setServer(ctx, s)
var ti *traceInfo
if EnableTracing {
tr := trace.New("grpc.Recv."+methodFamily(stream.Method()), stream.Method())
Expand All @@ -1707,6 +1733,20 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Str
}
}

md, _ := metadata.FromIncomingContext(ctx)
for _, sh := range s.opts.statsHandlers {
ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()})
stream.SetContext(ctx) // To have calls in stream callouts work. Will delete once all stats handler calls come from the gRPC layer.
sh.HandleRPC(ctx, &stats.InHeader{
FullMethod: stream.Method(),
RemoteAddr: t.RemoteAddr(),
LocalAddr: t.LocalAddr(),
Compression: stream.RecvCompress(),
WireLength: stream.HeaderWireLength(),
Header: md,
})
}

sm := stream.Method()
if sm != "" && sm[0] == '/' {
sm = sm[1:]
Expand Down Expand Up @@ -1920,6 +1960,24 @@ func (s *Server) getCodec(contentSubtype string) baseCodec {
return codec
}

// isRegisteredMethod returns whether the passed in method is registered as a
// method on the server.
func (s *Server) isRegisteredMethod(method string) bool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline this needs to be a "full method name", parsed accordingly, with both service & method included in the check.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for _, service := range s.services {
for mName := range service.methods {
if method == mName {
return true
}
}
for mName := range service.streams {
if method == mName {
return true
}
}
}
return false
}

// SetHeader sets the header metadata to be sent from the server to the client.
// The context provided must be the context passed to the server's handler.
//
Expand Down
86 changes: 86 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6539,3 +6539,89 @@ func (s) TestRPCBlockingOnPickerStatsCall(t *testing.T) {
t.Fatalf("sh.pickerUpdated count: %v, want: %v", pickerUpdatedCount, 2)
}
}

type statsHandlerServerAssert struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make a stub stats handler that works like our server, so the implementation relevant to the test can live inside the test itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'm gonna switch this test to use this new util. However, I won't switch all the tests that I wrote that use stats handler components to use this, but if you want me to I can do that (probably as another PR). Any future tests I will write I'll use this stub type.

errorCh *testutils.Channel
}

func (shsa *statsHandlerServerAssert) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
// OpenTelemetry instrumentation needs the passed in Server to determine if
// methods are registered in different handle calls in to record metrics.
// This tag RPC call context gets passed into every handle call, so can
// assert once here, since it maps to all the handle RPC calls that come
// after. These internal calls will be how the OpenTelemetry instrumentation
// component accesses this server and the subsequent helper on the server.
server := internal.GetServer.(func(context.Context) *grpc.Server)(ctx)
if server == nil {
shsa.errorCh.Send("stats handler received ctx has no server present")
}

if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "UnaryCall"); !registeredMethod {
shsa.errorCh.Send(errors.New("UnaryCall should be a registered method according to server"))
return ctx
}
Comment on lines +6559 to +6562
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A little more readable:

isRegistered := internal.IsRegisteredMethod.(func..etc..)
if !isRegistered(server, "UnaryCall) {
  // fail
}
// etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great point. Switched.


if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "FullDuplexCall"); !registeredMethod {
shsa.errorCh.Send(errors.New("FullDuplexCall should be a registered method according to server"))
return ctx
}

if registeredMethod := internal.IsRegisteredMethod.(func(*grpc.Server, string) bool)(server, "DoesNotExistCall"); registeredMethod {
shsa.errorCh.Send(errors.New("DoesNotExistCall should not be a registered method according to server"))
return ctx
}

shsa.errorCh.Send(nil)
return ctx
}

func (shsa *statsHandlerServerAssert) HandleRPC(ctx context.Context, s stats.RPCStats) {}

func (shsa *statsHandlerServerAssert) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
return ctx
}

func (shsa *statsHandlerServerAssert) HandleConn(context.Context, stats.ConnStats) {}

// TestStatsHandlerCallsServerIsRegisteredMethod tests whether a stats handler
// gets access to a Server on the server side, and thus the method that the
// server owns which specifies whether a method is made or not. The test sets up
// a server with a unary call and full duplex call configured, and makes an RPC.
// Within the stats handler, asking the server whether unary or duplex method
// names are registered should return true, and any other query should return
// false.
func (s) TestStatsHandlerCallsServerIsRegisteredMethod(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to a stats_test.go, or put it into the stats directory, or even in the root directory since it's actually testing grpc.Server functionality? This file is approaching 7k lines which is crazy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chose to put it in stats directory, since that is where we want this helper to be called.

errorCh := testutils.NewChannel()
shsa := &statsHandlerServerAssert{
errorCh: errorCh,
}
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
for {
if _, err := stream.Recv(); err == io.EOF {
return nil
}
}
},
}
if err := ss.Start([]grpc.ServerOption{grpc.StatsHandler(shsa)}); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{Payload: &testpb.Payload{}}); err != nil {
t.Fatalf("Unexpected error from UnaryCall: %v", err)
}
err, errRecv := errorCh.Receive(ctx)
if errRecv != nil {
t.Fatalf("error receiving from channel: %v", errRecv)
}
if err != nil {
t.Fatalf("error received from error channel: %v", err)
}
}