Skip to content

Commit

Permalink
Merge pull request #715 from akshayjshah/ajs/robust
Browse files Browse the repository at this point in the history
protovalidate: avoid pointer comparisons
  • Loading branch information
johanbrandhorst authored May 27, 2024
2 parents 8036513 + 21bacae commit 3606823
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 65 deletions.
24 changes: 16 additions & 8 deletions interceptors/protovalidate/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type options struct {
ignoreMessages []protoreflect.MessageType
ignoreMessages []protoreflect.FullName
}

// An Option lets you add options to protovalidate interceptors using With* funcs.
Expand All @@ -26,16 +26,24 @@ func evaluateOpts(opts []Option) *options {
return optCopy
}

// WithIgnoreMessages sets the messages that should be ignored by the validator. Use with
// caution and ensure validation is performed elsewhere.
// WithIgnoreMessages sets the messages that should be ignored by the
// validator. Message types are matched using their fully-qualified Protobuf
// names.
//
// Use with caution and ensure validation is performed elsewhere.
func WithIgnoreMessages(msgs ...protoreflect.MessageType) Option {
names := make([]protoreflect.FullName, 0, len(msgs))
for _, msg := range msgs {
names = append(names, msg.Descriptor().FullName())
}
slices.Sort(names)
return func(o *options) {
o.ignoreMessages = msgs
o.ignoreMessages = names
}
}

func (o *options) shouldIgnoreMessage(m protoreflect.MessageType) bool {
return slices.ContainsFunc(o.ignoreMessages, func(t protoreflect.MessageType) bool {
return m == t
})
func (o *options) shouldIgnoreMessage(fqn protoreflect.FullName) bool {
// Names are sorted in WithIgnoreMessages, so we can use binary search.
_, found := slices.BinarySearch(o.ignoreMessages, fqn)
return found
}
80 changes: 29 additions & 51 deletions interceptors/protovalidate/protovalidate.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,98 +15,76 @@ import (
)

// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOpts(opts)

return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
o := evaluateOpts(opts)
switch msg := req.(type) {
case proto.Message:
if o.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
break
}
if err = validator.Validate(msg); err != nil {
return nil, validationErrToStatus(err).Err()
}
default:
return nil, errors.New("unsupported message type")
if err := validateMsg(req, validator, o); err != nil {
return nil, err
}

return handler(ctx, req)
}
}

// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func StreamServerInterceptor(validator *protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor {
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
ctx := stream.Context()
return handler(srv, &wrappedServerStream{
ServerStream: stream,
validator: validator,
options: evaluateOpts(opts),
})
}
}

wrapped := wrapServerStream(stream)
wrapped.wrappedContext = ctx
wrapped.validator = validator
wrapped.options = evaluateOpts(opts)
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type wrappedServerStream struct {
grpc.ServerStream

return handler(srv, wrapped)
}
validator *protovalidate.Validator
options *options
}

func (w *wrappedServerStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
return err
}
return validateMsg(m, w.validator, w.options)
}

func validateMsg(m interface{}, validator *protovalidate.Validator, opts *options) error {
msg, ok := m.(proto.Message)
if !ok {
return errors.New("unsupported message type")
}
if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) {
if opts.shouldIgnoreMessage(msg.ProtoReflect().Descriptor().FullName()) {
return nil
}
if err := w.validator.Validate(msg); err != nil {
return validationErrToStatus(err).Err()
err := validator.Validate(msg)
if err == nil {
return nil
}

return nil
}

// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type wrappedServerStream struct {
grpc.ServerStream
// wrappedContext is the wrapper's own Context. You can assign it.
wrappedContext context.Context

validator *protovalidate.Validator
options *options
}

// Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context()
func (w *wrappedServerStream) Context() context.Context {
return w.wrappedContext
}

// wrapServerStream returns a ServerStream that has the ability to overwrite context.
func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream {
return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()}
}

func validationErrToStatus(err error) *status.Status {
// Message is invalid.
if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) {
// Message is invalid.
st := status.New(codes.InvalidArgument, err.Error())
ds, detErr := st.WithDetails(valErr.ToProto())
if detErr != nil {
return st
return st.Err()
}
return ds
return ds.Err()
}
// CEL expression doesn't compile or type-check.
return status.New(codes.Unknown, err.Error())
return status.Error(codes.Unknown, err.Error())
}
25 changes: 19 additions & 6 deletions interceptors/protovalidate/protovalidate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ func TestUnaryServerInterceptor(t *testing.T) {

type server struct {
testvalidatev1.UnimplementedTestValidateServiceServer

called *bool
}

func (g *server) SendStream(
_ *testvalidatev1.SendStreamRequest,
stream testvalidatev1.TestValidateService_SendStreamServer,
) error {
*g.called = true
if err := stream.Send(&testvalidatev1.SendStreamResponse{}); err != nil {
return err
}
Expand All @@ -85,7 +88,7 @@ func (g *server) SendStream(

const bufSize = 1024 * 1024

func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn {
func startGrpcServer(t *testing.T, called *bool, ignoreMessages ...protoreflect.MessageType) *grpc.ClientConn {
lis := bufconn.Listen(bufSize)

validator, err := protovalidate.New()
Expand All @@ -98,7 +101,7 @@ func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *
),
),
)
testvalidatev1.RegisterTestValidateServiceServer(s, &server{})
testvalidatev1.RegisterTestValidateServiceServer(s, &server{called: called})
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("Server exited with error: %v", err)
Expand Down Expand Up @@ -129,17 +132,24 @@ func startGrpcServer(t *testing.T, ignoreMessages ...protoreflect.MessageType) *

func TestStreamServerInterceptor(t *testing.T) {
t.Run("valid_email", func(t *testing.T) {
called := proto.Bool(false)
client := testvalidatev1.NewTestValidateServiceClient(
startGrpcServer(t),
startGrpcServer(t, called),
)

_, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest)
out, err := client.SendStream(context.Background(), testvalidate.GoodStreamRequest)
assert.Nil(t, err)

_, err = out.Recv()
t.Log(err)
assert.Nil(t, err)
assert.True(t, *called)
})

t.Run("invalid_email", func(t *testing.T) {
called := proto.Bool(false)
client := testvalidatev1.NewTestValidateServiceClient(
startGrpcServer(t),
startGrpcServer(t, called),
)

out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
Expand All @@ -151,18 +161,21 @@ func TestStreamServerInterceptor(t *testing.T) {
ConstraintId: "string.email",
Message: "value must be a valid email address",
}, err)
assert.False(t, *called)
})

t.Run("invalid_email_ignored", func(t *testing.T) {
called := proto.Bool(false)
client := testvalidatev1.NewTestValidateServiceClient(
startGrpcServer(t, testvalidate.BadStreamRequest.ProtoReflect().Type()),
startGrpcServer(t, called, testvalidate.BadStreamRequest.ProtoReflect().Type()),
)

out, err := client.SendStream(context.Background(), testvalidate.BadStreamRequest)
assert.Nil(t, err)

_, err = out.Recv()
assert.Nil(t, err)
assert.True(t, *called)
})
}

Expand Down

0 comments on commit 3606823

Please sign in to comment.