Skip to content

Commit cf70da4

Browse files
authored
Merge pull request #2463 from dolthub/nicktobey/insert-as
Update GMS to detect INSERT statements with row alias and return error.
2 parents 436ceb5 + f43f8e3 commit cf70da4

File tree

7 files changed

+94
-34
lines changed

7 files changed

+94
-34
lines changed

enginetest/enginetests.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ func TestReadOnlyDatabases(t *testing.T, harness ReadOnlyDatabaseHarness) {
367367
} {
368368
for _, tt := range querySet {
369369
t.Run(tt.WriteQuery, func(t *testing.T) {
370+
if tt.Skip {
371+
t.Skip()
372+
return
373+
}
370374
AssertErrWithBindings(t, engine, harness, tt.WriteQuery, tt.Bindings, analyzererrors.ErrReadOnlyDatabase)
371375
})
372376
}
@@ -1235,10 +1239,16 @@ func TestDelete(t *testing.T, harness Harness) {
12351239
for name, coster := range biasedCosters {
12361240
t.Run(name+" join", func(t *testing.T) {
12371241
for _, tt := range queries.DeleteJoinTests {
1238-
e := mustNewEngine(t, harness)
1239-
e.EngineAnalyzer().Coster = coster
1240-
defer e.Close()
1241-
RunWriteQueryTestWithEngine(t, harness, e, tt)
1242+
t.Run(tt.WriteQuery, func(t *testing.T) {
1243+
if tt.Skip {
1244+
t.Skip()
1245+
return
1246+
}
1247+
e := mustNewEngine(t, harness)
1248+
e.EngineAnalyzer().Coster = coster
1249+
defer e.Close()
1250+
RunWriteQueryTestWithEngine(t, harness, e, tt)
1251+
})
12421252
}
12431253
})
12441254
}

enginetest/evaluation.go

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,37 +1046,45 @@ func ExtractQueryNode(node sql.Node) sql.Node {
10461046

10471047
// RunWriteQueryTest runs the specified |tt| WriteQueryTest using the specified harness.
10481048
func RunWriteQueryTest(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
1049-
e := mustNewEngine(t, harness)
1050-
defer e.Close()
1051-
RunWriteQueryTestWithEngine(t, harness, e, tt)
1049+
t.Run(tt.WriteQuery, func(t *testing.T) {
1050+
if tt.Skip {
1051+
t.Skip()
1052+
return
1053+
}
1054+
e := mustNewEngine(t, harness)
1055+
defer e.Close()
1056+
RunWriteQueryTestWithEngine(t, harness, e, tt)
1057+
})
10521058
}
10531059

10541060
// RunWriteQueryTestWithEngine runs the specified |tt| WriteQueryTest, using the specified harness and engine. Callers
10551061
// are still responsible for closing the engine.
10561062
func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, tt queries.WriteQueryTest) {
1057-
t.Run(tt.WriteQuery, func(t *testing.T) {
1058-
if sh, ok := harness.(SkippingHarness); ok {
1059-
if sh.SkipQueryTest(tt.WriteQuery) {
1060-
t.Logf("Skipping query %s", tt.WriteQuery)
1061-
return
1062-
}
1063-
if sh.SkipQueryTest(tt.SelectQuery) {
1064-
t.Logf("Skipping query %s", tt.SelectQuery)
1065-
return
1066-
}
1063+
if sh, ok := harness.(SkippingHarness); ok {
1064+
if sh.SkipQueryTest(tt.WriteQuery) {
1065+
t.Logf("Skipping query %s", tt.WriteQuery)
1066+
return
10671067
}
1068-
ctx := NewContext(harness)
1069-
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
1070-
expectedSelect := tt.ExpectedSelect
1071-
if IsServerEngine(e) && tt.SkipServerEngine {
1072-
expectedSelect = nil
1068+
if sh.SkipQueryTest(tt.SelectQuery) {
1069+
t.Logf("Skipping query %s", tt.SelectQuery)
1070+
return
10731071
}
1074-
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
1075-
})
1072+
}
1073+
ctx := NewContext(harness)
1074+
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil)
1075+
expectedSelect := tt.ExpectedSelect
1076+
if IsServerEngine(e) && tt.SkipServerEngine {
1077+
expectedSelect = nil
1078+
}
1079+
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil)
10761080
}
10771081

10781082
func runWriteQueryTestPrepared(t *testing.T, harness Harness, tt queries.WriteQueryTest) {
10791083
t.Run(tt.WriteQuery, func(t *testing.T) {
1084+
if tt.Skip {
1085+
t.Skip()
1086+
return
1087+
}
10801088
if sh, ok := harness.(SkippingHarness); ok {
10811089
if sh.SkipQueryTest(tt.WriteQuery) {
10821090
t.Logf("Skipping query %s", tt.WriteQuery)

enginetest/queries/insert_queries.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,20 @@ var InsertQueries = []WriteQueryTest{
476476
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
477477
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
478478
},
479+
{
480+
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt(new_i,new_s) ON DUPLICATE KEY UPDATE s=new_s",
481+
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
482+
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
483+
ExpectedSelect: []sql.Row{{int64(1), "hi"}},
484+
Skip: true, // https://github.com/dolthub/dolt/issues/7638
485+
},
486+
{
487+
WriteQuery: "INSERT INTO mytable (i,s) values (1, 'hi') AS dt ON DUPLICATE KEY UPDATE mytable.s=dt.s",
488+
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
489+
SelectQuery: "SELECT * FROM mytable WHERE i = 1",
490+
ExpectedSelect: []sql.Row{{int64(1), "hir"}},
491+
Skip: true, // https://github.com/dolthub/dolt/issues/7638
492+
},
479493
{
480494
WriteQuery: "INSERT INTO mytable (s,i) values ('dup',1) ON DUPLICATE KEY UPDATE s=CONCAT(VALUES(s), 'licate')",
481495
ExpectedWriteResult: []sql.Row{{types.NewOkResult(2)}},
@@ -1881,6 +1895,18 @@ var InsertScripts = []ScriptTest{
18811895
},
18821896
},
18831897
},
1898+
{
1899+
Name: "insert on duplicate key with incorrect row alias",
1900+
SetUpScript: []string{
1901+
`create table a (i int primary key)`,
1902+
},
1903+
Assertions: []ScriptTestAssertion{
1904+
{
1905+
Query: `insert into a values (1) as new(c, d) on duplicate key update i = c`,
1906+
ExpectedErr: sql.ErrColumnCountMismatch,
1907+
},
1908+
},
1909+
},
18841910
{
18851911
Name: "Insert throws primary key violations",
18861912
SetUpScript: []string{

enginetest/queries/queries.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10763,6 +10763,7 @@ type WriteQueryTest struct {
1076310763
SelectQuery string
1076410764
ExpectedSelect []sql.Row
1076510765
Bindings map[string]*query.BindVariable
10766+
Skip bool
1076610767
SkipServerEngine bool
1076710768
}
1076810769

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ require (
66
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e
77
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
88
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
9-
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df
9+
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80
1010
github.com/go-kit/kit v0.10.0
1111
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
1212
github.com/gocraft/dbr/v2 v2.7.2

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
5858
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
5959
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
6060
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
61-
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df h1:hXB89Qhyu0ymVhP4AvuCtHWGpQmZN0Tt5Cc58Ig8/dg=
62-
github.com/dolthub/vitess v0.0.0-20240415200146-562b545c47df/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
61+
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80 h1:BG7DheiFrbvKYtPmZ1avXA/VPKzz+Bv7L0ytUi83kyQ=
62+
github.com/dolthub/vitess v0.0.0-20240416194558-081bbdc97e80/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
6363
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
6464
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
6565
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=

sql/planbuilder/dml.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package planbuilder
1616

1717
import (
18+
"errors"
1819
"fmt"
1920
"strings"
2021

@@ -128,7 +129,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
128129
switch v := ir.(type) {
129130
case ast.SelectStatement:
130131
return b.buildSelectStmt(inScope, v)
131-
case ast.Values:
132+
case *ast.AliasedValues:
132133
outScope = b.buildInsertValues(inScope, v, columnNames, tableName, destSchema)
133134
default:
134135
err := sql.ErrUnsupportedSyntax.New(ast.String(ir))
@@ -137,7 +138,7 @@ func (b *Builder) insertRowsToNode(inScope *scope, ir ast.InsertRows, columnName
137138
return
138139
}
139140

140-
func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
141+
func (b *Builder) buildInsertValues(inScope *scope, v *ast.AliasedValues, columnNames []string, tableName string, destSchema sql.Schema) (outScope *scope) {
141142
columnDefaultValues := make([]*sql.ColumnDefaultValue, len(columnNames))
142143

143144
for i, columnName := range columnNames {
@@ -156,8 +157,22 @@ func (b *Builder) buildInsertValues(inScope *scope, v ast.Values, columnNames []
156157
}
157158
}
158159

159-
exprTuples := make([][]sql.Expression, len(v))
160-
for i, vt := range v {
160+
if !v.As.IsEmpty() {
161+
if len(v.Columns) != 0 {
162+
for _, tuple := range v.Values {
163+
if len(v.Columns) != len(tuple) {
164+
err := sql.ErrColumnCountMismatch.New()
165+
b.handleErr(err)
166+
}
167+
}
168+
169+
err := errors.New("insert row aliases are not currently supported; use the VALUES() function instead")
170+
b.handleErr(err)
171+
}
172+
}
173+
174+
exprTuples := make([][]sql.Expression, len(v.Values))
175+
for i, vt := range v.Values {
161176
// noExprs is an edge case where we fill VALUES with nil expressions
162177
noExprs := len(vt) == 0
163178
// triggerUnknownTable is an edge case where we ignored an unresolved
@@ -217,10 +232,10 @@ func reorderSchema(names []string, schema sql.Schema) sql.Schema {
217232
return newSch
218233
}
219234

220-
func (b *Builder) buildValues(inScope *scope, v ast.Values) (outScope *scope) {
235+
func (b *Builder) buildValues(inScope *scope, v ast.AliasedValues) (outScope *scope) {
221236
// TODO add literals to outScope?
222-
exprTuples := make([][]sql.Expression, len(v))
223-
for i, vt := range v {
237+
exprTuples := make([][]sql.Expression, len(v.Values))
238+
for i, vt := range v.Values {
224239
exprs := make([]sql.Expression, len(vt))
225240
exprTuples[i] = exprs
226241
for j, e := range vt {

0 commit comments

Comments
 (0)