From 64f8084368777e9e815b5191cef4e0c68fb07a12 Mon Sep 17 00:00:00 2001 From: Igor <9917165+ostrea@users.noreply.github.com> Date: Wed, 20 Jul 2022 13:50:46 +0300 Subject: [PATCH] Disallow repeat of non repeatable directives (#525) * Disallow repeat of non repeatable directives * Remove unnecessary scallar --- internal/schema/schema.go | 9 +++++++++ internal/schema/schema_test.go | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/internal/schema/schema.go b/internal/schema/schema.go index da37fe1f..be92f181 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -289,6 +289,7 @@ func resolveField(s *types.Schema, f *types.FieldDefinition) error { } func resolveDirectives(s *types.Schema, directives types.DirectiveList, loc string) error { + alreadySeenNonRepeatable := make(map[string]struct{}) for _, d := range directives { dirName := d.Name.Name dd, ok := s.Directives[dirName] @@ -315,6 +316,14 @@ func resolveDirectives(s *types.Schema, directives types.DirectiveList, loc stri d.Arguments = append(d.Arguments, &types.Argument{Name: arg.Name, Value: arg.Default}) } } + + if dd.Repeatable { + continue + } + if _, seen := alreadySeenNonRepeatable[dirName]; seen { + return errors.Errorf(`non repeatable directive %q can not be repeated. Consider adding "repeatable".`, dirName) + } + alreadySeenNonRepeatable[dirName] = struct{}{} } return nil } diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 726546b1..f57a33e7 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -920,6 +920,22 @@ Second line of the description. return nil }, }, + { + name: "Disallow repeat of a directive if it is not `repeatable`", + sdl: ` + directive @nonrepeatabledirective on FIELD_DEFINITION + type Foo { + bar: String @nonrepeatabledirective @nonrepeatabledirective + } + `, + validateError: func(err error) error { + prefix := `graphql: non repeatable directive "nonrepeatabledirective" can not be repeated. Consider adding "repeatable"` + if err == nil || !strings.HasPrefix(err.Error(), prefix) { + return fmt.Errorf("expected error starting with %q, but got %q", prefix, err) + } + return nil + }, + }, } { t.Run(test.name, func(t *testing.T) { s, err := schema.ParseSchema(test.sdl, test.useStringDescriptions)