diff --git a/examples/integration_test.go b/examples/integration_test.go index bbb9c0be462..24f482ef73c 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -19,6 +19,7 @@ import ( sub "github.com/gengo/grpc-gateway/examples/sub" "github.com/gengo/grpc-gateway/runtime" "github.com/golang/protobuf/proto" + "google.golang.org/grpc/codes" ) type aBitOfEverything struct { @@ -44,11 +45,6 @@ func TestIntegration(t *testing.T) { runtime.WithForwardResponseOption( func(ctx context.Context, w http.ResponseWriter, _ proto.Message) error { if md, ok := runtime.ServerMetadataFromContext(ctx); ok { - for k, vs := range md.HeaderMD { - for i := range vs { - w.Header().Add(fmt.Sprintf("Grpc-Header-%s", k), vs[i]) - } - } for k, vs := range md.TrailerMD { for i := range vs { w.Header().Add(fmt.Sprintf("Grpc-Trailer-%s", k), vs[i]) @@ -71,6 +67,7 @@ func TestIntegration(t *testing.T) { testABECreateBody(t) testABEBulkCreate(t) testABELookup(t) + testABELookupNotFound(t) testABEList(t) testABEListError(t) testAdditionalBindings(t) @@ -161,11 +158,11 @@ func testEchoBody(t *testing.T) { t.Errorf("msg.Id = %q; want %q", got, want) } - if value := resp.Header.Get("Grpc-Header-foo"); value != "foo1" { + if value := resp.Header.Get("Grpc-Metadata-foo"); value != "foo1" { t.Errorf("Grpc-Header-foo was %s, wanted %s", value, "foo1") } - if value := resp.Header.Get("Grpc-Header-bar"); value != "bar1" { + if value := resp.Header.Get("Grpc-Metadata-bar"); value != "bar1" { t.Errorf("Grpc-Header-bar was %s, wanted %s", value, "bar1") } @@ -427,6 +424,63 @@ func testABELookup(t *testing.T) { if got := msg; !reflect.DeepEqual(got, want) { t.Errorf("msg= %v; want %v", &got, &want) } + + if got, want := resp.Header.Get("Grpc-Metadata-uuid"), want.Uuid; got != want { + t.Errorf("Grpc-Metadata-foo was %s, wanted %s", got, want) + } +} + +func testABELookupNotFound(t *testing.T) { + url := "http://localhost:8080/v1/example/a_bit_of_everything" + uuid := "not_exist" + url = fmt.Sprintf("%s/%s", url, uuid) + resp, err := http.Get(url) + if err != nil { + t.Errorf("http.Get(%q) failed with %v; want success", url, err) + return + } + defer resp.Body.Close() + + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err) + return + } + + if got, want := resp.StatusCode, http.StatusNotFound; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + t.Logf("%s", buf) + return + } + + var msg runtime.ErrorBody + if err := json.Unmarshal(buf, &msg); err != nil { + t.Errorf("json.Unmarshal(%s, &msg) failed with %v; want success", buf, err) + return + } + + if got, want := msg.Code, int(codes.NotFound); got != want { + t.Errorf("msg.Code = %d; want %d", got, want) + return + } + + if got, want := resp.Header.Get("Grpc-Metadata-uuid"), uuid; got != want { + t.Errorf("Grpc-Metadata-foo was %s, wanted %s", got, want) + } + + md := msg.Trailer + if md.Len() == 0 { + t.Errorf("no trailer is set") + return + } + + if got, want := md["foo"], []string{"foo2"}; !reflect.DeepEqual(got, want) { + t.Errorf("msg.Trailer[%q] = %v; want %v", "foo", got, want) + } + + if got, want := md["bar"], []string{"bar2"}; !reflect.DeepEqual(got, want) { + t.Errorf("msg.Trailer[%q] = %v; want %v", "bar", got, want) + } } func testABEList(t *testing.T) { diff --git a/examples/server/a_bit_of_everything.go b/examples/server/a_bit_of_everything.go index 5166b2134e4..a5d492df287 100644 --- a/examples/server/a_bit_of_everything.go +++ b/examples/server/a_bit_of_everything.go @@ -83,11 +83,20 @@ func (s *_ABitOfEverythingServer) BulkCreate(stream examples.ABitOfEverythingSer func (s *_ABitOfEverythingServer) Lookup(ctx context.Context, msg *examples.IdMessage) (*examples.ABitOfEverything, error) { s.m.Lock() defer s.m.Unlock() - glog.Info(msg) + + grpc.SendHeader(ctx, metadata.New(map[string]string{ + "uuid": msg.Uuid, + })) + if a, ok := s.v[msg.Uuid]; ok { return a, nil } + + grpc.SetTrailer(ctx, metadata.New(map[string]string{ + "foo": "foo2", + "bar": "bar2", + })) return nil, grpc.Errorf(codes.NotFound, "not found") } diff --git a/runtime/errors.go b/runtime/errors.go index 98c1950d4de..30e2dc85dd5 100644 --- a/runtime/errors.go +++ b/runtime/errors.go @@ -2,6 +2,7 @@ package runtime import ( "encoding/json" + "fmt" "io" "net/http" @@ -9,6 +10,7 @@ import ( "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" ) // HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status. @@ -62,9 +64,10 @@ var ( OtherErrorHandler = DefaultOtherErrorHandler ) -type errorBody struct { - Error string `json:"error"` - Code int `json:"code"` +type ErrorBody struct { + Error string `json:"error"` + Code int `json:"code"` + Trailer metadata.MD `json:"trailer,omitempty"` } // DefaultHTTPError is the default implementation of HTTPError. @@ -77,7 +80,19 @@ func DefaultHTTPError(ctx context.Context, w http.ResponseWriter, _ *http.Reques const fallback = `{"error": "failed to marshal error message"}` w.Header().Set("Content-Type", "application/json") - body := errorBody{Error: grpc.ErrorDesc(err), Code: int(grpc.Code(err))} + body := ErrorBody{ + Error: grpc.ErrorDesc(err), + Code: int(grpc.Code(err)), + } + if md, ok := ServerMetadataFromContext(ctx); ok { + for k, vs := range md.HeaderMD { + hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k) + for i := range vs { + w.Header().Add(hKey, vs[i]) + } + } + body.Trailer = md.TrailerMD + } buf, merr := json.Marshal(body) if merr != nil { glog.Errorf("Failed to marshal error message %q: %v", body, merr) diff --git a/runtime/handler.go b/runtime/handler.go index 1be324cafdd..517cf6575b1 100644 --- a/runtime/handler.go +++ b/runtime/handler.go @@ -80,6 +80,20 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { + md, ok := ServerMetadataFromContext(ctx) + if !ok { + glog.Errorf("Failed to extract ServerMetadata from context") + } + + if md != nil { + for k, vs := range md.HeaderMD { + hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k) + for i := range vs { + w.Header().Add(hKey, vs[i]) + } + } + } + w.Header().Set("Content-Type", "application/json") if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { HTTPError(ctx, w, req, err)