Skip to content

Commit

Permalink
feat: add max query depth (wundergraph#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
df-wg authored Sep 10, 2024
1 parent a1469b1 commit 5475a96
Show file tree
Hide file tree
Showing 9 changed files with 319 additions and 14 deletions.
170 changes: 170 additions & 0 deletions router-tests/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/wundergraph/cosmo/router/pkg/otel"
"github.com/wundergraph/cosmo/router/pkg/trace/tracetest"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/trace"
tracetest2 "go.opentelemetry.io/otel/sdk/trace/tracetest"
"io"
"math/rand"
"net/http"
Expand Down Expand Up @@ -965,6 +970,171 @@ func TestBlockNonPersistedOperations(t *testing.T) {
})
}

func TestQueryDepthLimit(t *testing.T) {
t.Parallel()
t.Run("max query depth of 0 doesn't block", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 0
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body)
})
})

t.Run("allows queries up to the max depth", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 3
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.JSONEq(t, `{"data":{"employee":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}}`, res.Body)
})
})

t.Run("max query depth blocks queries over the limit", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, res.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}],"data":null}`, res.Body)
})
})

t.Run("max query depth blocks persisted queries over the limit", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, _ := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
OperationName: []byte(`Find`),
Variables: []byte(`{"criteria": {"nationality": "GERMAN" }}`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "e33580cf6276de9a75fb3b1c4b7580fec2a1c8facd13f3487bf6c7c3f854f7e3"}}`),
Header: header,
})
require.Equal(t, 400, res.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}],"data":null}`, res.Body)
})
})

t.Run("max query depth doesn't block persisted queries if DisableDepthLimitPersistedOperations set", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
securityConfiguration.DepthLimit.IgnorePersistedOperations = true
},
}, func(t *testing.T, xEnv *testenv.Environment) {
header := make(http.Header)
header.Add("graphql-client-name", "my-client")
res, _ := xEnv.MakeGraphQLRequestOverGET(testenv.GraphQLRequest{
OperationName: []byte(`Find`),
Variables: []byte(`{"criteria": {"nationality": "GERMAN" }}`),
Extensions: []byte(`{"persistedQuery": {"version": 1, "sha256Hash": "e33580cf6276de9a75fb3b1c4b7580fec2a1c8facd13f3487bf6c7c3f854f7e3"}}`),
Header: header,
})
require.Equal(t, 200, res.Response.StatusCode)
//require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}],"data":null}`, res.Body)
})
})

t.Run("query depth validation caches success and failure runs", func(t *testing.T) {
t.Parallel()

metricReader := metric.NewManualReader()
exporter := tracetest.NewInMemoryExporter(t)
testenv.Run(t, &testenv.Config{
TraceExporter: exporter,
MetricReader: metricReader,
ModifySecurityConfiguration: func(securityConfiguration *config.SecurityConfiguration) {
securityConfiguration.DepthLimit.Enabled = true
securityConfiguration.DepthLimit.Limit = 2
securityConfiguration.DepthLimit.CacheSize = 1024
},
}, func(t *testing.T, xEnv *testenv.Environment) {
failedRes, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, failedRes.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}],"data":null}`, failedRes.Body)

testSpan := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan.Attributes(), otel.WgQueryDepth.Int(3))
require.Contains(t, testSpan.Attributes(), otel.WgQueryDepthCacheHit.Bool(false))
exporter.Reset()

failedRes2, _ := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `{ employee(id:1) { id details { forename surname } } }`,
})
require.Equal(t, 400, failedRes2.Response.StatusCode)
require.Equal(t, `{"errors":[{"message":"The query depth 3 exceeds the max query depth allowed (2)"}],"data":null}`, failedRes2.Body)

testSpan2 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepth.Int(3))
require.Contains(t, testSpan2.Attributes(), otel.WgQueryDepthCacheHit.Bool(true))
exporter.Reset()

successRes := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, successRes.Body)
testSpan3 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepth.Int(2))
require.Contains(t, testSpan3.Attributes(), otel.WgQueryDepthCacheHit.Bool(false))
exporter.Reset()

successRes2 := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, successRes2.Body)
testSpan4 := requireSpanWithName(t, exporter, "Operation - Validate")
require.Contains(t, testSpan4.Attributes(), otel.WgQueryDepth.Int(2))
require.Contains(t, testSpan4.Attributes(), otel.WgQueryDepthCacheHit.Bool(true))
})
})
}

func requireSpanWithName(t *testing.T, exporter *tracetest2.InMemoryExporter, name string) trace.ReadOnlySpan {
sn := exporter.GetSpans().Snapshots()
var testSpan trace.ReadOnlySpan
for _, span := range sn {
if span.Name() == name {
testSpan = span
break
}
}
require.NotNil(t, testSpan)
return testSpan
}

func TestPartialOriginErrors(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
Expand Down
20 changes: 20 additions & 0 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ type graphMux struct {
planCache ExecutionPlanCache[uint64, *planWithMetaData]
normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
queryDepthCache *ristretto.Cache[uint64, int]
}

func (s *graphMux) Shutdown(_ context.Context) {
Expand All @@ -311,6 +312,9 @@ func (s *graphMux) Shutdown(_ context.Context) {
if s.validationCache != nil {
s.validationCache.Close()
}
if s.queryDepthCache != nil {
s.queryDepthCache.Close()
}
}

// buildGraphMux creates a new graph mux with the given feature flags and engine configuration.
Expand Down Expand Up @@ -393,6 +397,18 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
}
}

if s.securityConfiguration.DepthLimit.Enabled && s.securityConfiguration.DepthLimit.CacheSize > 0 {
queryDepthCacheConfig := &ristretto.Config[uint64, int]{
MaxCost: s.securityConfiguration.DepthLimit.CacheSize,
NumCounters: s.securityConfiguration.DepthLimit.CacheSize * 10,
BufferItems: 64,
}
gm.queryDepthCache, err = ristretto.NewCache[uint64, int](queryDepthCacheConfig)
if err != nil {
return nil, fmt.Errorf("failed to create query depth cache: %w", err)
}
}

metrics := NewRouterMetrics(&routerMetricsConfig{
metrics: s.metricStore,
gqlMetricsExporter: s.gqlMetricsExporter,
Expand Down Expand Up @@ -575,6 +591,7 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
EnablePersistedOperationsCache: s.engineExecutionConfiguration.EnablePersistedOperationsCache,
NormalizationCache: gm.normalizationCache,
ValidationCache: gm.validationCache,
QueryDepthCache: gm.queryDepthCache,
ParseKitPoolSize: s.engineExecutionConfiguration.ParseKitPoolSize,
})
operationPlanner := NewOperationPlanner(executor, gm.planCache)
Expand Down Expand Up @@ -635,6 +652,9 @@ func (s *graphServer) buildGraphMux(ctx context.Context,
FileUploadEnabled: s.fileUploadConfig.Enabled,
MaxUploadFiles: s.fileUploadConfig.MaxFiles,
MaxUploadFileSize: int(s.fileUploadConfig.MaxFileSizeBytes),
QueryDepthEnabled: s.securityConfiguration.DepthLimit.Enabled,
QueryDepthLimit: s.securityConfiguration.DepthLimit.Limit,
QueryIgnorePersistent: s.securityConfiguration.DepthLimit.IgnorePersistedOperations,
AlwaysIncludeQueryPlan: s.engineExecutionConfiguration.Debug.AlwaysIncludeQueryPlan,
AlwaysSkipLoader: s.engineExecutionConfiguration.Debug.AlwaysSkipLoader,
QueryPlansEnabled: s.Config.queryPlansEnabled,
Expand Down
45 changes: 34 additions & 11 deletions router/core/graphql_prehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,20 @@ import (
)

type PreHandlerOptions struct {
Logger *zap.Logger
Executor *Executor
Metrics RouterMetrics
OperationProcessor *OperationProcessor
Planner *OperationPlanner
AccessController *AccessController
OperationBlocker *OperationBlocker
RouterPublicKey *ecdsa.PublicKey
TracerProvider *sdktrace.TracerProvider
MaxUploadFiles int
MaxUploadFileSize int
Logger *zap.Logger
Executor *Executor
Metrics RouterMetrics
OperationProcessor *OperationProcessor
Planner *OperationPlanner
AccessController *AccessController
OperationBlocker *OperationBlocker
RouterPublicKey *ecdsa.PublicKey
TracerProvider *sdktrace.TracerProvider
MaxUploadFiles int
MaxUploadFileSize int
QueryDepthEnabled bool
QueryDepthLimit int
QueryIgnorePersistent bool

FlushTelemetryAfterResponse bool
FileUploadEnabled bool
Expand Down Expand Up @@ -76,6 +79,9 @@ type PreHandler struct {
fileUploadEnabled bool
maxUploadFiles int
maxUploadFileSize int
queryDepthEnabled bool
queryDepthLimit int
queryIgnorePersistent bool
bodyReadBuffers *sync.Pool
trackSchemaUsageInfo bool
}
Expand Down Expand Up @@ -115,6 +121,9 @@ func NewPreHandler(opts *PreHandlerOptions) *PreHandler {
fileUploadEnabled: opts.FileUploadEnabled,
maxUploadFiles: opts.MaxUploadFiles,
maxUploadFileSize: opts.MaxUploadFileSize,
queryDepthEnabled: opts.QueryDepthEnabled,
queryDepthLimit: opts.QueryDepthLimit,
queryIgnorePersistent: opts.QueryIgnorePersistent,
bodyReadBuffers: &sync.Pool{},
alwaysIncludeQueryPlan: opts.AlwaysIncludeQueryPlan,
alwaysSkipLoader: opts.AlwaysSkipLoader,
Expand Down Expand Up @@ -552,6 +561,20 @@ func (h *PreHandler) handleOperation(req *http.Request, buf *bytes.Buffer, httpO
// this allows us to generate query plans without having to provide variables
engineValidateSpan.SetAttributes(otel.WgVariablesValidationSkipped.Bool(true))
}

// Validate that the planned query doesn't exceed the maximum query depth configured
// This check runs if they've configured a max query depth, and it can optionally be turned off for persisted operations
if h.queryDepthEnabled && h.queryDepthLimit > 0 && (!operationKit.parsedOperation.IsPersistedOperation || operationKit.parsedOperation.IsPersistedOperation && !h.queryIgnorePersistent) {
cacheHit, depth, queryDepthErr := operationKit.ValidateQueryDepth(h.queryDepthLimit, operationKit.kit.doc, h.executor.RouterSchema)
engineValidateSpan.SetAttributes(otel.WgQueryDepth.Int(depth))
engineValidateSpan.SetAttributes(otel.WgQueryDepthCacheHit.Bool(cacheHit))
if queryDepthErr != nil {
rtrace.AttachErrToSpan(engineValidateSpan, err)
engineValidateSpan.End()

return nil, queryDepthErr
}
}
engineValidateSpan.End()

httpOperation.traceTimings.EndValidate()
Expand Down
42 changes: 42 additions & 0 deletions router/core/operation_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/wundergraph/graphql-go-tools/v2/pkg/middleware/operation_complexity"
"io"
"net/http"
"net/url"
Expand Down Expand Up @@ -85,6 +86,7 @@ type OperationProcessorOptions struct {
EnablePersistedOperationsCache bool
NormalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
ValidationCache *ristretto.Cache[uint64, bool]
QueryDepthCache *ristretto.Cache[uint64, int]
ParseKitPoolSize int
}

Expand Down Expand Up @@ -122,6 +124,7 @@ type OperationCache struct {

normalizationCache *ristretto.Cache[uint64, NormalizationCacheEntry]
validationCache *ristretto.Cache[uint64, bool]
queryDepthCache *ristretto.Cache[uint64, int]
}

// OperationKit provides methods to parse, normalize and validate operations.
Expand Down Expand Up @@ -740,6 +743,39 @@ func (o *OperationKit) Validate(skipLoader bool) (cacheHit bool, err error) {
return
}

// ValidateQueryDepth validates that the operation query depth isn't greater than the max query depth.
func (o *OperationKit) ValidateQueryDepth(maxQueryDepth int, operation, definition *ast.Document) (bool, int, error) {
if o.cache != nil && o.cache.queryDepthCache != nil {
depth, cacheHit := o.cache.queryDepthCache.Get(o.parsedOperation.ID)
if cacheHit {
valid := depth <= maxQueryDepth
if !valid {
return cacheHit, depth, &httpGraphqlError{
message: fmt.Sprintf("The query depth %d exceeds the max query depth allowed (%d)", depth, maxQueryDepth),
statusCode: http.StatusBadRequest,
}
}
return cacheHit, depth, nil
}
}

report := operationreport.Report{}
globalComplexityResult, _ := operation_complexity.CalculateOperationComplexity(operation, definition, &report)
valid := globalComplexityResult.Depth <= maxQueryDepth

if o.cache != nil && o.cache.queryDepthCache != nil {
o.cache.queryDepthCache.Set(o.parsedOperation.ID, globalComplexityResult.Depth, 1)
}

if !valid {
return false, globalComplexityResult.Depth, &httpGraphqlError{
message: fmt.Sprintf("The query depth %d exceeds the max query depth allowed (%d)", globalComplexityResult.Depth, maxQueryDepth),
statusCode: http.StatusBadRequest,
}
}
return false, globalComplexityResult.Depth, nil
}

var (
literalIF = []byte("if")
)
Expand Down Expand Up @@ -827,6 +863,12 @@ func NewOperationProcessor(opts OperationProcessorOptions) *OperationProcessor {
}
processor.operationCache.validationCache = opts.ValidationCache
}
if opts.QueryDepthCache != nil {
if processor.operationCache == nil {
processor.operationCache = &OperationCache{}
}
processor.operationCache.queryDepthCache = opts.QueryDepthCache
}
return processor
}

Expand Down
Loading

0 comments on commit 5475a96

Please sign in to comment.