diff --git a/codegen/testserver/enum.graphql b/codegen/testserver/enum.graphql new file mode 100644 index 00000000000..08559b65c6f --- /dev/null +++ b/codegen/testserver/enum.graphql @@ -0,0 +1,12 @@ +enum EnumTest { + OK + NG +} + +input InputWithEnumValue { + enum: EnumTest! +} + +extend type Query { + enumInInput(input: InputWithEnumValue): EnumTest! +} diff --git a/codegen/testserver/enums_test.go b/codegen/testserver/enums_test.go new file mode 100644 index 00000000000..15fd861ce0a --- /dev/null +++ b/codegen/testserver/enums_test.go @@ -0,0 +1,52 @@ +package testserver + +import ( + "context" + "testing" + + "github.com/99designs/gqlgen/client" + "github.com/99designs/gqlgen/handler" + "github.com/stretchr/testify/require" +) + +func TestEnumsResolver(t *testing.T) { + resolvers := &Stub{} + resolvers.QueryResolver.EnumInInput = func(ctx context.Context, input *InputWithEnumValue) (EnumTest, error) { + return input.Enum, nil + } + + c := client.New(handler.GraphQL(NewExecutableSchema(Config{Resolvers: resolvers}))) + + t.Run("input with valid enum value", func(t *testing.T) { + var resp struct { + EnumInInput EnumTest + } + c.MustPost(`query { + enumInInput(input: {enum: OK}) + } + `, &resp) + require.Equal(t, resp.EnumInInput, EnumTestOk) + }) + + t.Run("input with invalid enum value", func(t *testing.T) { + var resp struct { + EnumInInput EnumTest + } + err := c.Post(`query { + enumInInput(input: {enum: INVALID}) + } + `, &resp) + require.EqualError(t, err, `http 422: {"errors":[{"message":"Expected type EnumTest!, found INVALID.","locations":[{"line":2,"column":30}]}],"data":null}`) + }) + + t.Run("input with invalid enum value via vars", func(t *testing.T) { + var resp struct { + EnumInInput EnumTest + } + err := c.Post(`query ($input: InputWithEnumValue) { + enumInInput(input: $input) + } + `, &resp, client.Var("input", map[string]interface{}{"enum": "INVALID"})) + require.EqualError(t, err, `http 422: {"errors":[{"message":"Expected type EnumTest!, found INVALID.","locations":[{"line":2,"column":30}]}],"data":null}`) + }) +} diff --git a/codegen/testserver/generated.go b/codegen/testserver/generated.go index 0891bd075fa..d31eca65173 100644 --- a/codegen/testserver/generated.go +++ b/codegen/testserver/generated.go @@ -224,6 +224,7 @@ type ComplexityRoot struct { DirectiveObject func(childComplexity int) int DirectiveObjectWithCustomGoModel func(childComplexity int) int DirectiveUnimplemented func(childComplexity int) int + EnumInInput func(childComplexity int, input *InputWithEnumValue) int ErrorBubble func(childComplexity int) int Errors func(childComplexity int) int Fallback func(childComplexity int, arg FallbackToStringEncoding) int @@ -366,6 +367,7 @@ type QueryResolver interface { DirectiveField(ctx context.Context) (*string, error) DirectiveDouble(ctx context.Context) (*string, error) DirectiveUnimplemented(ctx context.Context) (*string, error) + EnumInInput(ctx context.Context, input *InputWithEnumValue) (EnumTest, error) Shapes(ctx context.Context) ([]Shape, error) NoShape(ctx context.Context) (Shape, error) MapStringInterface(ctx context.Context, in map[string]interface{}) (map[string]interface{}, error) @@ -926,6 +928,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.DirectiveUnimplemented(childComplexity), true + case "Query.enumInInput": + if e.complexity.Query.EnumInInput == nil { + break + } + + args, err := ec.field_Query_enumInInput_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.EnumInInput(childComplexity, args["input"].(*InputWithEnumValue)), true + case "Query.errorBubble": if e.complexity.Query.ErrorBubble == nil { break @@ -1518,6 +1532,19 @@ type ObjectDirectives { type ObjectDirectivesWithCustomGoModel { nullableText: String @toNull } +`}, + &ast.Source{Name: "enum.graphql", Input: `enum EnumTest { + OK + NG +} + +input InputWithEnumValue { + enum: EnumTest! +} + +extend type Query { + enumInInput(input: InputWithEnumValue): EnumTest! +} `}, &ast.Source{Name: "interfaces.graphql", Input: `extend type Query { shapes: [Shape] @@ -2184,6 +2211,20 @@ func (ec *executionContext) field_Query_directiveNullableArg_args(ctx context.Co return args, nil } +func (ec *executionContext) field_Query_enumInInput_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 *InputWithEnumValue + if tmp, ok := rawArgs["input"]; ok { + arg0, err = ec.unmarshalOInputWithEnumValue2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐInputWithEnumValue(ctx, tmp) + if err != nil { + return nil, err + } + } + args["input"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_fallback_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -5573,6 +5614,47 @@ func (ec *executionContext) _Query_directiveUnimplemented(ctx context.Context, f return ec.marshalOString2ᚖstring(ctx, field.Selections, res) } +func (ec *executionContext) _Query_enumInInput(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + ctx = ec.Tracer.StartFieldExecution(ctx, field) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + ec.Tracer.EndFieldExecution(ctx) + }() + rctx := &graphql.ResolverContext{ + Object: "Query", + Field: field, + Args: nil, + IsMethod: true, + } + ctx = graphql.WithResolverContext(ctx, rctx) + rawArgs := field.ArgumentMap(ec.Variables) + args, err := ec.field_Query_enumInInput_args(ctx, rawArgs) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + rctx.Args = args + ctx = ec.Tracer.StartFieldResolverExecution(ctx, rctx) + resTmp := ec._fieldMiddleware(ctx, nil, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().EnumInInput(rctx, args["input"].(*InputWithEnumValue)) + }) + + if resTmp == nil { + if !ec.HasError(rctx) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(EnumTest) + rctx.Result = res + ctx = ec.Tracer.StartFieldChildExecution(ctx) + return ec.marshalNEnumTest2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEnumTest(ctx, field.Selections, res) +} + func (ec *executionContext) _Query_shapes(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { ctx = ec.Tracer.StartFieldExecution(ctx, field) defer func() { @@ -8473,6 +8555,24 @@ func (ec *executionContext) unmarshalInputInputDirectives(ctx context.Context, o return it, nil } +func (ec *executionContext) unmarshalInputInputWithEnumValue(ctx context.Context, obj interface{}) (InputWithEnumValue, error) { + var it InputWithEnumValue + var asMap = obj.(map[string]interface{}) + + for k, v := range asMap { + switch k { + case "enum": + var err error + it.Enum, err = ec.unmarshalNEnumTest2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEnumTest(ctx, v) + if err != nil { + return it, err + } + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputNestedMapInput(ctx context.Context, obj interface{}) (NestedMapInput, error) { var it NestedMapInput var asMap = obj.(map[string]interface{}) @@ -10050,6 +10150,20 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr res = ec._Query_directiveUnimplemented(ctx, field) return res }) + case "enumInInput": + field := field + out.Concurrently(i, func() (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_enumInInput(ctx, field) + if res == graphql.Null { + atomic.AddUint32(&invalids, 1) + } + return res + }) case "shapes": field := field out.Concurrently(i, func() (res graphql.Marshaler) { @@ -10893,6 +11007,15 @@ func (ec *executionContext) marshalNDefaultScalarImplementation2string(ctx conte return res } +func (ec *executionContext) unmarshalNEnumTest2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEnumTest(ctx context.Context, v interface{}) (EnumTest, error) { + var res EnumTest + return res, res.UnmarshalGQL(v) +} + +func (ec *executionContext) marshalNEnumTest2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐEnumTest(ctx context.Context, sel ast.SelectionSet, v EnumTest) graphql.Marshaler { + return v +} + func (ec *executionContext) marshalNError2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐError(ctx context.Context, sel ast.SelectionSet, v Error) graphql.Marshaler { return ec._Error(ctx, sel, &v) } @@ -11749,6 +11872,18 @@ func (ec *executionContext) unmarshalOInputDirectives2ᚖgithubᚗcomᚋ99design return &res, err } +func (ec *executionContext) unmarshalOInputWithEnumValue2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐInputWithEnumValue(ctx context.Context, v interface{}) (InputWithEnumValue, error) { + return ec.unmarshalInputInputWithEnumValue(ctx, v) +} + +func (ec *executionContext) unmarshalOInputWithEnumValue2ᚖgithubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐInputWithEnumValue(ctx context.Context, v interface{}) (*InputWithEnumValue, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalOInputWithEnumValue2githubᚗcomᚋ99designsᚋgqlgenᚋcodegenᚋtestserverᚐInputWithEnumValue(ctx, v) + return &res, err +} + func (ec *executionContext) unmarshalOInt2int(ctx context.Context, v interface{}) (int, error) { return graphql.UnmarshalInt(v) } diff --git a/codegen/testserver/generated_test.go b/codegen/testserver/generated_test.go index 38362c131f5..7b12da9645a 100644 --- a/codegen/testserver/generated_test.go +++ b/codegen/testserver/generated_test.go @@ -31,6 +31,11 @@ func TestEnums(t *testing.T) { require.Equal(t, StatusOk, AllStatus[0]) require.Equal(t, StatusError, AllStatus[1]) }) + + t.Run("invalid enum values", func(t *testing.T) { + require.Equal(t, StatusOk, AllStatus[0]) + require.Equal(t, StatusError, AllStatus[1]) + }) } func TestUnionFragments(t *testing.T) { diff --git a/codegen/testserver/models-gen.go b/codegen/testserver/models-gen.go index 197ca86b807..5e1f78f7df0 100644 --- a/codegen/testserver/models-gen.go +++ b/codegen/testserver/models-gen.go @@ -73,6 +73,10 @@ type InputDirectives struct { ThirdParty *ThirdParty `json:"thirdParty"` } +type InputWithEnumValue struct { + Enum EnumTest `json:"enum"` +} + type LoopA struct { B *LoopB `json:"b"` } @@ -171,6 +175,47 @@ type IIt struct { ID string `json:"id"` } +type EnumTest string + +const ( + EnumTestOk EnumTest = "OK" + EnumTestNg EnumTest = "NG" +) + +var AllEnumTest = []EnumTest{ + EnumTestOk, + EnumTestNg, +} + +func (e EnumTest) IsValid() bool { + switch e { + case EnumTestOk, EnumTestNg: + return true + } + return false +} + +func (e EnumTest) String() string { + return string(e) +} + +func (e *EnumTest) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = EnumTest(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid EnumTest", str) + } + return nil +} + +func (e EnumTest) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) +} + type Status string const ( diff --git a/codegen/testserver/resolver.go b/codegen/testserver/resolver.go index 29df1b49451..c05bc27e574 100644 --- a/codegen/testserver/resolver.go +++ b/codegen/testserver/resolver.go @@ -179,6 +179,9 @@ func (r *queryResolver) DirectiveDouble(ctx context.Context) (*string, error) { func (r *queryResolver) DirectiveUnimplemented(ctx context.Context) (*string, error) { panic("not implemented") } +func (r *queryResolver) EnumInInput(ctx context.Context, input *InputWithEnumValue) (EnumTest, error) { + panic("not implemented") +} func (r *queryResolver) Shapes(ctx context.Context) ([]Shape, error) { panic("not implemented") } diff --git a/codegen/testserver/stub.go b/codegen/testserver/stub.go index bcc7ad7f243..4c6c1d78cf6 100644 --- a/codegen/testserver/stub.go +++ b/codegen/testserver/stub.go @@ -63,6 +63,7 @@ type Stub struct { DirectiveField func(ctx context.Context) (*string, error) DirectiveDouble func(ctx context.Context) (*string, error) DirectiveUnimplemented func(ctx context.Context) (*string, error) + EnumInInput func(ctx context.Context, input *InputWithEnumValue) (EnumTest, error) Shapes func(ctx context.Context) ([]Shape, error) NoShape func(ctx context.Context) (Shape, error) MapStringInterface func(ctx context.Context, in map[string]interface{}) (map[string]interface{}, error) @@ -263,6 +264,9 @@ func (r *stubQuery) DirectiveDouble(ctx context.Context) (*string, error) { func (r *stubQuery) DirectiveUnimplemented(ctx context.Context) (*string, error) { return r.QueryResolver.DirectiveUnimplemented(ctx) } +func (r *stubQuery) EnumInInput(ctx context.Context, input *InputWithEnumValue) (EnumTest, error) { + return r.QueryResolver.EnumInInput(ctx, input) +} func (r *stubQuery) Shapes(ctx context.Context) ([]Shape, error) { return r.QueryResolver.Shapes(ctx) }