Skip to content

Commit

Permalink
Add ErrorConverter to protovalidate
Browse files Browse the repository at this point in the history
  • Loading branch information
khasanovbi committed May 7, 2024
1 parent 7da22cf commit 6ced6eb
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 55 deletions.
10 changes: 9 additions & 1 deletion interceptors/protovalidate/example_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type StreamService struct {
Expand All @@ -32,7 +34,13 @@ func ExampleStreamServerInterceptor() {
protovalidate_middleware.StreamServerInterceptor(validator,
protovalidate_middleware.WithIgnoreMessages(
(&testvalidatev1.SendStreamRequest{}).ProtoReflect().Type(),
)),
),
protovalidate_middleware.WithErrorConverter(
func(err error) error {
return status.Error(codes.InvalidArgument, err.Error())
},
),
),
),
)
svc = &StreamService{}
Expand Down
7 changes: 7 additions & 0 deletions interceptors/protovalidate/example_unary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate"
testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type UnaryService struct {
Expand All @@ -34,6 +36,11 @@ func ExampleUnaryServerInterceptor() {
protovalidate_middleware.WithIgnoreMessages(
(&testvalidatev1.SendRequest{}).ProtoReflect().Type(),
),
protovalidate_middleware.WithErrorConverter(
func(err error) error {
return status.Error(codes.InvalidArgument, err.Error())
},
),
),
),
)
Expand Down
27 changes: 27 additions & 0 deletions interceptors/protovalidate/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,33 @@ package protovalidate

import (
"golang.org/x/exp/slices"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
)

// DefaultErrorConverter returns InvalidArgument status with error message from validator.
func DefaultErrorConverter(err error) error {
return status.Error(codes.InvalidArgument, err.Error())
}

var (
defaultOptions = &options{
errorConverter: DefaultErrorConverter,
}
)

type options struct {
ignoreMessages []protoreflect.MessageType
errorConverter ErrorConverter
}

// An Option lets you add options to protovalidate interceptors using With* funcs.
type Option func(*options)

func evaluateOpts(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
Expand All @@ -39,3 +54,15 @@ func (o *options) shouldIgnoreMessage(m protoreflect.MessageType) bool {
return m == t
})
}

// ErrorConverter function customize the error returned by protovalidate.Validator.
type ErrorConverter = func(err error) error

// WithErrorConverter customizes the function for mapping errors.
//
// By default, DefaultErrorConverter used.
func WithErrorConverter(errorConverter ErrorConverter) Option {
return func(o *options) {
o.errorConverter = errorConverter
}
}
73 changes: 31 additions & 42 deletions interceptors/protovalidate/protovalidate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,38 @@ package protovalidate

import (
"context"
"errors"

"github.com/bufbuild/protovalidate-go"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

func validateMessage(validator *protovalidate.Validator, o *options, req any) error {
msg := req.(proto.Message)

if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
return nil
}

if err := validator.Validate(msg); err != nil {
return o.errorConverter(err)
}

return nil
}

// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOpts(opts)

return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
o := evaluateOpts(opts)
switch msg := req.(type) {
case proto.Message:
if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
break
}
if err = validator.Validate(msg); err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
default:
return nil, errors.New("unsupported message type")
if err := validateMessage(validator, o, req); err != nil {
return nil, err
}

return handler(ctx, req)
Expand All @@ -41,55 +45,40 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option)

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOpts(opts)

return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
ctx := stream.Context()

wrapped := wrapServerStream(stream)
wrapped.wrappedContext = ctx
wrapped.validator = validator
wrapped.options = evaluateOpts(opts)
wrapped := wrapServerStream(stream, validator, o)

return handler(srv, wrapped)
}
}

func (w *wrappedServerStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
if err := validateMessage(w.validator, w.options, m); err != nil {
return err
}

msg := m.(proto.Message)
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
return nil
}
if err := w.validator.Validate(msg); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}

return nil
return w.ServerStream.RecvMsg(m)
}

// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows to validate messages.
type wrappedServerStream struct {
grpc.ServerStream
// wrappedContext is the wrapper's own Context. You can assign it.
wrappedContext context.Context

validator *protovalidate.Validator
options *options
}

// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
func (w *wrappedServerStream) Context() context.Context {
return w.wrappedContext
}

// wrapServerStream returns a ServerStream that has the ability to overwrite context.
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
// wrapServerStream returns a ServerStream that has the ability to validate messages.
func wrapServerStream(
stream grpc.ServerStream,
validator *protovalidate.Validator,
options *options,
) *wrappedServerStream {
return &wrappedServerStream{ServerStream: stream, validator: validator, options: options}
}
59 changes: 47 additions & 12 deletions interceptors/protovalidate/protovalidate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package protovalidate_test

import (
"context"
"fmt"
"log"
"net"
"testing"
Expand All @@ -19,12 +20,15 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/reflect/protoreflect"
)

func customErrorConverter(err error) error {
return fmt.Errorf("my custom wrapper: %w", err)
}

func TestUnaryServerInterceptor(t *testing.T) {
validator, err := protovalidate.New()
assert.Nil(t, err)
assert.NoError(t, err)

interceptor := protovalidate_middleware.UnaryServerInterceptor(validator)

Expand All @@ -38,7 +42,7 @@ func TestUnaryServerInterceptor(t *testing.T) {
}

resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, resp, "good")
})

Expand All @@ -62,9 +66,24 @@ func TestUnaryServerInterceptor(t *testing.T) {
}

resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, resp, "good")
})

interceptor = protovalidate_middleware.UnaryServerInterceptor(validator,
protovalidate_middleware.WithErrorConverter(customErrorConverter),
)

t.Run("custom_error_converter", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}

_, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler)
assert.Error(t, err)
assert.Equal(t, codes.Unknown, status.Code(err))
assert.EqualError(t, err, "my custom wrapper: validation error:\n - message: value must be a valid email address [string.email]")
})
}

type server struct {
Expand All @@ -84,17 +103,15 @@ func (g *server) SendStream(

const bufSize = 1024 * 1024

func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn {
func startGrpcServer(t *testing.T, opts ...protovalidate_middleware.Option) *grpc.ClientConn {
lis := bufconn.Listen(bufSize)

validator, err := protovalidate.New()
assert.Nil(t, err)

s := grpc.NewServer(
grpc.StreamInterceptor(
protovalidate_middleware.StreamServerInterceptor(validator,
protovalidate_middleware.WithIgnoreMessages(ignoreMessages...),
),
protovalidate_middleware.StreamServerInterceptor(validator, opts...),
),
)
testvalidatev1.RegisterTestValidateServiceServer(s, &server{})
Expand Down Expand Up @@ -133,7 +150,7 @@ func TestStreamServerInterceptor(t *testing.T) {
)

_, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest)
assert.Nil(t, err)
assert.NoError(t, err)
})

t.Run("invalid_email", func(t *testing.T) {
Expand All @@ -151,13 +168,31 @@ func TestStreamServerInterceptor(t *testing.T) {

t.Run("invalid_email_ignored", func(t *testing.T) {
client := testvalidatev1.NewTestValidateServiceClient(
startGrpcServer(t, testvalidate.BadStreamRequest.ProtoReflect().Type()),
startGrpcServer(
t,
protovalidate_middleware.WithIgnoreMessages(testvalidate.BadStreamRequest.ProtoReflect().Type()),
),
)

out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
assert.Nil(t, err)
assert.NoError(t, err)

_, err = out.Recv()
assert.Nil(t, err)
assert.NoError(t, err)
})

t.Run("custom_error_converter", func(t *testing.T) {
client := testvalidatev1.NewTestValidateServiceClient(
startGrpcServer(t, protovalidate_middleware.WithErrorConverter(customErrorConverter)),
)

out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
assert.NoError(t, err)

_, err = out.Recv()
assert.Error(t, err)
st, _ := status.FromError(err)
assert.Equal(t, codes.Unknown, st.Code())
assert.Equal(t, "my custom wrapper: validation error:\n - message: value must be a valid email address [string.email]", st.Message())
})
}

0 comments on commit 6ced6eb

Please sign in to comment.