diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index c7d3e8d91bf..9e03f1de388 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -127,17 +127,18 @@ var _ = utilities.PascalFromSnake _ = template.Must(handlerTemplate.New("request-func-signature").Parse(strings.Replace(` {{if .Method.GetServerStreaming}} -func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, error) +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) ({{.Method.Service.GetName}}_{{.Method.GetName}}Client, runtime.ServerMetadata, error) {{else}} -func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, error) +func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx context.Context, client {{.Method.Service.GetName}}Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {{end}}`, "\n", "", -1))) _ = template.Must(handlerTemplate.New("client-streaming-request-func").Parse(` {{template "request-func-signature" .}} { + var metadata runtime.ServerMetadata stream, err := client.{{.Method.GetName}}(ctx) if err != nil { glog.Errorf("Failed to start streaming: %v", err) - return nil, err + return nil, metadata, err } dec := json.NewDecoder(req.Body) for { @@ -148,21 +149,25 @@ func request_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}(ctx cont } if err != nil { glog.Errorf("Failed to decode request: %v", err) - return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) } if err = stream.Send(&protoReq); err != nil { glog.Errorf("Failed to send request: %v", err) - return nil, err + return nil, metadata, err } } {{if .Method.GetServerStreaming}} if err = stream.CloseSend(); err != nil { glog.Errorf("Failed to terminate client stream: %v", err) - return nil, err + return nil, metadata, err } - return stream, nil + + metadata.TrailerMD = stream.Trailer() + return stream, metadata, nil {{else}} - return stream.CloseAndRecv() + metadata.TrailerMD = stream.Trailer() + msg, err := stream.CloseAndRecv() + return msg, metadata, err {{end}} } `)) @@ -175,9 +180,10 @@ var ( {{end}} {{template "request-func-signature" .}} { var protoReq {{.Method.RequestType.GoType .Method.Service.File.GoPkg.Path}} + var metadata runtime.ServerMetadata {{if .Body}} if err := json.NewDecoder(req.Body).Decode(&{{.Body.RHS "protoReq"}}); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) } {{end}} {{if .PathParams}} @@ -190,7 +196,7 @@ var ( {{range $param := .PathParams}} val, ok = pathParams[{{$param | printf "%q"}}] if !ok { - return nil, grpc.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}}) + return nil, metadata, grpc.Errorf(codes.InvalidArgument, "missing parameter %s", {{$param | printf "%q"}}) } {{if $param.IsNestedProto3 }} err = runtime.PopulateFieldFromPath(&protoReq, {{$param | printf "%q"}}, val) @@ -198,17 +204,19 @@ var ( {{$param.RHS "protoReq"}}, err = {{$param.ConvertFuncExpr}}(val) {{end}} if err != nil { - return nil, err + return nil, metadata, err } {{end}} {{end}} {{if .HasQueryParam}} if err := runtime.PopulateQueryParameters(&protoReq, req.URL.Query(), filter_{{.Method.Service.GetName}}_{{.Method.GetName}}_{{.Index}}); err != nil { - return nil, grpc.Errorf(codes.InvalidArgument, "%v", err) + return nil, metadata, grpc.Errorf(codes.InvalidArgument, "%v", err) } {{end}} - return client.{{.Method.GetName}}(ctx, &protoReq) + msg, err := client.{{.Method.GetName}}(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + + return msg, metadata, err }`)) trailerTemplate = template.Must(template.New("trailer").Parse(` @@ -245,7 +253,8 @@ func Register{{$svc.GetName}}Handler(ctx context.Context, mux *runtime.ServeMux, {{range $m := $svc.Methods}} {{range $b := $m.Bindings}} mux.Handle({{$b.HTTPMethod | printf "%q"}}, pattern_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - resp, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(runtime.AnnotateContext(ctx, req), client, req, pathParams) + resp, md, err := request_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(runtime.AnnotateContext(ctx, req), client, req, pathParams) + ctx = runtime.NewServerMetadataContext(ctx, &md) if err != nil { runtime.HTTPError(ctx, w, req, err) return diff --git a/protoc-gen-grpc-gateway/gengateway/template_test.go b/protoc-gen-grpc-gateway/gengateway/template_test.go index 77e8d50fbb3..c53d51e0eeb 100644 --- a/protoc-gen-grpc-gateway/gengateway/template_test.go +++ b/protoc-gen-grpc-gateway/gengateway/template_test.go @@ -133,11 +133,11 @@ func TestApplyTemplateRequestWithoutClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) @@ -294,11 +294,11 @@ func TestApplyTemplateRequestWithClientStreaming(t *testing.T) { }{ { serverStreaming: false, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {`, }, { serverStreaming: true, - sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, error) {`, + sigWant: `func request_ExampleService_Echo_0(ctx context.Context, client ExampleServiceClient, req *http.Request, pathParams map[string]string) (ExampleService_EchoClient, runtime.ServerMetadata, error) {`, }, } { meth.ServerStreaming = proto.Bool(spec.serverStreaming) diff --git a/runtime/context.go b/runtime/context.go index cde43136867..8c77befd521 100644 --- a/runtime/context.go +++ b/runtime/context.go @@ -35,3 +35,21 @@ func AnnotateContext(ctx context.Context, req *http.Request) context.Context { } return metadata.NewContext(ctx, metadata.Pairs(pairs...)) } + +type ServerMetadata struct { + HeaderMD metadata.MD + TrailerMD metadata.MD +} + +type serverMetadataKey struct{} + +// NewServerMetadataContext creates a new context with ServerMetadata +func NewServerMetadataContext(ctx context.Context, md *ServerMetadata) context.Context { + return context.WithValue(ctx, serverMetadataKey{}, md) +} + +// ServerMetadataFromContext returns the ServerMetadata in ctx +func ServerMetadataFromContext(ctx context.Context) (md *ServerMetadata, ok bool) { + md, ok = ctx.Value(serverMetadataKey{}).(*ServerMetadata) + return +}