Skip to content

Commit

Permalink
feat: add validation package
Browse files Browse the repository at this point in the history
Simplifies working with field violation errors, especially conversion to
gRPC responses.
  • Loading branch information
odsod committed Apr 28, 2021
1 parent 9c398c9 commit 8e8179b
Show file tree
Hide file tree
Showing 5 changed files with 246 additions and 0 deletions.
2 changes: 2 additions & 0 deletions validation/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package validation provides primitives for validating proto messages and gRPC requests.
package validation
73 changes: 73 additions & 0 deletions validation/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package validation

import (
"fmt"
"strings"

"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// Error represents a message validation error.
type Error struct {
fieldViolations []*errdetails.BadRequest_FieldViolation
grpcStatus *status.Status
str string
}

// NewError creates a new validation error from the provided field violations.
func NewError(fieldViolations []*errdetails.BadRequest_FieldViolation) error {
if len(fieldViolations) == 0 {
panic("validation.NewError: must provide at least one field violation")
}
return &Error{
fieldViolations: fieldViolations,
}
}

// GRPCStatus converts the validation error to a gRPC status with code INVALID_ARGUMENT.
func (e *Error) GRPCStatus() *status.Status {
if e.grpcStatus == nil {
var fields strings.Builder
for i, fieldViolation := range e.fieldViolations {
_, _ = fields.WriteString(fieldViolation.Field)
if i < len(e.fieldViolations)-1 {
_, _ = fields.WriteString(", ")
}
}
withoutDetails := status.Newf(codes.InvalidArgument, "invalid fields: %s", fields.String())
if withDetails, err := withoutDetails.WithDetails(&errdetails.BadRequest{
FieldViolations: e.fieldViolations,
}); err != nil {
e.grpcStatus = withoutDetails
} else {
e.grpcStatus = withDetails
}
}
return e.grpcStatus
}

// Error implements the error interface.
func (e *Error) Error() string {
if e.str == "" {
if len(e.fieldViolations) == 1 {
e.str = fmt.Sprintf(
"field violation on %s: %s",
e.fieldViolations[0].Field,
e.fieldViolations[0].Description,
)
} else {
var result strings.Builder
_, _ = result.WriteString("field violation on multiple fields:\n")
for i, fieldViolation := range e.fieldViolations {
_, _ = result.WriteString(fmt.Sprintf("\t%s: %s", fieldViolation.Field, fieldViolation.Description))
if i < len(e.fieldViolations)-1 {
_ = result.WriteByte('\n')
}
}
e.str = result.String()
}
}
return e.str
}
51 changes: 51 additions & 0 deletions validation/error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package validation

import (
"testing"

"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/testing/protocmp"
"gotest.tools/v3/assert"
"gotest.tools/v3/assert/cmp"
)

func TestError_NewError(t *testing.T) {
t.Parallel()
t.Run("panics on empty field violations", func(t *testing.T) {
t.Parallel()
assert.Assert(t, cmp.Panics(func() {
_ = NewError(nil)
}))
})
}

func TestError_Error(t *testing.T) {
t.Parallel()
err := NewError([]*errdetails.BadRequest_FieldViolation{
{Field: "foo.bar", Description: "test"},
{Field: "baz", Description: "test2"},
})
assert.Error(t, err, `field violation on multiple fields:
foo.bar: test
baz: test2`)
}

func TestError_GRPCStatus(t *testing.T) {
t.Parallel()
expected := &errdetails.BadRequest{
FieldViolations: []*errdetails.BadRequest_FieldViolation{
{Field: "foo.bar", Description: "test"},
{Field: "baz", Description: "test2"},
},
}
s := status.Convert(NewError(expected.FieldViolations))
assert.Equal(t, codes.InvalidArgument, s.Code())
assert.Equal(t, "invalid fields: foo.bar, baz", s.Message())
details := s.Details()
assert.Assert(t, len(details) == 1)
actual, ok := details[0].(*errdetails.BadRequest)
assert.Assert(t, ok)
assert.DeepEqual(t, expected, actual, protocmp.Transform())
}
71 changes: 71 additions & 0 deletions validation/messagevalidator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package validation

import (
"errors"
"fmt"
"strings"

"google.golang.org/genproto/googleapis/rpc/errdetails"
)

// MessageValidator provides primitives for validating the fields of a message.
type MessageValidator struct {
parentField string
fieldViolations []*errdetails.BadRequest_FieldViolation
}

// SetParentField sets a parent field which will be prepended to all the subsequently added violations.
func (m *MessageValidator) SetParentField(parentField string) {
m.parentField = parentField
}

// AddFieldViolation adds a field violation to the message validator.
func (m *MessageValidator) AddFieldViolation(field, description string, formatArgs ...interface{}) {
if m.parentField != "" {
field = makeFieldWithParent(m.parentField, field)
}
if len(formatArgs) > 0 {
description = fmt.Sprintf(description, formatArgs...)
}
m.fieldViolations = append(m.fieldViolations, &errdetails.BadRequest_FieldViolation{
Field: field,
Description: description,
})
}

// AddFieldError adds a field violation from the provided error.
// If the provided error is a validation.Error, the individual field violations from the provided error are added.
func (m *MessageValidator) AddFieldError(field string, err error) {
var errValidation *Error
if errors.As(err, &errValidation) {
// Add the child field violations with the current field as parent.
originalParentField := m.parentField
m.parentField = makeFieldWithParent(m.parentField, field)
for _, fieldViolation := range errValidation.fieldViolations {
m.AddFieldViolation(fieldViolation.Field, fieldViolation.Description)
}
m.parentField = originalParentField
} else {
m.AddFieldViolation(field, err.Error())
}
}

// Err returns the validator's current validation error, or nil if no field validations have been registered.
func (m *MessageValidator) Err() error {
if len(m.fieldViolations) > 0 {
return NewError(m.fieldViolations)
}
return nil
}

func makeFieldWithParent(parentField, field string) string {
if parentField == "" {
return field
}
var result strings.Builder
result.Grow(len(parentField) + 1 + len(field))
_, _ = result.WriteString(parentField)
_ = result.WriteByte('.')
_, _ = result.WriteString(field)
return result.String()
}
49 changes: 49 additions & 0 deletions validation/messagevalidator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package validation

import (
"errors"
"testing"

"gotest.tools/v3/assert"
)

func TestMessageValidator(t *testing.T) {
t.Parallel()

t.Run("no violation", func(t *testing.T) {
t.Parallel()
var v MessageValidator
assert.NilError(t, v.Err())
})

t.Run("add single violation", func(t *testing.T) {
t.Parallel()
var v MessageValidator
v.AddFieldViolation("foo", "bar")
assert.Error(t, v.Err(), "field violation on foo: bar")
})

t.Run("add single violation with parent", func(t *testing.T) {
t.Parallel()
var v MessageValidator
v.SetParentField("foo")
v.AddFieldViolation("bar", "baz")
assert.Error(t, v.Err(), "field violation on foo.bar: baz")
})

t.Run("add nested violations", func(t *testing.T) {
t.Parallel()
var inner MessageValidator
inner.AddFieldViolation("b", "c")
var outer MessageValidator
outer.AddFieldError("a", inner.Err())
assert.Error(t, outer.Err(), "field violation on a.b: c")
})

t.Run("add field error", func(t *testing.T) {
t.Parallel()
var v MessageValidator
v.AddFieldError("a", errors.New("boom"))
assert.Error(t, v.Err(), "field violation on a: boom")
})
}

0 comments on commit 8e8179b

Please sign in to comment.