Skip to content

Commit

Permalink
return trailer metadata in unary rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
kazegusuri committed Feb 6, 2016
1 parent ee78409 commit cb54dbd
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 12 deletions.
68 changes: 61 additions & 7 deletions examples/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
sub "github.com/gengo/grpc-gateway/examples/sub"
"github.com/gengo/grpc-gateway/runtime"
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/codes"
)

type aBitOfEverything struct {
Expand All @@ -44,11 +45,6 @@ func TestIntegration(t *testing.T) {
runtime.WithForwardResponseOption(
func(ctx context.Context, w http.ResponseWriter, _ proto.Message) error {
if md, ok := runtime.ServerMetadataFromContext(ctx); ok {
for k, vs := range md.HeaderMD {
for i := range vs {
w.Header().Add(fmt.Sprintf("Grpc-Header-%s", k), vs[i])
}
}
for k, vs := range md.TrailerMD {
for i := range vs {
w.Header().Add(fmt.Sprintf("Grpc-Trailer-%s", k), vs[i])
Expand All @@ -71,6 +67,7 @@ func TestIntegration(t *testing.T) {
testABECreateBody(t)
testABEBulkCreate(t)
testABELookup(t)
testABELookupNotFound(t)
testABEList(t)
testABEListError(t)
testAdditionalBindings(t)
Expand Down Expand Up @@ -161,11 +158,11 @@ func testEchoBody(t *testing.T) {
t.Errorf("msg.Id = %q; want %q", got, want)
}

if value := resp.Header.Get("Grpc-Header-foo"); value != "foo1" {
if value := resp.Header.Get("Grpc-Metadata-foo"); value != "foo1" {
t.Errorf("Grpc-Header-foo was %s, wanted %s", value, "foo1")
}

if value := resp.Header.Get("Grpc-Header-bar"); value != "bar1" {
if value := resp.Header.Get("Grpc-Metadata-bar"); value != "bar1" {
t.Errorf("Grpc-Header-bar was %s, wanted %s", value, "bar1")
}

Expand Down Expand Up @@ -427,6 +424,63 @@ func testABELookup(t *testing.T) {
if got := msg; !reflect.DeepEqual(got, want) {
t.Errorf("msg= %v; want %v", &got, &want)
}

if got, want := resp.Header.Get("Grpc-Metadata-uuid"), want.Uuid; got != want {
t.Errorf("Grpc-Metadata-foo was %s, wanted %s", got, want)
}
}

func testABELookupNotFound(t *testing.T) {
url := "http://localhost:8080/v1/example/a_bit_of_everything"
uuid := "not_exist"
url = fmt.Sprintf("%s/%s", url, uuid)
resp, err := http.Get(url)
if err != nil {
t.Errorf("http.Get(%q) failed with %v; want success", url, err)
return
}
defer resp.Body.Close()

buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("ioutil.ReadAll(resp.Body) failed with %v; want success", err)
return
}

if got, want := resp.StatusCode, http.StatusNotFound; got != want {
t.Errorf("resp.StatusCode = %d; want %d", got, want)
t.Logf("%s", buf)
return
}

var msg runtime.ErrorBody
if err := json.Unmarshal(buf, &msg); err != nil {
t.Errorf("json.Unmarshal(%s, &msg) failed with %v; want success", buf, err)
return
}

if got, want := msg.Code, int(codes.NotFound); got != want {
t.Errorf("msg.Code = %d; want %d", got, want)
return
}

if got, want := resp.Header.Get("Grpc-Metadata-uuid"), uuid; got != want {
t.Errorf("Grpc-Metadata-foo was %s, wanted %s", got, want)
}

md := msg.Trailer
if md.Len() == 0 {
t.Errorf("no trailer is set")
return
}

if got, want := md["foo"], []string{"foo2"}; !reflect.DeepEqual(got, want) {
t.Errorf("msg.Trailer[%q] = %v; want %v", "foo", got, want)
}

if got, want := md["bar"], []string{"bar2"}; !reflect.DeepEqual(got, want) {
t.Errorf("msg.Trailer[%q] = %v; want %v", "bar", got, want)
}
}

func testABEList(t *testing.T) {
Expand Down
11 changes: 10 additions & 1 deletion examples/server/a_bit_of_everything.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,20 @@ func (s *_ABitOfEverythingServer) BulkCreate(stream examples.ABitOfEverythingSer
func (s *_ABitOfEverythingServer) Lookup(ctx context.Context, msg *examples.IdMessage) (*examples.ABitOfEverything, error) {
s.m.Lock()
defer s.m.Unlock()

glog.Info(msg)

grpc.SendHeader(ctx, metadata.New(map[string]string{
"uuid": msg.Uuid,
}))

if a, ok := s.v[msg.Uuid]; ok {
return a, nil
}

grpc.SetTrailer(ctx, metadata.New(map[string]string{
"foo": "foo2",
"bar": "bar2",
}))
return nil, grpc.Errorf(codes.NotFound, "not found")
}

Expand Down
23 changes: 19 additions & 4 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package runtime

import (
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/golang/glog"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
)

// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
Expand Down Expand Up @@ -62,9 +64,10 @@ var (
OtherErrorHandler = DefaultOtherErrorHandler
)

type errorBody struct {
Error string `json:"error"`
Code int `json:"code"`
type ErrorBody struct {
Error string `json:"error"`
Code int `json:"code"`
Trailer metadata.MD `json:"trailer,omitempty"`
}

// DefaultHTTPError is the default implementation of HTTPError.
Expand All @@ -77,7 +80,19 @@ func DefaultHTTPError(ctx context.Context, w http.ResponseWriter, _ *http.Reques
const fallback = `{"error": "failed to marshal error message"}`

w.Header().Set("Content-Type", "application/json")
body := errorBody{Error: grpc.ErrorDesc(err), Code: int(grpc.Code(err))}
body := ErrorBody{
Error: grpc.ErrorDesc(err),
Code: int(grpc.Code(err)),
}
if md, ok := ServerMetadataFromContext(ctx); ok {
for k, vs := range md.HeaderMD {
hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k)
for i := range vs {
w.Header().Add(hKey, vs[i])
}
}
body.Trailer = md.TrailerMD
}
buf, merr := json.Marshal(body)
if merr != nil {
glog.Errorf("Failed to marshal error message %q: %v", body, merr)
Expand Down
14 changes: 14 additions & 0 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ func ForwardResponseStream(ctx context.Context, w http.ResponseWriter, req *http

// ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
func ForwardResponseMessage(ctx context.Context, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
md, ok := ServerMetadataFromContext(ctx)
if !ok {
glog.Errorf("Failed to extract ServerMetadata from context")
}

if md != nil {
for k, vs := range md.HeaderMD {
hKey := fmt.Sprintf("%s%s", metadataHeaderPrefix, k)
for i := range vs {
w.Header().Add(hKey, vs[i])
}
}
}

w.Header().Set("Content-Type", "application/json")
if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
HTTPError(ctx, w, req, err)
Expand Down

0 comments on commit cb54dbd

Please sign in to comment.