Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

*: implement subquery expressions #835

Merged
merged 1 commit into from
Oct 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions SUPPORTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,6 @@
- div
- %

## Subqueries
- supported only as tables, not as expressions.

## Functions
- ARRAY_LENGTH
- CEIL
Expand Down Expand Up @@ -133,3 +130,6 @@
- WEEKDAY
- YEAR
- YEARWEEK

## Subqueries
Supported both as a table and as expressions but they can't access the parent query scope.
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sqle // import "github.com/src-d/go-mysql-server"
package sqle

import (
"time"
Expand Down
30 changes: 26 additions & 4 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,28 @@ var queries = []struct {
{int64(5), "there is some text in here"},
},
},
{
`SELECT i FROM mytable WHERE i = (SELECT 1)`,
[]sql.Row{{int64(1)}},
},
{
`SELECT i FROM mytable WHERE i IN (SELECT i FROM mytable)`,
[]sql.Row{
{int64(1)},
{int64(2)},
{int64(3)},
},
},
{
`SELECT i FROM mytable WHERE i NOT IN (SELECT i FROM mytable ORDER BY i ASC LIMIT 2)`,
[]sql.Row{
{int64(3)},
},
},
{
`SELECT (SELECT i FROM mytable ORDER BY i ASC LIMIT 1) AS x`,
[]sql.Row{{int64(1)}},
},
}

func TestQueries(t *testing.T) {
Expand Down Expand Up @@ -1901,7 +1923,7 @@ func TestInsertInto(t *testing.T) {
[]sql.Row{{int64(1)}},
"SELECT * FROM typestable WHERE id = 999;",
[]sql.Row{{
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
int64(0), int64(0), int64(0), int64(0),
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
Expand All @@ -1919,7 +1941,7 @@ func TestInsertInto(t *testing.T) {
[]sql.Row{{int64(1)}},
"SELECT * FROM typestable WHERE id = 999;",
[]sql.Row{{
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
int64(0), int64(0), int64(0), int64(0),
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
Expand Down Expand Up @@ -2101,7 +2123,7 @@ func TestReplaceInto(t *testing.T) {
[]sql.Row{{int64(1)}},
"SELECT * FROM typestable WHERE id = 999;",
[]sql.Row{{
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
int64(0), int64(0), int64(0), int64(0),
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
Expand All @@ -2119,7 +2141,7 @@ func TestReplaceInto(t *testing.T) {
[]sql.Row{{int64(1)}},
"SELECT * FROM typestable WHERE id = 999;",
[]sql.Row{{
int64(999), int64(-math.MaxInt8-1), int64(-math.MaxInt16-1), int64(-math.MaxInt32-1), int64(-math.MaxInt64-1),
int64(999), int64(-math.MaxInt8 - 1), int64(-math.MaxInt16 - 1), int64(-math.MaxInt32 - 1), int64(-math.MaxInt64 - 1),
int64(0), int64(0), int64(0), int64(0),
float64(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
timeParse(sql.TimestampLayout, "0010-04-05 12:51:36"), timeParse(sql.DateLayout, "0101-11-07"),
Expand Down
2 changes: 1 addition & 1 deletion log.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sqle // import "github.com/src-d/go-mysql-server"
package sqle

import (
"github.com/golang/glog"
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server // import "github.com/src-d/go-mysql-server/server"
package server

import (
"time"
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package analyzer // import "github.com/src-d/go-mysql-server/sql/analyzer"
package analyzer

import (
"os"
Expand Down
14 changes: 13 additions & 1 deletion sql/analyzer/assign_indexes.go
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,20 @@ func containsColumns(e sql.Expression) bool {
return result
}

func containsSubquery(e sql.Expression) bool {
var result bool
expression.Inspect(e, func(e sql.Expression) bool {
if _, ok := e.(*expression.Subquery); ok {
result = true
return false
}
return true
})
return result
}

func isEvaluable(e sql.Expression) bool {
return !containsColumns(e)
return !containsColumns(e) && !containsSubquery(e)
}

func canMergeIndexes(a, b sql.IndexLookup) bool {
Expand Down
24 changes: 23 additions & 1 deletion sql/analyzer/resolve_subqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package analyzer

import (
"github.com/src-d/go-mysql-server/sql"
"github.com/src-d/go-mysql-server/sql/expression"
"github.com/src-d/go-mysql-server/sql/plan"
)

Expand All @@ -10,7 +11,7 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
defer span.Finish()

a.Log("resolving subqueries")
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
n, err := plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
switch n := n.(type) {
case *plan.SubqueryAlias:
a.Log("found subquery %q with child of type %T", n.Name(), n.Child)
Expand All @@ -24,4 +25,25 @@ func resolveSubqueries(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, err
return n, nil
}
})
if err != nil {
return nil, err
}

return plan.TransformExpressionsUp(n, func(e sql.Expression) (sql.Expression, error) {
s, ok := e.(*expression.Subquery)
if !ok || s.Resolved() {
return e, nil
}

q, err := a.Analyze(ctx, s.Query)
if err != nil {
return nil, err
}

if qp, ok := q.(*plan.QueryProcess); ok {
q = qp.Child
}

return s.WithQuery(q), nil
})
}
27 changes: 27 additions & 0 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
validateCaseResultTypesRule = "validate_case_result_types"
validateIntervalUsageRule = "validate_interval_usage"
validateExplodeUsageRule = "validate_explode_usage"
validateSubqueryColumnsRule = "validate_subquery_columns"
)

var (
Expand Down Expand Up @@ -57,6 +58,12 @@ var (
ErrExplodeInvalidUse = errors.NewKind(
"using EXPLODE is not supported outside a Project node",
)

// ErrSubqueryColumns is returned when an expression subquery returns
// more than a single column.
ErrSubqueryColumns = errors.NewKind(
"subquery expressions can only return a single column",
)
)

// DefaultValidationRules to apply while analyzing nodes.
Expand All @@ -70,6 +77,7 @@ var DefaultValidationRules = []Rule{
{validateCaseResultTypesRule, validateCaseResultTypes},
{validateIntervalUsageRule, validateIntervalUsage},
{validateExplodeUsageRule, validateExplodeUsage},
{validateSubqueryColumnsRule, validateSubqueryColumns},
}

func validateIsResolved(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
Expand Down Expand Up @@ -322,6 +330,25 @@ func validateExplodeUsage(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node,
return n, nil
}

func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
valid := true
plan.InspectExpressions(n, func(e sql.Expression) bool {
s, ok := e.(*expression.Subquery)
if ok && len(s.Query.Schema()) != 1 {
valid = false
return false
}

return true
})

if !valid {
return nil, ErrSubqueryColumns.New()
}

return n, nil
}

func stringContains(strs []string, target string) bool {
for _, s := range strs {
if s == target {
Expand Down
31 changes: 31 additions & 0 deletions sql/analyzer/validation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,37 @@ func TestValidateExplodeUsage(t *testing.T) {
}
}

func TestValidateSubqueryColumns(t *testing.T) {
require := require.New(t)
ctx := sql.NewEmptyContext()

node := plan.NewProject([]sql.Expression{
expression.NewSubquery(plan.NewProject(
[]sql.Expression{
lit(1),
lit(2),
},
dummyNode{true},
)),
}, dummyNode{true})

_, err := validateSubqueryColumns(ctx, nil, node)
require.Error(err)
require.True(ErrSubqueryColumns.Is(err))

node = plan.NewProject([]sql.Expression{
expression.NewSubquery(plan.NewProject(
[]sql.Expression{
lit(1),
},
dummyNode{true},
)),
}, dummyNode{true})

_, err = validateSubqueryColumns(ctx, nil, node)
require.NoError(err)
}

type dummyNode struct{ resolved bool }

func (n dummyNode) String() string { return "dummynode" }
Expand Down
2 changes: 1 addition & 1 deletion sql/core.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sql // import "github.com/src-d/go-mysql-server/sql"
package sql

import (
"fmt"
Expand Down
58 changes: 56 additions & 2 deletions sql/expression/comparison.go
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, err
}

// TODO: support subqueries
switch right := in.Right().(type) {
case Tuple:
for _, el := range right {
Expand Down Expand Up @@ -496,6 +495,34 @@ func (in *In) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
}
}

return false, nil
case *Subquery:
if leftElems > 1 {
return nil, ErrInvalidOperandColumns.New(leftElems, 1)
}

typ := right.Type()
values, err := right.EvalMultiple(ctx)
if err != nil {
return nil, err
}

for _, val := range values {
val, err = typ.Convert(val)
if err != nil {
return nil, err
}

cmp, err := typ.Compare(left, val)
if err != nil {
return nil, err
}

if cmp == 0 {
return true, nil
}
}

return false, nil
default:
return nil, ErrUnsupportedInOperand.New(right)
Expand Down Expand Up @@ -547,7 +574,6 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, err
}

// TODO: support subqueries
switch right := in.Right().(type) {
case Tuple:
for _, el := range right {
Expand Down Expand Up @@ -577,6 +603,34 @@ func (in *NotIn) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
}
}

return true, nil
case *Subquery:
if leftElems > 1 {
return nil, ErrInvalidOperandColumns.New(leftElems, 1)
}

typ := right.Type()
values, err := right.EvalMultiple(ctx)
if err != nil {
return nil, err
}

for _, val := range values {
val, err = typ.Convert(val)
if err != nil {
return nil, err
}

cmp, err := typ.Compare(left, val)
if err != nil {
return nil, err
}

if cmp == 0 {
return false, nil
}
}

return true, nil
default:
return nil, ErrUnsupportedInOperand.New(right)
Expand Down
Loading