diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 2b1ce0148e3c..57475d27977f 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -566,7 +566,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( } if t.inTapHandle != nil { var err error - if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method}); err != nil { + if s.ctx, err = t.inTapHandle(s.ctx, &tap.Info{FullMethodName: s.method, Header: mdata}); err != nil { t.mu.Unlock() if t.logger.V(logLevel) { t.logger.Infof("Aborting the stream early due to InTapHandle failure: %v", err) diff --git a/tap/tap.go b/tap/tap.go index bfa5dfa40e4d..07f012576880 100644 --- a/tap/tap.go +++ b/tap/tap.go @@ -27,6 +27,8 @@ package tap import ( "context" + + "google.golang.org/grpc/metadata" ) // Info defines the relevant information needed by the handles. @@ -34,6 +36,10 @@ type Info struct { // FullMethodName is the string of grpc method (in the format of // /package.service/method). FullMethodName string + + // Header contains the header metadata received. + Header metadata.MD + // TODO: More to be added. } diff --git a/test/end2end_test.go b/test/end2end_test.go index 6b6c6077ccdc..1dd5757c7b9f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -2099,6 +2099,10 @@ func (t *myTap) handle(ctx context.Context, info *tap.Info) (context.Context, er switch info.FullMethodName { case "/grpc.testing.TestService/EmptyCall": t.cnt++ + + if vals := info.Header.Get("return-error"); len(vals) > 0 && vals[0] == "true" { + return nil, status.Errorf(codes.Unknown, "tap error") + } case "/grpc.testing.TestService/UnaryCall": return nil, fmt.Errorf("tap error") case "/grpc.testing.TestService/FullDuplexCall": @@ -2120,6 +2124,7 @@ func testTap(t *testing.T, e env) { tc := testgrpc.NewTestServiceClient(cc) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); err != nil { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) } @@ -2127,6 +2132,20 @@ func testTap(t *testing.T, e env) { t.Fatalf("Get the count in ttap %d, want 1", ttap.cnt) } + if _, err := tc.EmptyCall(metadata.AppendToOutgoingContext(ctx, "return-error", "false"), &testpb.Empty{}); err != nil { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, ", err) + } + if ttap.cnt != 2 { + t.Fatalf("Get the count in ttap %d, want 2", ttap.cnt) + } + + if _, err := tc.EmptyCall(metadata.AppendToOutgoingContext(ctx, "return-error", "true"), &testpb.Empty{}); status.Code(err) != codes.Unknown { + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %s", err, codes.Unknown) + } + if ttap.cnt != 3 { + t.Fatalf("Get the count in ttap %d, want 3", ttap.cnt) + } + payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 31) if err != nil { t.Fatal(err)