Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Custom Directive Support for Fields #543

Merged
merged 20 commits into from
Jan 19, 2023
10 changes: 10 additions & 0 deletions graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type Schema struct {
useStringDescriptions bool
disableIntrospection bool
subscribeResolverTimeout time.Duration
visitors map[string]types.DirectiveVisitor
}

func (s *Schema) ASTSchema() *types.Schema {
Expand Down Expand Up @@ -169,6 +170,14 @@ func SubscribeResolverTimeout(timeout time.Duration) SchemaOpt {
}
}

// DirectiveVisitors allows to pass custom directive visitors that will be able to handle
// your GraphQL schema directives.
func DirectiveVisitors(visitors map[string]types.DirectiveVisitor) SchemaOpt {
return func(s *Schema) {
s.visitors = visitors
}
}

// Response represents a typical response of a GraphQL server. It may be encoded to JSON directly or
// it may be further processed to a custom response type, for example to include custom error data.
// Errors are intentionally serialized first based on the advice in https://github.com/facebook/graphql/commit/7b40390d48680b15cb93e02d46ac5eb249689876#diff-757cea6edf0288677a9eea4cfc801d87R107
Expand Down Expand Up @@ -258,6 +267,7 @@ func (s *Schema) exec(ctx context.Context, queryString string, operationName str
Tracer: s.tracer,
Logger: s.logger,
PanicHandler: s.panicHandler,
Visitors: s.visitors,
}
varTypes := make(map[string]*introspection.Type)
for _, v := range op.Vars {
Expand Down
178 changes: 177 additions & 1 deletion graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/graph-gophers/graphql-go/gqltesting"
"github.com/graph-gophers/graphql-go/introspection"
"github.com/graph-gophers/graphql-go/trace/tracer"
"github.com/graph-gophers/graphql-go/types"
)

type helloWorldResolver1 struct{}
Expand Down Expand Up @@ -48,6 +49,30 @@ func (r *helloSnakeResolver2) SayHello(ctx context.Context, args struct{ FullNam
return "Hello " + args.FullName + "!", nil
}

type structFieldResolver struct {
Hello string
}

type customDirectiveVisitor struct {
beforeWasCalled bool
}

func (v *customDirectiveVisitor) Before(ctx context.Context, directive *types.Directive, input interface{}) error {
v.beforeWasCalled = true
return nil
}

func (v *customDirectiveVisitor) After(ctx context.Context, directive *types.Directive, output interface{}) (interface{}, error) {
if v.beforeWasCalled == false {
return nil, errors.New("Before directive visitor method wasn't called.")
}

if value, ok := directive.Arguments.Get("customAttribute"); ok {
return fmt.Sprintf("Directive '%s' (with arg '%s') modified result: %s", directive.Name.Name, value.String(), output.(string)), nil
}
return fmt.Sprintf("Directive '%s' modified result: %s", directive.Name.Name, output.(string)), nil
}

type theNumberResolver struct {
number int32
}
Expand Down Expand Up @@ -191,7 +216,6 @@ func TestHelloWorld(t *testing.T) {
}
`,
},

{
Schema: graphql.MustParseSchema(`
schema {
Expand All @@ -216,6 +240,158 @@ func TestHelloWorld(t *testing.T) {
})
}

func TestHelloWorldStructFieldResolver(t *testing.T) {
rudle marked this conversation as resolved.
Show resolved Hide resolved
gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
schema {
query: Query
}

type Query {
hello: String!
}
`,
&structFieldResolver{Hello: "Hello world!"},
graphql.UseFieldResolvers()),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Hello world!"
}
`,
},
})
}

func TestCustomDirective(t *testing.T) {
t.Parallel()

gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
directive @customDirective on FIELD_DEFINITION

schema {
query: Query
}

type Query {
hello_html: String! @customDirective
}
`, &helloSnakeResolver1{},
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
})),
Query: `
{
hello_html
}
`,
ExpectedResult: `
{
"hello_html": "Directive 'customDirective' modified result: Hello snake!"
}
`,
},
{
Schema: graphql.MustParseSchema(`
directive @customDirective(
customAttribute: String!
) on FIELD_DEFINITION

schema {
query: Query
}

type Query {
say_hello(full_name: String!): String! @customDirective(customAttribute: hi)
}
`, &helloSnakeResolver1{},
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
})),
Query: `
{
say_hello(full_name: "Johnny")
}
`,
ExpectedResult: `
{
"say_hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello Johnny!"
}
`,
},

// tests for struct field resolvers
rudle marked this conversation as resolved.
Show resolved Hide resolved

})
}

func TestCustomDirectiveStructFieldResolver(t *testing.T) {
schemaOpt := []graphql.SchemaOpt{
graphql.DirectiveVisitors(map[string]types.DirectiveVisitor{
"customDirective": &customDirectiveVisitor{},
}),
graphql.UseFieldResolvers(),
}

gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
directive @customDirective on FIELD_DEFINITION

schema {
query: Query
}

type Query {
hello: String! @customDirective
}
`, &structFieldResolver{Hello: "Hello world!"}, schemaOpt...),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Directive 'customDirective' modified result: Hello world!"
}
`,
},
{
Schema: graphql.MustParseSchema(`
directive @customDirective(
customAttribute: String!
) on FIELD_DEFINITION

schema {
query: Query
}

type Query {
hello: String! @customDirective(customAttribute: hi)
}
`, &structFieldResolver{Hello: "Hello world!"}, schemaOpt...),
Query: `
{
hello
}
`,
ExpectedResult: `
{
"hello": "Directive 'customDirective' (with arg 'hi') modified result: Hello world!"
}
`,
},
})
}

func TestHelloSnake(t *testing.T) {
t.Parallel()

Expand Down
72 changes: 72 additions & 0 deletions internal/exec/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type Request struct {
Logger log.Logger
PanicHandler errors.PanicHandler
SubscribeResolverTimeout time.Duration
Visitors map[string]types.DirectiveVisitor
}

func (r *Request) handlePanic(ctx context.Context) {
Expand Down Expand Up @@ -208,8 +209,48 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
if f.field.ArgsPacker != nil {
in = append(in, f.field.PackedArgs)
}

// Before hook directive visitor
if len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
values := make([]interface{}, 0, len(in))
for _, inValue := range in {
values = append(values, inValue.Interface())
}

visitorErr := visitor.Before(ctx, directive, values)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
}
}
}
}

// Call method
callOut := res.Method(f.field.MethodIndex).Call(in)
result = callOut[0]

// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(returned)
}
rudle marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

if f.field.HasError && !callOut[1].IsNil() {
resolverErr := callOut[1].Interface().(error)
err := errors.Errorf("%s", resolverErr)
Expand All @@ -225,7 +266,38 @@ func execFieldSelection(ctx context.Context, r *Request, s *resolvable.Schema, f
if res.Kind() == reflect.Ptr {
res = res.Elem()
}
// Before hook directive visitor struct field
if len(f.field.Directives) > 0 {
for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
// TODO check that directive arity == 0-that should be an error at schema init time
visitorErr := visitor.Before(ctx, directive, nil)
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
}
}
}
}
result = res.FieldByIndex(f.field.FieldIndex)
// After hook directive visitor (when no error is returned from resolver)
if !f.field.HasError && len(f.field.Directives) > 0 {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that struct resolver fields can ever produce an error but I'm not an expert. Shall we delete this, Pavel?

Suggested change
if !f.field.HasError && len(f.field.Directives) > 0 {
if len(f.field.Directives) > 0 {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, struct fields cannot produce an error. It doesn't make sense to me to check for one.

for _, directive := range f.field.Directives {
if visitor, ok := r.Visitors[directive.Name.Name]; ok {
returned, visitorErr := visitor.After(ctx, directive, result.Interface())
if visitorErr != nil {
err := errors.Errorf("%s", visitorErr)
err.Path = path.toSlice()
err.ResolverError = visitorErr
return err
} else {
result = reflect.ValueOf(returned)
}
rudle marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
}
return nil
}()
Expand Down
11 changes: 10 additions & 1 deletion types/directive.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package types

import "github.com/graph-gophers/graphql-go/errors"
import (
"context"

"github.com/graph-gophers/graphql-go/errors"
)

// Directive is a representation of the GraphQL Directive.
//
Expand All @@ -24,6 +28,11 @@ type DirectiveDefinition struct {

type DirectiveList []*Directive

type DirectiveVisitor interface {
pavelnikolov marked this conversation as resolved.
Show resolved Hide resolved
Before(ctx context.Context, directive *Directive, input interface{}) error
After(ctx context.Context, directive *Directive, output interface{}) (interface{}, error)
}
pavelnikolov marked this conversation as resolved.
Show resolved Hide resolved

// Returns the Directive in the DirectiveList by name or nil if not found.
func (l DirectiveList) Get(name string) *Directive {
for _, d := range l {
Expand Down