Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,10 @@ func assertNodesEqualWithDiff(t *testing.T, expected, actual sql.Node) {
}
}

func TestRecursiveViewDefinition(t *testing.T) {
enginetest.TestRecursiveViewDefinition(t, enginetest.NewDefaultMemoryHarness())
}

func TestTableFunctions(t *testing.T) {
var tableFunctionScriptTests = []queries.ScriptTest{
{
Expand Down
22 changes: 22 additions & 0 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,28 @@ func TestViews(t *testing.T, harness Harness) {
})
}

func TestRecursiveViewDefinition(t *testing.T, harness Harness) {
harness.Setup(setup.MydbData, setup.MytableData)
e := mustNewEngine(t, harness)
defer e.Close()
ctx := NewContext(harness)

db, err := e.Analyzer.Catalog.Database(ctx, "mydb")
require.NoError(t, err)

if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
db = pdb.Unwrap()
}

vdb, ok := db.(sql.ViewDatabase)
require.True(t, ok, "expected sql.ViewDatabase")

err = vdb.CreateView(ctx, "recursiveView", "select * from recursiveView")
require.NoError(t, err)

AssertErr(t, e, harness, "select * from recursiveView", analyzer.ErrMaxAnalysisIters)
}

func TestViewsPrepared(t *testing.T, harness Harness) {
harness.Setup(setup.MydbData, setup.MytableData)
e := mustNewEngine(t, harness)
Expand Down
8 changes: 8 additions & 0 deletions sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,17 @@ func (a *Analyzer) analyzeThroughBatch(ctx *sql.Context, n sql.Node, scope *Scop
}, sel)
}

// Every time we recursively invoke the analyzer we increment a depth counter to avoid analyzing queries that could
// cause infinite recursion. This limit is high but arbitrary
const maxBatchRecursion = 100

func (a *Analyzer) analyzeWithSelector(ctx *sql.Context, n sql.Node, scope *Scope, batchSelector BatchSelector, ruleSelector RuleSelector) (sql.Node, transform.TreeIdentity, error) {
span, ctx := ctx.Span("analyze")

if scope.RecursionDepth() > maxBatchRecursion {
return n, transform.SameTree, ErrMaxAnalysisIters.New(maxBatchRecursion)
}

var (
same = transform.SameTree
allSame = transform.SameTree
Expand Down
6 changes: 3 additions & 3 deletions sql/analyzer/resolve_subqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope,
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch n := n.(type) {
case *plan.SubqueryAlias:
// subqueries do not have access to outer scope
child, same, err := a.analyzeThroughBatch(ctx, n.Child, nil, "default-rules", sel)
// subqueries do not have access to outer scope, but we do need a scope object to track recursion depth
child, same, err := a.analyzeThroughBatch(ctx, n.Child, newScopeWithDepth(scope.RecursionDepth()+1), "default-rules", sel)
if err != nil {
return nil, same, err
}
Expand Down Expand Up @@ -59,7 +59,7 @@ func finalizeSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope,
switch n := n.(type) {
case *plan.SubqueryAlias:
// subqueries do not have access to outer scope
child, same, err := a.analyzeStartingAtBatch(ctx, n.Child, nil, "default-rules", sel)
child, same, err := a.analyzeStartingAtBatch(ctx, n.Child, newScopeWithDepth(scope.RecursionDepth()+1), "default-rules", sel)
if err != nil {
return nil, same, err
}
Expand Down
21 changes: 18 additions & 3 deletions sql/analyzer/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Scope struct {
// Memo nodes are nodes in the execution context that shouldn't be considered for name resolution, but are still
// important for analysis.
memos []sql.Node
// recursionDepth tracks how many times we've recursed with analysis, to avoid stack overflows from infinite recursion
recursionDepth int

procedures *ProcedureCache
}
Expand All @@ -42,12 +44,18 @@ func (s *Scope) newScope(node sql.Node) *Scope {
newNodes = append(newNodes, node)
newNodes = append(newNodes, s.nodes...)
return &Scope{
nodes: newNodes,
memos: s.memos,
procedures: s.procedures,
nodes: newNodes,
memos: s.memos,
recursionDepth: s.recursionDepth + 1,
procedures: s.procedures,
}
}

// newScopeWithDepth returns a new scope object with the recursion depth given
func newScopeWithDepth(depth int) *Scope {
return &Scope{recursionDepth: depth}
}

// memo creates a new Scope object with the memo node given. Memo nodes don't affect name resolution, but are used in
// other parts of analysis, such as error handling for trigger / procedure execution.
func (s *Scope) memo(node sql.Node) *Scope {
Expand Down Expand Up @@ -83,6 +91,13 @@ func (s *Scope) MemoNodes() []sql.Node {
return s.memos
}

func (s *Scope) RecursionDepth() int {
if s == nil {
return 0
}
return s.recursionDepth
}

func (s *Scope) procedureCache() *ProcedureCache {
if s == nil {
return nil
Expand Down