Skip to content

Commit 19c8693

Browse files
authored
fix: add ability to exclude introspection queries from complexity limits (#1342)
Allow external users to configure if they want to skip intospection querues in the complexity limits. Follow-up in Cosmo: wundergraph/cosmo#2296
1 parent ecbac93 commit 19c8693

File tree

5 files changed

+99
-156
lines changed

5 files changed

+99
-156
lines changed

execution/graphql/complexity.go

Lines changed: 0 additions & 76 deletions
This file was deleted.

execution/graphql/request.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
1010
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
1111
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve"
12-
"github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity"
1312
"github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport"
1413
)
1514

@@ -67,23 +66,6 @@ func (r *Request) SetHeader(header http.Header) {
6766
r.request.Header = header
6867
}
6968

70-
func (r *Request) CalculateComplexity(complexityCalculator ComplexityCalculator, schema *Schema) (ComplexityResult, error) {
71-
if schema == nil {
72-
return ComplexityResult{}, ErrNilSchema
73-
}
74-
75-
report := r.parseQueryOnce()
76-
if report.HasErrors() {
77-
return complexityResult(
78-
operation_complexity.OperationStats{},
79-
[]operation_complexity.RootFieldStats{},
80-
report,
81-
)
82-
}
83-
84-
return complexityCalculator.Calculate(&r.document, &schema.document)
85-
}
86-
8769
func (r *Request) Document() *ast.Document {
8870
return &r.document
8971
}

execution/graphql/request_test.go

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/stretchr/testify/assert"
99

10+
"github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity"
1011
"github.com/wundergraph/graphql-go-tools/v2/pkg/starwars"
1112
)
1213

@@ -88,64 +89,69 @@ func TestRequest_parseQueryOnce(t *testing.T) {
8889
}
8990

9091
func TestRequest_CalculateComplexity(t *testing.T) {
91-
t.Run("should return error when schema is nil", func(t *testing.T) {
92-
request := Request{}
93-
result, err := request.CalculateComplexity(DefaultComplexityCalculator, nil)
94-
assert.Error(t, err)
95-
assert.Equal(t, ErrNilSchema, err)
96-
assert.Equal(t, 0, result.NodeCount, "unexpected node count")
97-
assert.Equal(t, 0, result.Complexity, "unexpected complexity")
98-
assert.Equal(t, 0, result.Depth, "unexpected depth")
99-
assert.Nil(t, result.PerRootField, "per root field results is not nil")
100-
})
101-
10292
t.Run("should successfully calculate the complexity of request", func(t *testing.T) {
10393
schema := StarwarsSchema(t)
104-
10594
request := StarwarsRequestForQuery(t, starwars.FileSimpleHeroQuery)
106-
result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema)
107-
assert.NoError(t, err)
108-
assert.Equal(t, 1, result.NodeCount, "unexpected node count")
109-
assert.Equal(t, 1, result.Complexity, "unexpected complexity")
110-
assert.Equal(t, 2, result.Depth, "unexpected depth")
111-
assert.Equal(t, []FieldComplexityResult{
95+
96+
report := request.parseQueryOnce()
97+
assert.False(t, report.HasErrors())
98+
99+
estimator := operation_complexity.NewOperationComplexityEstimator(false)
100+
global, rootFields := estimator.Do(request.Document(), schema.Document(), &report)
101+
assert.False(t, report.HasErrors())
102+
103+
assert.Equal(t, 1, global.NodeCount, "unexpected node count")
104+
assert.Equal(t, 1, global.Complexity, "unexpected complexity")
105+
assert.Equal(t, 2, global.Depth, "unexpected depth")
106+
assert.Equal(t, []operation_complexity.RootFieldStats{
112107
{
113-
TypeName: "Query",
114-
FieldName: "hero",
115-
Alias: "",
116-
NodeCount: 1,
117-
Complexity: 1,
118-
Depth: 1,
108+
TypeName: "Query",
109+
FieldName: "hero",
110+
Alias: "",
111+
Stats: operation_complexity.OperationStats{
112+
NodeCount: 1,
113+
Complexity: 1,
114+
Depth: 1,
115+
},
119116
},
120-
}, result.PerRootField, "unexpected per root field results")
117+
}, rootFields, "unexpected per root field results")
121118
})
122119

123120
t.Run("should successfully calculate the complexity of request with multiple query fields", func(t *testing.T) {
124121
schema := StarwarsSchema(t)
125-
126122
request := StarwarsRequestForQuery(t, starwars.FileHeroWithAliasesQuery)
127-
result, err := request.CalculateComplexity(DefaultComplexityCalculator, schema)
128-
assert.NoError(t, err)
129-
assert.Equal(t, 2, result.NodeCount, "unexpected node count")
130-
assert.Equal(t, 2, result.Complexity, "unexpected complexity")
131-
assert.Equal(t, 2, result.Depth, "unexpected depth")
132-
assert.Equal(t, []FieldComplexityResult{
123+
124+
report := request.parseQueryOnce()
125+
assert.False(t, report.HasErrors())
126+
127+
estimator := operation_complexity.NewOperationComplexityEstimator(false)
128+
global, rootFields := estimator.Do(request.Document(), schema.Document(), &report)
129+
assert.False(t, report.HasErrors())
130+
131+
assert.Equal(t, 2, global.NodeCount, "unexpected node count")
132+
assert.Equal(t, 2, global.Complexity, "unexpected complexity")
133+
assert.Equal(t, 2, global.Depth, "unexpected depth")
134+
assert.Equal(t, []operation_complexity.RootFieldStats{
133135
{
134-
TypeName: "Query",
135-
FieldName: "hero",
136-
Alias: "empireHero",
137-
NodeCount: 1,
138-
Complexity: 1,
139-
Depth: 1,
136+
TypeName: "Query",
137+
FieldName: "hero",
138+
Alias: "empireHero",
139+
Stats: operation_complexity.OperationStats{
140+
NodeCount: 1,
141+
Complexity: 1,
142+
Depth: 1,
143+
},
140144
},
141145
{
142-
TypeName: "Query",
143-
FieldName: "hero",
144-
Alias: "jediHero",
145-
NodeCount: 1,
146-
Complexity: 1,
147-
Depth: 1,
148-
}}, result.PerRootField, "unexpected per root field results")
146+
TypeName: "Query",
147+
FieldName: "hero",
148+
Alias: "jediHero",
149+
Stats: operation_complexity.OperationStats{
150+
NodeCount: 1,
151+
Complexity: 1,
152+
Depth: 1,
153+
},
154+
}}, rootFields, "unexpected per root field results")
149155
})
150156
}
151157

v2/pkg/middleware/operation_complexity/operation_complexity.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,21 @@ var (
5252
)
5353

5454
const (
55-
skipIntrospection = true
56-
__schemaLiteral = "__schema"
57-
__typeLiteral = "__type"
55+
__schemaLiteral = "__schema"
56+
__typeLiteral = "__type"
5857
)
5958

6059
type OperationComplexityEstimator struct {
6160
walker *astvisitor.Walker
6261
visitor *complexityVisitor
6362
}
6463

65-
func NewOperationComplexityEstimator() *OperationComplexityEstimator {
66-
64+
func NewOperationComplexityEstimator(skipIntrospection bool) *OperationComplexityEstimator {
6765
walker := astvisitor.NewWalker(48)
6866
visitor := &complexityVisitor{
69-
Walker: &walker,
70-
multipliers: make([]multiplier, 0, 16),
67+
Walker: &walker,
68+
multipliers: make([]multiplier, 0, 16),
69+
skipIntrospection: skipIntrospection,
7170
}
7271

7372
walker.RegisterEnterDocumentVisitor(visitor)
@@ -116,8 +115,9 @@ func (n *OperationComplexityEstimator) Do(operation, definition *ast.Document, r
116115
return globalResult, n.visitor.calculatedRootFieldStats
117116
}
118117

118+
// Deprecated: use NewOperationComplexityEstimator.
119119
func CalculateOperationComplexity(operation, definition *ast.Document, report *operationreport.Report) (OperationStats, []RootFieldStats) {
120-
estimator := NewOperationComplexityEstimator()
120+
estimator := NewOperationComplexityEstimator(false)
121121
return estimator.Do(operation, definition, report)
122122
}
123123

@@ -141,6 +141,9 @@ type complexityVisitor struct {
141141
currentRootFieldSelectionSetDepth int
142142

143143
calculatedRootFieldStats []RootFieldStats
144+
145+
// Enforces to ignore introspection queries in calculations.
146+
skipIntrospection bool
144147
}
145148

146149
type multiplier struct {
@@ -202,7 +205,7 @@ func (c *complexityVisitor) EnterField(ref int) {
202205
}
203206

204207
typeName, fieldName, alias := c.extractFieldRelatedNames(ref, definition)
205-
if skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) {
208+
if c.skipIntrospection && (fieldName == __schemaLiteral || fieldName == __typeLiteral) {
206209
c.SkipNode()
207210
return
208211
}

v2/pkg/middleware/operation_complexity/operation_complexity_test.go

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,26 @@ func TestCalculateOperationComplexity(t *testing.T) {
467467
})
468468
t.Run("introspection query", func(t *testing.T) {
469469
run(t, testDefinition, introspectionQuery,
470+
OperationStats{
471+
NodeCount: 59,
472+
Complexity: 59,
473+
Depth: 13,
474+
},
475+
[]RootFieldStats{
476+
{
477+
TypeName: "Query",
478+
FieldName: "__schema",
479+
Stats: OperationStats{
480+
NodeCount: 59,
481+
Complexity: 59,
482+
Depth: 12,
483+
},
484+
},
485+
},
486+
)
487+
})
488+
t.Run("introspection query with skip", func(t *testing.T) {
489+
runSkipIntrospection(t, testDefinition, introspectionQuery,
470490
OperationStats{
471491
NodeCount: 0,
472492
Complexity: 0,
@@ -477,35 +497,43 @@ func TestCalculateOperationComplexity(t *testing.T) {
477497
})
478498
}
479499

480-
var run = func(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
500+
func runConfig(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats, skipIntrospection bool) {
481501
def := unsafeparser.ParseGraphqlDocumentString(definition)
482502
op := unsafeparser.ParseGraphqlDocumentString(operation)
483503
report := operationreport.Report{}
484504

485505
astnormalization.NormalizeOperation(&op, &def, &report)
486506

487-
actualGlobalComplexityResult, actualFieldsComplexityResult := CalculateOperationComplexity(&op, &def, &report)
488-
if report.HasErrors() {
489-
require.NoError(t, report)
490-
}
507+
estimator := NewOperationComplexityEstimator(skipIntrospection)
508+
actualGlobalComplexityResult, actualFieldsComplexityResult := estimator.Do(&op, &def, &report)
509+
require.False(t, report.HasErrors())
491510

492511
assert.Equal(t, expectedGlobalComplexityResult.NodeCount, actualGlobalComplexityResult.NodeCount, "unexpected global node count")
493512
assert.Equal(t, expectedGlobalComplexityResult.Complexity, actualGlobalComplexityResult.Complexity, "unexpected global complexity")
494513
assert.Equal(t, expectedGlobalComplexityResult.Depth, actualGlobalComplexityResult.Depth, "unexpected global depth")
495514
assert.Equal(t, expectedFieldsComplexityResult, actualFieldsComplexityResult, "unexpected fields complexity result")
496515
}
497516

517+
func run(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
518+
runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, false)
519+
}
520+
521+
func runSkipIntrospection(t *testing.T, definition, operation string, expectedGlobalComplexityResult OperationStats, expectedFieldsComplexityResult []RootFieldStats) {
522+
runConfig(t, definition, operation, expectedGlobalComplexityResult, expectedFieldsComplexityResult, true)
523+
}
524+
498525
func BenchmarkEstimateComplexity(b *testing.B) {
499526
def := unsafeparser.ParseGraphqlDocumentString(testDefinition)
500527
op := unsafeparser.ParseGraphqlDocumentString(complexQuery)
501528

502-
estimator := NewOperationComplexityEstimator()
503-
report := operationreport.Report{}
504-
505529
b.ResetTimer()
506530
b.ReportAllocs()
507531

508532
for i := 0; i < b.N; i++ {
533+
// We use NewOperationComplexityEstimator for every operation in production, thus
534+
// we want it in the benchmarking loop.
535+
estimator := NewOperationComplexityEstimator(false)
536+
report := operationreport.Report{}
509537
globalComplexityResult, _ := estimator.Do(&op, &def, &report)
510538
if report.HasErrors() {
511539
b.Fatal(report)

0 commit comments

Comments
 (0)