Skip to content
93 changes: 68 additions & 25 deletions router/core/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"os"
"time"

"github.com/wundergraph/cosmo/router/pkg/metric"

Expand Down Expand Up @@ -53,6 +54,26 @@ type Planner struct {
operationValidator *astvalidation.OperationValidator
}

type OperationTimes struct {
ParseTime time.Duration
NormalizeTime time.Duration
ValidateTime time.Duration
PlanTime time.Duration
}

func (ot *OperationTimes) TotalTime() time.Duration {
return ot.ParseTime + ot.NormalizeTime + ot.ValidateTime + ot.PlanTime
}

func (ot OperationTimes) Merge(other OperationTimes) OperationTimes {
return OperationTimes{
ParseTime: ot.ParseTime + other.ParseTime,
NormalizeTime: ot.NormalizeTime + other.NormalizeTime,
ValidateTime: ot.ValidateTime + other.ValidateTime,
PlanTime: ot.PlanTime + other.PlanTime,
}
}

type PlanOutputFormat string

const (
Expand All @@ -75,57 +96,74 @@ func NewPlanner(planConfiguration *plan.Configuration, definition *ast.Document,
}

// PlanOperation creates a query plan from an operation file in a pretty-printed text or JSON format
func (pl *Planner) PlanOperation(operationFilePath string, outputFormat PlanOutputFormat) (string, error) {
operation, err := pl.ParseAndPrepareOperation(operationFilePath)
func (pl *Planner) PlanOperation(operationFilePath string, outputFormat PlanOutputFormat) (string, OperationTimes, error) {
operation, opTimes, err := pl.ParseAndPrepareOperation(operationFilePath)
if err != nil {
return "", err
return "", opTimes, err
}

rawPlan, err := pl.PlanPreparedOperation(operation)
rawPlan, opTimes2, err := pl.PlanPreparedOperation(operation)
opTimes = opTimes.Merge(opTimes2)
if err != nil {
return "", fmt.Errorf("failed to plan operation: %w", err)
return "", opTimes, fmt.Errorf("failed to plan operation: %w", err)
}

switch outputFormat {
case PlanOutputFormatText:
return rawPlan.PrettyPrint(), nil
return rawPlan.PrettyPrint(), opTimes, nil
case PlanOutputFormatJSON:
marshal, err := json.Marshal(rawPlan)
if err != nil {
return "", fmt.Errorf("failed to marshal raw plan: %w", err)
return "", opTimes, fmt.Errorf("failed to marshal raw plan: %w", err)
}
return string(marshal), nil
return string(marshal), opTimes, nil
}

return "", fmt.Errorf("invalid outputFormat specified: %q", outputFormat)
return "", opTimes, fmt.Errorf("invalid outputFormat specified: %q", outputFormat)
}

// ParseAndPrepareOperation parses, normalizes and validates the operation
func (pl *Planner) ParseAndPrepareOperation(operationFilePath string) (*ast.Document, error) {
func (pl *Planner) ParseAndPrepareOperation(operationFilePath string) (*ast.Document, OperationTimes, error) {
start := time.Now()
operation, err := pl.parseOperation(operationFilePath)
parseTime := time.Since(start)
if err != nil {
return nil, &PlannerOperationValidationError{err: err}
return nil, OperationTimes{ParseTime: parseTime}, &PlannerOperationValidationError{err: err}
}

return pl.PrepareOperation(operation)
operation, opTimes, err := pl.PrepareOperation(operation)
opTimes.ParseTime = parseTime
if err != nil {
return nil, opTimes, err
}

return operation, opTimes, nil
}

// PrepareOperation normalizes and validates the operation
func (pl *Planner) PrepareOperation(operation *ast.Document) (*ast.Document, error) {
func (pl *Planner) PrepareOperation(operation *ast.Document) (*ast.Document, OperationTimes, error) {
operationName := findOperationName(operation)
if operationName == nil {
return nil, &PlannerOperationValidationError{err: errors.New("operation name not found")}
return nil, OperationTimes{}, &PlannerOperationValidationError{err: errors.New("operation name not found")}
}

if err := pl.normalizeOperation(operation, operationName); err != nil {
return nil, &PlannerOperationValidationError{err: err}
opTimes := OperationTimes{}

start := time.Now()
err := pl.normalizeOperation(operation, operationName)
opTimes.NormalizeTime = time.Since(start)
if err != nil {
return nil, opTimes, &PlannerOperationValidationError{err: err}
}

if err := pl.validateOperation(operation); err != nil {
return nil, &PlannerOperationValidationError{err: err}
start = time.Now()
err = pl.validateOperation(operation)
opTimes.ValidateTime = time.Since(start)
if err != nil {
return nil, opTimes, &PlannerOperationValidationError{err: err}
}

return operation, nil
return operation, opTimes, nil
}

func (pl *Planner) normalizeOperation(operation *ast.Document, operationName []byte) (err error) {
Expand Down Expand Up @@ -160,7 +198,7 @@ func (pl *Planner) normalizeOperation(operation *ast.Document, operationName []b
}

// PlanPreparedOperation creates a query plan from a normalized and validated operation
func (pl *Planner) PlanPreparedOperation(operation *ast.Document) (planNode *resolve.FetchTreeQueryPlanNode, err error) {
func (pl *Planner) PlanPreparedOperation(operation *ast.Document) (planNode *resolve.FetchTreeQueryPlanNode, opTimes OperationTimes, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during plan generation: %v", r)
Expand All @@ -172,25 +210,30 @@ func (pl *Planner) PlanPreparedOperation(operation *ast.Document) (planNode *res
operationName := findOperationName(operation)

if operationName == nil {
return nil, errors.New("operation name not found")
return nil, opTimes, errors.New("operation name not found")
}

// create and postprocess the plan
start := time.Now()
preparedPlan := pl.planner.Plan(operation, pl.definition, string(operationName), &report, plan.IncludeQueryPlanInResponse())
opTimes.PlanTime = time.Since(start)
if report.HasErrors() {
return nil, errors.New(report.Error())
return nil, opTimes, errors.New(report.Error())
}

post := postprocess.NewProcessor()
post.Process(preparedPlan)
// measure postprocessing time as part of planning time
opTimes.PlanTime = time.Since(start)

switch p := preparedPlan.(type) {
case *plan.SynchronousResponsePlan:
return p.Response.Fetches.QueryPlan(), nil
return p.Response.Fetches.QueryPlan(), opTimes, nil
case *plan.SubscriptionResponsePlan:
return p.Response.Response.Fetches.QueryPlan(), nil
return p.Response.Response.Fetches.QueryPlan(), opTimes, nil
}

return &resolve.FetchTreeQueryPlanNode{}, nil
return &resolve.FetchTreeQueryPlanNode{}, opTimes, nil
}

func (pl *Planner) validateOperation(operation *ast.Document) (err error) {
Expand Down
2 changes: 1 addition & 1 deletion router/core/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestPlanOperationPanic(t *testing.T) {
}

assert.NotPanics(t, func() {
_, err = planner.PlanPreparedOperation(invalidOperation)
_, _, err = planner.PlanPreparedOperation(invalidOperation)
assert.Error(t, err)
})
}
Expand Down
8 changes: 4 additions & 4 deletions router/internal/planningbenchmark/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func TestPlanning(t *testing.T) {
pl, err := pg.GetPlanner()
require.NoError(t, err)

opDoc, err := pl.ParseAndPrepareOperation(cfg.OperationPath)
opDoc, _, err := pl.ParseAndPrepareOperation(cfg.OperationPath)
require.NoError(t, err)

start := time.Now()
p, err := pl.PlanPreparedOperation(opDoc)
p, _, err := pl.PlanPreparedOperation(opDoc)
require.NoError(t, err)
t.Logf("Planning completed in %v", time.Since(start))

Expand Down Expand Up @@ -69,12 +69,12 @@ func BenchmarkPlanning(b *testing.B) {

for b.Loop() {
b.StopTimer()
opDoc, err := pl.ParseAndPrepareOperation(cfg.OperationPath)
opDoc, _, err := pl.ParseAndPrepareOperation(cfg.OperationPath)
require.NoError(b, err)
b.SetBytes(int64(len(opDoc.Input.RawBytes)))
b.StartTimer()

_, err = pl.PlanPreparedOperation(opDoc)
_, _, err = pl.PlanPreparedOperation(opDoc)
require.NoError(b, err)
}
}
16 changes: 9 additions & 7 deletions router/pkg/plan_generator/plan_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ type QueryPlanResults struct {
}

type QueryPlanResult struct {
FileName string `json:"file_name,omitempty"`
Plan string `json:"plan,omitempty"`
Error string `json:"error,omitempty"`
Warning string `json:"warning,omitempty"`
FileName string `json:"file_name,omitempty"`
Plan string `json:"plan,omitempty"`
Error string `json:"error,omitempty"`
Warning string `json:"warning,omitempty"`
Timings core.OperationTimes `json:"timings,omitempty"`
}

func PlanGenerator(ctx context.Context, cfg QueryPlanConfig) error {
Expand Down Expand Up @@ -145,10 +146,11 @@ func PlanGenerator(ctx context.Context, cfg QueryPlanConfig) error {

queryFilePath := filepath.Join(queriesPath, queryFile.Name())

outContent, err := planner.PlanOperation(queryFilePath, cfg.OutputFormat)
outContent, opTimes, err := planner.PlanOperation(queryFilePath, cfg.OutputFormat)
res := QueryPlanResult{
FileName: queryFile.Name(),
Plan: outContent,
FileName: queryFile.Name(),
Plan: outContent,
Timings: opTimes,
}
if err != nil {
if _, ok := err.(*core.PlannerOperationValidationError); ok {
Expand Down
35 changes: 31 additions & 4 deletions router/pkg/plan_generator/plan_generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package plan_generator
import (
"context"
"encoding/json"
"github.com/wundergraph/cosmo/router/core"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"testing"

"github.com/wundergraph/cosmo/router/core"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
Expand Down Expand Up @@ -210,7 +211,16 @@ func TestPlanGenerator(t *testing.T) {
assert.NoError(t, err)
resultsExpected, err := os.ReadFile(path.Join(getTestDataDir(), "plans", "base", ReportFileName))
assert.NoError(t, err)
assert.Equal(t, string(resultsExpected), string(results))
resultsStruct := QueryPlanResults{}
json.Unmarshal(results, &resultsStruct)
resultsExpectedStruct := QueryPlanResults{}
json.Unmarshal(resultsExpected, &resultsExpectedStruct)
require.Len(t, resultsStruct.Plans, len(resultsExpectedStruct.Plans))
for i := range resultsStruct.Plans {
assert.Equal(t, resultsStruct.Plans[i].Plan, resultsExpectedStruct.Plans[i].Plan)
assert.Equal(t, resultsStruct.Plans[i].Error, resultsExpectedStruct.Plans[i].Error)
assert.Equal(t, resultsStruct.Plans[i].Warning, resultsExpectedStruct.Plans[i].Warning)
}
})

t.Run("will not fail on warnings and results should return the warnings and generate results file", func(t *testing.T) {
Expand Down Expand Up @@ -241,7 +251,16 @@ func TestPlanGenerator(t *testing.T) {
assert.NoError(t, err)
resultsExpected, err := os.ReadFile(path.Join(getTestDataDir(), "plans", "base", ReportFileName))
assert.NoError(t, err)
assert.Equal(t, string(resultsExpected), string(results))
resultsStruct := QueryPlanResults{}
json.Unmarshal(results, &resultsStruct)
resultsExpectedStruct := QueryPlanResults{}
json.Unmarshal(resultsExpected, &resultsExpectedStruct)
require.Len(t, resultsStruct.Plans, len(resultsExpectedStruct.Plans))
for i := range resultsStruct.Plans {
assert.Equal(t, resultsStruct.Plans[i].Plan, resultsExpectedStruct.Plans[i].Plan)
assert.Equal(t, resultsStruct.Plans[i].Error, resultsExpectedStruct.Plans[i].Error)
assert.Equal(t, resultsStruct.Plans[i].Warning, resultsExpectedStruct.Plans[i].Warning)
}
})

t.Run("will not fail on warnings and files should have warnings and generate files", func(t *testing.T) {
Expand Down Expand Up @@ -270,7 +289,15 @@ func TestPlanGenerator(t *testing.T) {
assert.NoError(t, err)
expected, err := os.ReadFile(path.Join(getTestDataDir(), "plans", "base", filename))
assert.NoError(t, err)
assert.Equal(t, string(expected), string(queryPlan))
resultsStruct := QueryPlanResults{}
json.Unmarshal(queryPlan, &resultsStruct)
resultsExpectedStruct := QueryPlanResults{}
json.Unmarshal(expected, &resultsExpectedStruct)
for i := range resultsStruct.Plans {
assert.Equal(t, resultsStruct.Plans[i].Plan, resultsExpectedStruct.Plans[i].Plan)
assert.Equal(t, resultsStruct.Plans[i].Error, resultsExpectedStruct.Plans[i].Error)
assert.Equal(t, resultsStruct.Plans[i].Warning, resultsExpectedStruct.Plans[i].Warning)
}
})
}
})
Expand Down
Loading