Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Custom Directive Support for Fields #543

Merged
merged 20 commits into from
Jan 19, 2023
5 changes: 3 additions & 2 deletions directives/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (
// see the graphql.DirectiveVisitors() Schema Option
type Visitor interface {
// Before() is always called when the operation includes a directive matching this implementation's name
// When the first return value is true, the field resolver will not be called.
// Errors in Before() will prevent field resolution
pavelnikolov marked this conversation as resolved.
Show resolved Hide resolved
Before(ctx context.Context, directive *types.Directive, input interface{}) error
Before(ctx context.Context, directive *types.Directive, input interface{}) (skipResolver bool, err error)
// After is called if Before() *and* the field resolver do not error
After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error)
After(ctx context.Context, directive *types.Directive, output interface{}) (modified interface{}, err error)
}
56 changes: 54 additions & 2 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ type customDirectiveVisitor struct {
beforeWasCalled bool
}

func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) {
v.beforeWasCalled = true
return nil
return false, nil
}

func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
Expand All @@ -74,6 +74,30 @@ func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Dir
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
}

type cachedDirectiveVisitor struct {
cachedValue interface{}
}

func (v *cachedDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) (bool, error) {
pavelnikolov marked this conversation as resolved.
Show resolved Hide resolved
s := "valueFromCache"
v.cachedValue = s
return true, nil
}

func (v *cachedDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
return v.cachedValue, nil
}

type cachedDirectiveResolver struct {
t *testing.T
}

func (r *cachedDirectiveResolver) Hello(ctx context.Context, args struct{ FullName string }) string {
r.t.Error("expected cached resolver to not be called, but it was")

return ""
}

type theNumberResolver struct {
number int32
}
Expand Down Expand Up @@ -329,6 +353,34 @@ func TestCustomDirective(t *testing.T) {
}
`,
},
{
Schema: graphql.MustParseSchema(`
directive @cached(
key: String!
) on FIELD_DEFINITION

schema {
query: Query
}

type Query {
hello(full_name: String!): String! @cached(key: "notcheckedintest")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cool if this key can come from the resolved entity's ID somehow. In a real-world situation we would need to use the ID of that entity or the ID of its parent, for example, the products of a category or the comments of a post. In both of these examples we need to cache by the ID of the parent entity. I wish we could somehow do that.

}
`, &cachedDirectiveResolver{t: t},
graphql.DirectiveVisitors(map[string]directives.Visitor{
"cached": &cachedDirectiveVisitor{},
})),
Query: `
{
hello(full_name: "Full Name")
}
`,
ExpectedResult: `
{
"hello": "valueFromCache"
}
`,
},
})
}

Expand Down
53 changes: 38 additions & 15 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,13 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f

res := f.resolver
if f.field.UseMethodResolver() {
var in []reflect.Value
var (
skipResolver bool
in []reflect.Value
callOut []reflect.Value
visitorErr error
modified interface{}
)
if f.field.HasContext {
in = append(in, reflect.ValueOf(traceCtx))
}
Expand All @@ -219,8 +225,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
for _, inValue := range in {
values = append(values, inValue.Interface())
}

visitorErr := visitor.Before(ctx, directive, values)
skipResolver, visitorErr = visitor.Before(ctx, directive, values)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
Expand All @@ -231,27 +236,36 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
}
}

// Call method
callOut := res.Method(f.field.MethodIndex).Call(in)
result = callOut[0]
// Call method unless the Before visitor tells us not to
if !skipResolver {
callOut = res.Method(f.field.MethodIndex).Call(in)
result = callOut[0]
}

// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
modified, visitorErr := visitor.After(ctx, directive, result.Interface())
if (result.IsValid() && !result.IsZero()) && result.CanInterface() {
modified, visitorErr = visitor.After(ctx, directive, result.Interface())
} else {
modified, visitorErr = visitor.After(ctx, directive, nil)
}

if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(modified)
}
result = reflect.ValueOf(modified)
}
}
}

if skipResolver {
return nil
}
if f.field.HasError && !callOut[1].IsNil() {
resolverErr := callOut[1].Interface().(error)
err := errors.Errorf("%s", resolverErr)
Expand All @@ -263,6 +277,11 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
return err
}
} else {
var (
skipResolver bool
visitorErr error
modified interface{}
)
// TODO extract out unwrapping ptr logic to a common place
if res.Kind() == reflect.Ptr {
res = res.Elem()
Expand All @@ -271,8 +290,7 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
if len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
// TODO check that directive arity == 0-that should be an error at schema init time
visitorErr := visitor.Before(ctx, directive, nil)
skipResolver, visitorErr = visitor.Before(ctx, directive, nil)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
Expand All @@ -282,20 +300,25 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
}
}
}
result = res.FieldByIndex(f.field.FieldIndex)
if !skipResolver {
result = res.FieldByIndex(f.field.FieldIndex)
}
// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that struct resolver fields can ever produce an error but I'm not an expert. Shall we delete this, Pavel?

Suggested change
if !f.field.HasError && len(f.field.Directives) > 0 {
if len(f.field.Directives) > 0 {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, struct fields cannot produce an error. It doesn't make sense to me to check for one.

for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
modified, visitorErr := visitor.After(ctx, directive, result.Interface())
if (result.IsValid() && !result.IsZero()) && result.CanInterface() {
modified, visitorErr = visitor.After(ctx, directive, result.Interface())
} else {
modified, visitorErr = visitor.After(ctx, directive, nil)
}
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(modified)
}
result = reflect.ValueOf(modified)
}
}
}
Expand Down