Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

sql/*: add support for create index statement #182

Merged
merged 2 commits into from
May 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions benchmark/tpc_h_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,14 @@ func executeQueries(b *testing.B, e *sqle.Engine) error {
func genDB(b *testing.B) (sql.Database, error) {
db := mem.NewDatabase("tpch")

memDb, ok := db.(*mem.Database)
if !ok {
b.Fatal("database cannot be casted to mem database")
}

for _, m := range tpchTableMetadata {
b.Log("generating table", m.name)
t := mem.NewTable(m.name, m.schema)
if err := insertDataToTable(t, len(m.schema)); err != nil {
return nil, err
}

memDb.AddTable(m.name, t)
db.AddTable(m.name, t)
}

return db, nil
Expand Down
17 changes: 5 additions & 12 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,8 @@ func TestAmbiguousColumnResolution(t *testing.T) {
require.Nil(table2.Insert(sql.NewRow("pux", int64(1))))

db := mem.NewDatabase("mydb")

memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable(table.Name(), table)
memDb.AddTable(table2.Name(), table2)
db.AddTable(table.Name(), table)
db.AddTable(table2.Name(), table2)

e := sqle.New()
e.AddDatabase(db)
Expand Down Expand Up @@ -374,12 +370,9 @@ func newEngine(t *testing.T) *sqle.Engine {
require.Nil(table3.Insert(sql.NewRow("c", int32(3))))

db := mem.NewDatabase("mydb")
memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable(table.Name(), table)
memDb.AddTable(table2.Name(), table2)
memDb.AddTable(table3.Name(), table3)
db.AddTable(table.Name(), table)
db.AddTable(table2.Name(), table2)
db.AddTable(table3.Name(), table3)

e := sqle.New()
e.AddDatabase(db)
Expand Down
10 changes: 4 additions & 6 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ func Example() {

// Iterate results and print them.
for {
ro, err := r.Next()
row, err := r.Next()
if err == io.EOF {
break
}
checkIfError(err)

name := ro[0]
count := ro[1]
name := row[0]
count := row[1]

fmt.Println(name, count)
}
Expand All @@ -50,9 +50,7 @@ func createTestDatabase() sql.Database {
{Name: "name", Type: sql.Text, Source: "mytable"},
{Name: "email", Type: sql.Text, Source: "mytable"},
})
memDb, _ := db.(*mem.Database)

memDb.AddTable("mytable", table)
db.AddTable("mytable", table)
table.Insert(sql.NewRow("John Doe", "john@doe.com"))
table.Insert(sql.NewRow("John Doe", "johnalt@doe.com"))
table.Insert(sql.NewRow("Jane Doe", "jane@doe.com"))
Expand Down
4 changes: 2 additions & 2 deletions mem/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type Database struct {
}

// NewDatabase creates a new database with the given name.
func NewDatabase(name string) sql.Database {
func NewDatabase(name string) *Database {
return &Database{
name: name,
tables: map[string]sql.Table{},
Expand All @@ -29,7 +29,7 @@ func (d *Database) Tables() map[string]sql.Table {
}

// AddTable adds a new table to the database.
func (d *Database) AddTable(name string, t *Table) {
func (d *Database) AddTable(name string, t sql.Table) {
d.tables[name] = t
}

Expand Down
4 changes: 1 addition & 3 deletions mem/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ func TestDatabase_AddTable(t *testing.T) {
tables := db.Tables()
require.Equal(0, len(tables))

altDb, ok := db.(sql.Alterable)
require.True(ok)

var altDb sql.Alterable = db
err := altDb.Create("test_table", sql.Schema{})
require.NoError(err)

Expand Down
5 changes: 1 addition & 4 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@ func setupMemDB(require *require.Assertions) *sqle.Engine {
db := mem.NewDatabase("test")
e.AddDatabase(db)

memDb, ok := db.(*mem.Database)
require.True(ok)

tableTest := mem.NewTable("test", sql.Schema{{Name: "c1", Type: sql.Int32, Source: "test"}})

for i := 0; i < 1010; i++ {
require.NoError(tableTest.Insert(sql.NewRow(int32(i))))
}

memDb.AddTable("test", tableTest)
db.AddTable("test", tableTest)

return e
}
Expand Down
16 changes: 6 additions & 10 deletions sql/analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@ func TestAnalyzer_Analyze(t *testing.T) {
})
table2 := mem.NewTable("mytable2", sql.Schema{{Name: "i2", Type: sql.Int32, Source: "mytable2"}})
db := mem.NewDatabase("mydb")

memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable("mytable", table)
memDb.AddTable("mytable2", table2)
db.AddTable("mytable", table)
db.AddTable("mytable2", table2)

catalog := &sql.Catalog{Databases: []sql.Database{db}}
a := New(catalog)
Expand Down Expand Up @@ -208,16 +204,16 @@ func TestAddRule(t *testing.T) {
require := require.New(t)

a := New(nil)
require.Len(a.Rules, 12)
a.AddRule("foo", pushdown)
require.Len(a.Rules, 13)
a.AddRule("foo", pushdown)
require.Len(a.Rules, 14)
}

func TestAddValidationRule(t *testing.T) {
require := require.New(t)

a := New(nil)
require.Len(a.ValidationRules, 5)
a.AddValidationRule("foo", validateGroupBy)
require.Len(a.ValidationRules, 6)
a.AddValidationRule("foo", validateGroupBy)
require.Len(a.ValidationRules, 7)
}
22 changes: 22 additions & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var DefaultRules = []Rule{
{"pushdown", pushdown},
{"optimize_distinct", optimizeDistinct},
{"erase_projection", eraseProjection},
{"index_catalog", indexCatalog},
}

var (
Expand Down Expand Up @@ -638,6 +639,27 @@ func dedupStrings(in []string) []string {
return result
}

// indexCatalog sets the catalog in the CreateIndex nodes.
func indexCatalog(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}

ci, ok := n.(*plan.CreateIndex)
if !ok {
return n, nil
}

span, ctx := ctx.Span("index_catalog")
defer span.Finish()

nc := *ci
ci.Catalog = a.Catalog
ci.CurrentDatabase = a.CurrentDatabase

return &nc, nil
}

func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
span, ctx := ctx.Span("pushdown")
defer span.Finish()
Expand Down
19 changes: 5 additions & 14 deletions sql/analyzer/rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,9 @@ func TestResolveSubqueries(t *testing.T) {
})
table3 := mem.NewTable("baz", sql.Schema{{Name: "c", Type: sql.Int64, Source: "baz"}})
db := mem.NewDatabase("mydb")
memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable("foo", table1)
memDb.AddTable("bar", table2)
memDb.AddTable("baz", table3)
db.AddTable("foo", table1)
db.AddTable("bar", table2)
db.AddTable("baz", table3)

catalog := &sql.Catalog{Databases: []sql.Database{db}}
a := New(catalog)
Expand Down Expand Up @@ -106,10 +103,7 @@ func TestResolveTables(t *testing.T) {

table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}})
db := mem.NewDatabase("mydb")
memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable("mytable", table)
db.AddTable("mytable", table)

catalog := &sql.Catalog{Databases: []sql.Database{db}}

Expand Down Expand Up @@ -144,10 +138,7 @@ func TestResolveTablesNested(t *testing.T) {

table := mem.NewTable("mytable", sql.Schema{{Name: "i", Type: sql.Int32}})
db := mem.NewDatabase("mydb")
memDb, ok := db.(*mem.Database)
require.True(ok)

memDb.AddTable("mytable", table)
db.AddTable("mytable", table)

catalog := &sql.Catalog{Databases: []sql.Database{db}}

Expand Down
39 changes: 39 additions & 0 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package analyzer

import (
"strings"

errors "gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-mysql-server.v0/sql"
"gopkg.in/src-d/go-mysql-server.v0/sql/expression"
Expand All @@ -13,6 +15,7 @@ const (
validateGroupByRule = "validate_group_by"
validateSchemaSourceRule = "validate_schema_source"
validateProjectTuplesRule = "validate_project_tuples"
validateIndexCreationRule = "validate_index_creation"
)

var (
Expand All @@ -30,6 +33,9 @@ var (
// ErrProjectTuple is returned when there is a tuple of more than 1 column
// inside a projection.
ErrProjectTuple = errors.NewKind("selected field %d should have 1 column, but has %d")
// ErrUnknownIndexColumns is returned when there are columns in the expr
// to index that are unknown in the table.
ErrUnknownIndexColumns = errors.NewKind("unknown columns to index for table %q: %s")
)

// DefaultValidationRules to apply while analyzing nodes.
Expand All @@ -39,6 +45,7 @@ var DefaultValidationRules = []ValidationRule{
{validateGroupByRule, validateGroupBy},
{validateSchemaSourceRule, validateSchemaSource},
{validateProjectTuplesRule, validateProjectTuples},
{validateIndexCreationRule, validateIndexCreation},
}

func validateIsResolved(ctx *sql.Context, n sql.Node) error {
Expand Down Expand Up @@ -130,6 +137,38 @@ func validateSchemaSource(ctx *sql.Context, n sql.Node) error {
return nil
}

func validateIndexCreation(ctx *sql.Context, n sql.Node) error {
span, ctx := ctx.Span("validate_index_creation")
defer span.Finish()

ci, ok := n.(*plan.CreateIndex)
if !ok {
return nil
}

schema := ci.Table.Schema()
table := schema[0].Source

var unknownColumns []string
for _, expr := range ci.Exprs {
expression.Inspect(expr, func(e sql.Expression) bool {
gf, ok := e.(*expression.GetField)
if ok {
if gf.Table() != table || !schema.Contains(gf.Name()) {
unknownColumns = append(unknownColumns, gf.Name())
}
}
return true
})
}

if len(unknownColumns) > 0 {
return ErrUnknownIndexColumns.New(table, strings.Join(unknownColumns, ", "))
}

return nil
}

func validateSchema(t sql.Table) error {
name := t.Name()
for _, col := range t.Schema() {
Expand Down
67 changes: 67 additions & 0 deletions sql/analyzer/validation_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,73 @@ func TestValidateProjectTuples(t *testing.T) {
}
}

func TestValidateIndexCreation(t *testing.T) {
table := mem.NewTable("foo", sql.Schema{
{Name: "a", Source: "foo"},
{Name: "b", Source: "foo"},
})

testCases := []struct {
name string
node sql.Node
ok bool
}{
{
"columns from another table",
plan.NewCreateIndex(
"idx", table,
[]sql.Expression{expression.NewEquals(
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
expression.NewGetFieldWithTable(0, sql.Int64, "bar", "b", false),
)},
"",
make(map[string]string),
),
false,
},
{
"columns that don't exist",
plan.NewCreateIndex(
"idx", table,
[]sql.Expression{expression.NewEquals(
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "c", false),
)},
"",
make(map[string]string),
),
false,
},
{
"columns only from table",
plan.NewCreateIndex(
"idx", table,
[]sql.Expression{expression.NewEquals(
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "a", false),
expression.NewGetFieldWithTable(0, sql.Int64, "foo", "b", false),
)},
"",
make(map[string]string),
),
true,
},
}

rule := getValidationRule(validateIndexCreationRule)
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
err := rule.Apply(sql.NewEmptyContext(), tt.node)
if tt.ok {
require.NoError(err)
} else {
require.Error(err)
require.True(ErrUnknownIndexColumns.Is(err))
}
})
}
}

type dummyNode struct{ resolved bool }

func (n dummyNode) String() string { return "dummynode" }
Expand Down
Loading