Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 89e8320

Browse files
authored
Merge pull request #832 from liquidata-inc/updates
Implemented UPDATE
2 parents b8d9155 + e847998 commit 89e8320

File tree

8 files changed

+493
-8
lines changed

8 files changed

+493
-8
lines changed

engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func (e *Engine) Query(
120120
case *plan.CreateIndex:
121121
typ = sql.CreateIndexProcess
122122
perm = auth.ReadPerm | auth.WritePerm
123-
case *plan.InsertInto, *plan.DeleteFrom, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
123+
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.DropIndex, *plan.UnlockTables, *plan.LockTables:
124124
perm = auth.ReadPerm | auth.WritePerm
125125
}
126126

engine_test.go

Lines changed: 143 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/src-d/go-mysql-server/sql/plan"
1919
"github.com/src-d/go-mysql-server/test"
2020

21+
2122
"github.com/stretchr/testify/require"
2223
)
2324

@@ -2245,6 +2246,142 @@ func TestReplaceIntoErrors(t *testing.T) {
22452246
}
22462247
}
22472248

2249+
func TestUpdate(t *testing.T) {
2250+
var updates = []struct {
2251+
updateQuery string
2252+
expectedUpdate []sql.Row
2253+
selectQuery string
2254+
expectedSelect []sql.Row
2255+
}{
2256+
{
2257+
"UPDATE mytable SET s = 'updated';",
2258+
[]sql.Row{{int64(3), int64(3)}},
2259+
"SELECT * FROM mytable;",
2260+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
2261+
},
2262+
{
2263+
"UPDATE mytable SET s = 'updated' WHERE i > 9999;",
2264+
[]sql.Row{{int64(0), int64(0)}},
2265+
"SELECT * FROM mytable;",
2266+
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
2267+
},
2268+
{
2269+
"UPDATE mytable SET s = 'updated' WHERE i = 1;",
2270+
[]sql.Row{{int64(1), int64(1)}},
2271+
"SELECT * FROM mytable;",
2272+
[]sql.Row{{int64(1), "updated"}, {int64(2), "second row"}, {int64(3), "third row"}},
2273+
},
2274+
{
2275+
"UPDATE mytable SET s = 'updated' WHERE i <> 9999;",
2276+
[]sql.Row{{int64(3), int64(3)}},
2277+
"SELECT * FROM mytable;",
2278+
[]sql.Row{{int64(1), "updated"},{int64(2), "updated"},{int64(3), "updated"}},
2279+
},
2280+
{
2281+
"UPDATE floattable SET f32 = f32 + f32, f64 = f32 * f64 WHERE i = 2;",
2282+
[]sql.Row{{int64(1), int64(1)}},
2283+
"SELECT * FROM floattable WHERE i = 2;",
2284+
[]sql.Row{{int64(2), float32(3.0), float64(4.5)}},
2285+
},
2286+
{
2287+
"UPDATE floattable SET f32 = 5, f32 = 4 WHERE i = 1;",
2288+
[]sql.Row{{int64(1), int64(1)}},
2289+
"SELECT f32 FROM floattable WHERE i = 1;",
2290+
[]sql.Row{{float32(4.0)}},
2291+
},
2292+
{
2293+
"UPDATE mytable SET s = 'first row' WHERE i = 1;",
2294+
[]sql.Row{{int64(1), int64(0)}},
2295+
"SELECT * FROM mytable;",
2296+
[]sql.Row{{int64(1), "first row"}, {int64(2), "second row"}, {int64(3), "third row"}},
2297+
},
2298+
{
2299+
"UPDATE niltable SET b = NULL WHERE f IS NULL;",
2300+
[]sql.Row{{int64(2), int64(1)}},
2301+
"SELECT * FROM niltable WHERE f IS NULL;",
2302+
[]sql.Row{{int64(4), nil, nil}, {nil, nil, nil}},
2303+
},
2304+
{
2305+
"UPDATE mytable SET s = 'updated' ORDER BY i ASC LIMIT 2;",
2306+
[]sql.Row{{int64(2), int64(2)}},
2307+
"SELECT * FROM mytable;",
2308+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "third row"}},
2309+
},
2310+
{
2311+
"UPDATE mytable SET s = 'updated' ORDER BY i DESC LIMIT 2;",
2312+
[]sql.Row{{int64(2), int64(2)}},
2313+
"SELECT * FROM mytable;",
2314+
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "updated"}},
2315+
},
2316+
{
2317+
"UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;",
2318+
[]sql.Row{{int64(1), int64(1)}},
2319+
"SELECT * FROM mytable;",
2320+
[]sql.Row{{int64(1), "first row"}, {int64(2), "updated"}, {int64(3), "third row"}},
2321+
},
2322+
{
2323+
"UPDATE mytable SET s = 'updated';",
2324+
[]sql.Row{{int64(3), int64(3)}},
2325+
"SELECT * FROM mytable;",
2326+
[]sql.Row{{int64(1), "updated"}, {int64(2), "updated"}, {int64(3), "updated"}},
2327+
},
2328+
}
2329+
2330+
for _, update := range updates {
2331+
e := newEngine(t)
2332+
ctx := newCtx()
2333+
testQueryWithContext(ctx, t, e, update.updateQuery, update.expectedUpdate)
2334+
testQueryWithContext(ctx, t, e, update.selectQuery, update.expectedSelect)
2335+
}
2336+
}
2337+
2338+
func TestUpdateErrors(t *testing.T) {
2339+
var expectedFailures = []struct {
2340+
name string
2341+
query string
2342+
}{
2343+
{
2344+
"invalid table",
2345+
"UPDATE doesnotexist SET i = 0;",
2346+
},
2347+
{
2348+
"invalid column set",
2349+
"UPDATE mytable SET z = 0;",
2350+
},
2351+
{
2352+
"invalid column set value",
2353+
"UPDATE mytable SET i = z;",
2354+
},
2355+
{
2356+
"invalid column where",
2357+
"UPDATE mytable SET s = 'hi' WHERE z = 1;",
2358+
},
2359+
{
2360+
"invalid column order by",
2361+
"UPDATE mytable SET s = 'hi' ORDER BY z;",
2362+
},
2363+
{
2364+
"negative limit",
2365+
"UPDATE mytable SET s = 'hi' LIMIT -1;",
2366+
},
2367+
{
2368+
"negative offset",
2369+
"UPDATE mytable SET s = 'hi' LIMIT 1 OFFSET -1;",
2370+
},
2371+
{
2372+
"set null on non-nullable",
2373+
"UPDATE mytable SET s = NULL;",
2374+
},
2375+
}
2376+
2377+
for _, expectedFailure := range expectedFailures {
2378+
t.Run(expectedFailure.name, func(t *testing.T) {
2379+
_, _, err := newEngine(t).Query(newCtx(), expectedFailure.query)
2380+
require.Error(t, err)
2381+
})
2382+
}
2383+
}
2384+
22482385
const testNumPartitions = 5
22492386

22502387
func TestAmbiguousColumnResolution(t *testing.T) {
@@ -2670,12 +2807,12 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine {
26702807

26712808
insertRows(
26722809
t, floatTable,
2673-
sql.NewRow(1, float32(1.0), float64(1.0)),
2674-
sql.NewRow(2, float32(1.5), float64(1.5)),
2675-
sql.NewRow(3, float32(2.0), float64(2.0)),
2676-
sql.NewRow(4, float32(2.5), float64(2.5)),
2677-
sql.NewRow(-1, float32(-1.0), float64(-1.0)),
2678-
sql.NewRow(-2, float32(-1.5), float64(-1.5)),
2810+
sql.NewRow(int64(1), float32(1.0), float64(1.0)),
2811+
sql.NewRow(int64(2), float32(1.5), float64(1.5)),
2812+
sql.NewRow(int64(3), float32(2.0), float64(2.0)),
2813+
sql.NewRow(int64(4), float32(2.5), float64(2.5)),
2814+
sql.NewRow(int64(-1), float32(-1.0), float64(-1.0)),
2815+
sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)),
26792816
)
26802817

26812818
nilTable := memory.NewPartitionedTable("niltable", sql.Schema{

memory/table.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,37 @@ func (t *Table) Delete(ctx *sql.Context, row sql.Row) error {
290290
return nil
291291
}
292292

293+
func (t *Table) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
294+
if err := checkRow(t.schema, oldRow); err != nil {
295+
return err
296+
}
297+
if err := checkRow(t.schema, newRow); err != nil {
298+
return err
299+
}
300+
301+
matches := false
302+
for partitionIndex, partition := range t.partitions {
303+
for partitionRowIndex, partitionRow := range partition {
304+
matches = true
305+
for rIndex, val := range oldRow {
306+
if val != partitionRow[rIndex] {
307+
matches = false
308+
break
309+
}
310+
}
311+
if matches {
312+
t.partitions[partitionIndex][partitionRowIndex] = newRow
313+
break
314+
}
315+
}
316+
if matches {
317+
break
318+
}
319+
}
320+
321+
return nil
322+
}
323+
293324
func checkRow(schema sql.Schema, row sql.Row) error {
294325
if len(row) != len(schema) {
295326
return sql.ErrUnexpectedRowLength.New(len(schema), len(row))

sql/analyzer/pushdown.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func pushdown(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error) {
2020

2121
// don't do pushdown on certain queries
2222
switch n.(type) {
23-
case *plan.InsertInto, *plan.DeleteFrom, *plan.CreateIndex:
23+
case *plan.InsertInto, *plan.DeleteFrom, *plan.Update, *plan.CreateIndex:
2424
return n, nil
2525
}
2626

sql/core.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,12 @@ type Replacer interface {
217217
Inserter
218218
}
219219

220+
// Updater allows rows to be updated.
221+
type Updater interface {
222+
// Update the given row. Provides both the old and new rows.
223+
Update(ctx *Context, old Row, new Row) error
224+
}
225+
220226
// Database represents the database.
221227
type Database interface {
222228
Nameable

sql/expression/set.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package expression
2+
3+
import (
4+
"fmt"
5+
"github.com/src-d/go-mysql-server/sql"
6+
"gopkg.in/src-d/go-errors.v1"
7+
)
8+
9+
var errCannotSetField = errors.NewKind("Expected GetField expression on left but got %T")
10+
11+
// SetField updates the value of a field from a row.
12+
type SetField struct {
13+
BinaryExpression
14+
}
15+
16+
// NewSetField creates a new SetField expression.
17+
func NewSetField(colName, expr sql.Expression) sql.Expression {
18+
return &SetField{BinaryExpression{Left: colName, Right: expr}}
19+
}
20+
21+
func (s *SetField) String() string {
22+
return fmt.Sprintf("SETFIELD %s = %s", s.Left, s.Right)
23+
}
24+
25+
// Type implements the Expression interface.
26+
func (s *SetField) Type() sql.Type {
27+
return s.Left.Type()
28+
}
29+
30+
// Eval implements the Expression interface.
31+
// Returns a copy of the given row with an updated value.
32+
func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
33+
getField, ok := s.Left.(*GetField)
34+
if !ok {
35+
return nil, errCannotSetField.New(s.Left)
36+
}
37+
if getField.fieldIndex < 0 || getField.fieldIndex >= len(row) {
38+
return nil, ErrIndexOutOfBounds.New(getField.fieldIndex, len(row))
39+
}
40+
val, err := s.Right.Eval(ctx, row)
41+
if err != nil {
42+
return nil, err
43+
}
44+
if val != nil {
45+
val, err = getField.fieldType.Convert(val)
46+
if err != nil {
47+
return nil, err
48+
}
49+
}
50+
updatedRow := row.Copy()
51+
updatedRow[getField.fieldIndex] = val
52+
return updatedRow, nil
53+
}
54+
55+
// WithChildren implements the Expression interface.
56+
func (s *SetField) WithChildren(children ...sql.Expression) (sql.Expression, error) {
57+
if len(children) != 2 {
58+
return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 2)
59+
}
60+
return NewSetField(children[0], children[1]), nil
61+
}

sql/parse/parse.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ func convert(ctx *sql.Context, stmt sqlparser.Statement, query string) (sql.Node
153153
return plan.NewRollback(), nil
154154
case *sqlparser.Delete:
155155
return convertDelete(ctx, n)
156+
case *sqlparser.Update:
157+
return convertUpdate(ctx, n)
156158
}
157159
}
158160

@@ -429,6 +431,49 @@ func convertDelete(ctx *sql.Context, d *sqlparser.Delete) (sql.Node, error) {
429431
return plan.NewDeleteFrom(node), nil
430432
}
431433

434+
func convertUpdate(ctx *sql.Context, d *sqlparser.Update) (sql.Node, error) {
435+
node, err := tableExprsToTable(ctx, d.TableExprs)
436+
if err != nil {
437+
return nil, err
438+
}
439+
440+
updateExprs, err := updateExprsToExpressions(ctx, d.Exprs)
441+
if err != nil {
442+
return nil, err
443+
}
444+
445+
if d.Where != nil {
446+
node, err = whereToFilter(ctx, d.Where, node)
447+
if err != nil {
448+
return nil, err
449+
}
450+
}
451+
452+
if len(d.OrderBy) != 0 {
453+
node, err = orderByToSort(ctx, d.OrderBy, node)
454+
if err != nil {
455+
return nil, err
456+
}
457+
}
458+
459+
// Limit must wrap offset, and not vice-versa, so that skipped rows don't count toward the returned row count.
460+
if d.Limit != nil && d.Limit.Offset != nil {
461+
node, err = offsetToOffset(ctx, d.Limit.Offset, node)
462+
if err != nil {
463+
return nil, err
464+
}
465+
}
466+
467+
if d.Limit != nil {
468+
node, err = limitToLimit(ctx, d.Limit.Rowcount, node)
469+
if err != nil {
470+
return nil, err
471+
}
472+
}
473+
474+
return plan.NewUpdate(node, updateExprs), nil
475+
}
476+
432477
func columnDefinitionToSchema(colDef []*sqlparser.ColumnDefinition) (sql.Schema, error) {
433478
var schema sql.Schema
434479
for _, cd := range colDef {
@@ -1241,6 +1286,22 @@ func intervalExprToExpression(ctx *sql.Context, e *sqlparser.IntervalExpr) (sql.
12411286
return expression.NewInterval(expr, e.Unit), nil
12421287
}
12431288

1289+
func updateExprsToExpressions(ctx *sql.Context, e sqlparser.UpdateExprs) ([]sql.Expression, error) {
1290+
res := make([]sql.Expression, len(e))
1291+
for i, updateExpr := range e {
1292+
colName, err := exprToExpression(ctx, updateExpr.Name)
1293+
if err != nil {
1294+
return nil, err
1295+
}
1296+
innerExpr, err := exprToExpression(ctx, updateExpr.Expr)
1297+
if err != nil {
1298+
return nil, err
1299+
}
1300+
res[i] = expression.NewSetField(colName, innerExpr)
1301+
}
1302+
return res, nil
1303+
}
1304+
12441305
func removeComments(s string) string {
12451306
r := bufio.NewReader(strings.NewReader(s))
12461307
var result []rune

0 commit comments

Comments
 (0)