-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplifies working with field violation errors, especially conversion to gRPC responses.
- Loading branch information
Showing
5 changed files
with
246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
// Package validation provides primitives for validating proto messages and gRPC requests. | ||
package validation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package validation | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"google.golang.org/genproto/googleapis/rpc/errdetails" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
) | ||
|
||
// Error represents a message validation error. | ||
type Error struct { | ||
fieldViolations []*errdetails.BadRequest_FieldViolation | ||
grpcStatus *status.Status | ||
str string | ||
} | ||
|
||
// NewError creates a new validation error from the provided field violations. | ||
func NewError(fieldViolations []*errdetails.BadRequest_FieldViolation) error { | ||
if len(fieldViolations) == 0 { | ||
panic("validation.NewError: must provide at least one field violation") | ||
} | ||
return &Error{ | ||
fieldViolations: fieldViolations, | ||
} | ||
} | ||
|
||
// GRPCStatus converts the validation error to a gRPC status with code INVALID_ARGUMENT. | ||
func (e *Error) GRPCStatus() *status.Status { | ||
if e.grpcStatus == nil { | ||
var fields strings.Builder | ||
for i, fieldViolation := range e.fieldViolations { | ||
_, _ = fields.WriteString(fieldViolation.Field) | ||
if i < len(e.fieldViolations)-1 { | ||
_, _ = fields.WriteString(", ") | ||
} | ||
} | ||
withoutDetails := status.Newf(codes.InvalidArgument, "invalid fields: %s", fields.String()) | ||
if withDetails, err := withoutDetails.WithDetails(&errdetails.BadRequest{ | ||
FieldViolations: e.fieldViolations, | ||
}); err != nil { | ||
e.grpcStatus = withoutDetails | ||
} else { | ||
e.grpcStatus = withDetails | ||
} | ||
} | ||
return e.grpcStatus | ||
} | ||
|
||
// Error implements the error interface. | ||
func (e *Error) Error() string { | ||
if e.str == "" { | ||
if len(e.fieldViolations) == 1 { | ||
e.str = fmt.Sprintf( | ||
"field violation on %s: %s", | ||
e.fieldViolations[0].Field, | ||
e.fieldViolations[0].Description, | ||
) | ||
} else { | ||
var result strings.Builder | ||
_, _ = result.WriteString("field violation on multiple fields:\n") | ||
for i, fieldViolation := range e.fieldViolations { | ||
_, _ = result.WriteString(fmt.Sprintf("\t%s: %s", fieldViolation.Field, fieldViolation.Description)) | ||
if i < len(e.fieldViolations)-1 { | ||
_ = result.WriteByte('\n') | ||
} | ||
} | ||
e.str = result.String() | ||
} | ||
} | ||
return e.str | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package validation | ||
|
||
import ( | ||
"testing" | ||
|
||
"google.golang.org/genproto/googleapis/rpc/errdetails" | ||
"google.golang.org/grpc/codes" | ||
"google.golang.org/grpc/status" | ||
"google.golang.org/protobuf/testing/protocmp" | ||
"gotest.tools/v3/assert" | ||
"gotest.tools/v3/assert/cmp" | ||
) | ||
|
||
func TestError_NewError(t *testing.T) { | ||
t.Parallel() | ||
t.Run("panics on empty field violations", func(t *testing.T) { | ||
t.Parallel() | ||
assert.Assert(t, cmp.Panics(func() { | ||
_ = NewError(nil) | ||
})) | ||
}) | ||
} | ||
|
||
func TestError_Error(t *testing.T) { | ||
t.Parallel() | ||
err := NewError([]*errdetails.BadRequest_FieldViolation{ | ||
{Field: "foo.bar", Description: "test"}, | ||
{Field: "baz", Description: "test2"}, | ||
}) | ||
assert.Error(t, err, `field violation on multiple fields: | ||
foo.bar: test | ||
baz: test2`) | ||
} | ||
|
||
func TestError_GRPCStatus(t *testing.T) { | ||
t.Parallel() | ||
expected := &errdetails.BadRequest{ | ||
FieldViolations: []*errdetails.BadRequest_FieldViolation{ | ||
{Field: "foo.bar", Description: "test"}, | ||
{Field: "baz", Description: "test2"}, | ||
}, | ||
} | ||
s := status.Convert(NewError(expected.FieldViolations)) | ||
assert.Equal(t, codes.InvalidArgument, s.Code()) | ||
assert.Equal(t, "invalid fields: foo.bar, baz", s.Message()) | ||
details := s.Details() | ||
assert.Assert(t, len(details) == 1) | ||
actual, ok := details[0].(*errdetails.BadRequest) | ||
assert.Assert(t, ok) | ||
assert.DeepEqual(t, expected, actual, protocmp.Transform()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package validation | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"strings" | ||
|
||
"google.golang.org/genproto/googleapis/rpc/errdetails" | ||
) | ||
|
||
// MessageValidator provides primitives for validating the fields of a message. | ||
type MessageValidator struct { | ||
parentField string | ||
fieldViolations []*errdetails.BadRequest_FieldViolation | ||
} | ||
|
||
// SetParentField sets a parent field which will be prepended to all the subsequently added violations. | ||
func (m *MessageValidator) SetParentField(parentField string) { | ||
m.parentField = parentField | ||
} | ||
|
||
// AddFieldViolation adds a field violation to the message validator. | ||
func (m *MessageValidator) AddFieldViolation(field, description string, formatArgs ...interface{}) { | ||
if m.parentField != "" { | ||
field = makeFieldWithParent(m.parentField, field) | ||
} | ||
if len(formatArgs) > 0 { | ||
description = fmt.Sprintf(description, formatArgs...) | ||
} | ||
m.fieldViolations = append(m.fieldViolations, &errdetails.BadRequest_FieldViolation{ | ||
Field: field, | ||
Description: description, | ||
}) | ||
} | ||
|
||
// AddFieldError adds a field violation from the provided error. | ||
// If the provided error is a validation.Error, the individual field violations from the provided error are added. | ||
func (m *MessageValidator) AddFieldError(field string, err error) { | ||
var errValidation *Error | ||
if errors.As(err, &errValidation) { | ||
// Add the child field violations with the current field as parent. | ||
originalParentField := m.parentField | ||
m.parentField = makeFieldWithParent(m.parentField, field) | ||
for _, fieldViolation := range errValidation.fieldViolations { | ||
m.AddFieldViolation(fieldViolation.Field, fieldViolation.Description) | ||
} | ||
m.parentField = originalParentField | ||
} else { | ||
m.AddFieldViolation(field, err.Error()) | ||
} | ||
} | ||
|
||
// Err returns the validator's current validation error, or nil if no field validations have been registered. | ||
func (m *MessageValidator) Err() error { | ||
if len(m.fieldViolations) > 0 { | ||
return NewError(m.fieldViolations) | ||
} | ||
return nil | ||
} | ||
|
||
func makeFieldWithParent(parentField, field string) string { | ||
if parentField == "" { | ||
return field | ||
} | ||
var result strings.Builder | ||
result.Grow(len(parentField) + 1 + len(field)) | ||
_, _ = result.WriteString(parentField) | ||
_ = result.WriteByte('.') | ||
_, _ = result.WriteString(field) | ||
return result.String() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package validation | ||
|
||
import ( | ||
"errors" | ||
"testing" | ||
|
||
"gotest.tools/v3/assert" | ||
) | ||
|
||
func TestMessageValidator(t *testing.T) { | ||
t.Parallel() | ||
|
||
t.Run("no violation", func(t *testing.T) { | ||
t.Parallel() | ||
var v MessageValidator | ||
assert.NilError(t, v.Err()) | ||
}) | ||
|
||
t.Run("add single violation", func(t *testing.T) { | ||
t.Parallel() | ||
var v MessageValidator | ||
v.AddFieldViolation("foo", "bar") | ||
assert.Error(t, v.Err(), "field violation on foo: bar") | ||
}) | ||
|
||
t.Run("add single violation with parent", func(t *testing.T) { | ||
t.Parallel() | ||
var v MessageValidator | ||
v.SetParentField("foo") | ||
v.AddFieldViolation("bar", "baz") | ||
assert.Error(t, v.Err(), "field violation on foo.bar: baz") | ||
}) | ||
|
||
t.Run("add nested violations", func(t *testing.T) { | ||
t.Parallel() | ||
var inner MessageValidator | ||
inner.AddFieldViolation("b", "c") | ||
var outer MessageValidator | ||
outer.AddFieldError("a", inner.Err()) | ||
assert.Error(t, outer.Err(), "field violation on a.b: c") | ||
}) | ||
|
||
t.Run("add field error", func(t *testing.T) { | ||
t.Parallel() | ||
var v MessageValidator | ||
v.AddFieldError("a", errors.New("boom")) | ||
assert.Error(t, v.Err(), "field violation on a: boom") | ||
}) | ||
} |