Skip to content

Commit

Permalink
transport/http2: use HTTP 400 for bad requests instead of 500 (#5804)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjbarag authored Dec 13, 2022
1 parent 5003029 commit 2f413c4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
30 changes: 21 additions & 9 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,32 @@ import (
"google.golang.org/grpc/status"
)

// NewServerHandlerTransport returns a ServerTransport handling gRPC
// from inside an http.Handler. It requires that the http Server
// supports HTTP/2.
// NewServerHandlerTransport returns a ServerTransport handling gRPC from
// inside an http.Handler, or writes an HTTP error to w and returns an error.
// It requires that the http Server supports HTTP/2.
func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler) (ServerTransport, error) {
if r.ProtoMajor != 2 {
return nil, errors.New("gRPC requires HTTP/2")
msg := "gRPC requires HTTP/2"
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if r.Method != "POST" {
return nil, errors.New("invalid gRPC request method")
msg := fmt.Sprintf("invalid gRPC request method %q", r.Method)
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
contentType := r.Header.Get("Content-Type")
// TODO: do we assume contentType is lowercase? we did before
contentSubtype, validContentType := grpcutil.ContentSubtype(contentType)
if !validContentType {
return nil, errors.New("invalid gRPC request content-type")
msg := fmt.Sprintf("invalid gRPC request content-type %q", contentType)
http.Error(w, msg, http.StatusBadRequest)
return nil, errors.New(msg)
}
if _, ok := w.(http.Flusher); !ok {
return nil, errors.New("gRPC requires a ResponseWriter supporting http.Flusher")
msg := "gRPC requires a ResponseWriter supporting http.Flusher"
http.Error(w, msg, http.StatusInternalServerError)
return nil, errors.New(msg)
}

st := &serverHandlerTransport{
Expand All @@ -79,7 +87,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
if v := r.Header.Get("grpc-timeout"); v != "" {
to, err := decodeTimeout(v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed time-out: %v", err)
msg := fmt.Sprintf("malformed time-out: %v", err)
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
st.timeoutSet = true
st.timeout = to
Expand All @@ -97,7 +107,9 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
for _, v := range vv {
v, err := decodeMetadataHeader(k, v)
if err != nil {
return nil, status.Errorf(codes.Internal, "malformed binary metadata: %v", err)
msg := fmt.Sprintf("malformed binary metadata %q in header %q: %v", v, k, err)
http.Error(w, msg, http.StatusBadRequest)
return nil, status.Error(codes.Internal, msg)
}
metakv = append(metakv, k, v)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
Method: "GET",
Header: http.Header{},
},
wantErr: "invalid gRPC request method",
wantErr: `invalid gRPC request method "GET"`,
},
{
name: "bad content type",
Expand All @@ -74,7 +74,7 @@ func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
"Content-Type": {"application/foo"},
},
},
wantErr: "invalid gRPC request content-type",
wantErr: `invalid gRPC request content-type "application/foo"`,
},
{
name: "not flusher",
Expand Down
3 changes: 2 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,8 @@ var _ http.Handler = (*Server)(nil)
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Errors returned from transport.NewServerHandlerTransport have
// already been written to w.
return
}
if !s.addConn(listenerAddressForServeHTTP, st) {
Expand Down

0 comments on commit 2f413c4

Please sign in to comment.